import pandas as pd
from providers.polygon import PolygonProvider
import seaborn as sns
import matplotlib.pyplot as plt
from io import BytesIO

def filter_stocks(df, filters):
    """
    Filter stocks based on user-defined criteria.
    :param df: DataFrame containing stock data
    :param filters: Dictionary of filter conditions (e.g., {"Volume": ">1M", "RSI": "<30"})
    :return: Filtered DataFrame
    """
    filtered_df = df.copy()

    for column, condition in filters.items():
        operator, value = condition[:2], condition[2:]
        if operator == ">":
            filtered_df = filtered_df[filtered_df[column] > float(value)]
        elif operator == "<":
            filtered_df = filtered_df[filtered_df[column] < float(value)]

    return filtered_df

def calculate_rsi(df, length=14):
    """
    Calculate RSI for each stock in the DataFrame.
    :param df: DataFrame containing stock data
    :param length: Length of RSI calculation
    :return: DataFrame with RSI column added
    """
    df["RSI"] = df.groupby(level=0).apply(lambda x: ta.rsi(x["Close"], length=length))
    return df

def generate_heatmap(filtered_df):
    """
    Generate a heatmap for the filtered stocks.
    :param filtered_df: DataFrame containing filtered stock data
    :return: BytesIO buffer containing the heatmap image
    """
    pivot_df = filtered_df.pivot(index="Symbol", columns="Timestamp", values="Close")
    plt.figure(figsize=(12, 8))
    sns.heatmap(pivot_df, cmap="coolwarm", annot=False, cbar=True)
    plt.title("Stock Price Heatmap")
    plt.xlabel("Date")
    plt.ylabel("Stock Symbol")

    buffer = BytesIO()
    plt.savefig(buffer, format="png")
    buffer.seek(0)
    plt.close()
    return buffer
