from config.db_config import get_db_connection

# Watchlist Functions
def add_to_watchlist(user_id, symbol):
    connection = get_db_connection()
    try:
        with connection.cursor() as cursor:
            cursor.execute("""
                INSERT INTO watchlist (user_id, symbol)
                VALUES (%s, %s)
                ON CONFLICT (user_id, symbol) DO NOTHING;
            """, (user_id, symbol))
        connection.commit()
    finally:
        connection.close()

def remove_from_watchlist(user_id, symbol):
    connection = get_db_connection()
    try:
        with connection.cursor() as cursor:
            cursor.execute("""
                DELETE FROM watchlist
                WHERE user_id = %s AND symbol = %s;
            """, (user_id, symbol))
        connection.commit()
    finally:
        connection.close()

def get_watchlist(user_id):
    connection = get_db_connection()
    try:
        with connection.cursor() as cursor:
            cursor.execute("""
                SELECT symbol FROM watchlist
                WHERE user_id = %s;
            """, (user_id,))
            rows = cursor.fetchall()
            return [row[0] for row in rows]
    finally:
        connection.close()