import os
import glob
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import yfinance as yf
import dash
from dash import dcc, html, Input, Output, State, callback, no_update
import dash_bootstrap_components as dbc

dash.register_page(__name__, path="/grafico", name="Gráfica", order=1)

DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data")


def listar_empresas():
    csvs = glob.glob(os.path.join(DATA_DIR, "historico_*.csv"))
    return sorted([os.path.basename(f).replace("historico_", "").replace(".csv", "") for f in csvs])


def leer_csv(empresa):
    ruta = os.path.join(DATA_DIR, f"historico_{empresa}.csv")
    df = pd.read_csv(ruta, index_col=0, skiprows=[1, 2], header=0)
    df.index = pd.to_datetime(df.index, errors="coerce")
    df = df[df.index.notna()].sort_index()
    for col in df.columns:
        df[col] = pd.to_numeric(df[col], errors="coerce")
    return df


def calcular_indicadores(df):
    col = "Adj Close" if "Adj Close" in df.columns else "Close"
    s = df[col].dropna()

    df["SMA50"] = s.rolling(50, min_periods=1).mean()
    df["SMA200"] = s.rolling(200, min_periods=1).mean()
    df["SMA20"] = s.rolling(20, min_periods=1).mean()
    std20 = s.rolling(20, min_periods=1).std()
    df["Boll_Up"] = df["SMA20"] + 2 * std20
    df["Boll_Low"] = df["SMA20"] - 2 * std20

    delta = s.diff()
    gain = delta.clip(lower=0).ewm(com=13, adjust=False).mean()
    loss = (-delta.clip(upper=0)).ewm(com=13, adjust=False).mean()
    df["RSI"] = 100 - 100 / (1 + gain / loss.replace(0, np.nan))

    if "Volume" in df.columns:
        vol         = df["Volume"].replace(0, np.nan)
        sma14       = s.rolling(14, min_periods=14).mean()  # precio medio 14 sesiones
        v_med14     = vol.rolling(14, min_periods=14).mean()
        denominador = sma14 * v_med14
        df["Ind_Custom"] = (s * vol) / denominador.replace(0, np.nan)  # neutro ≈ 1.0

    return df, col


def precio_vivo(ticker_sym):
    ticker = yf.Ticker(ticker_sym)
    # Intento 1: history 2d (endpoint más estable)
    try:
        df = ticker.history(period="2d")
        if not df.empty:
            return float(df["Close"].iloc[-1])
    except Exception:
        pass
    # Intento 2: fast_info
    try:
        p = ticker.fast_info.last_price
        if p and p > 0:
            return float(p)
    except Exception:
        pass
    return None


# ── Layout ─────────────────────────────────────────────────────────────────────

def layout():
    seis_meses = (pd.Timestamp.today() - pd.DateOffset(months=6)).strftime("%Y-%m-%d")
    hoy = pd.Timestamp.today().strftime("%Y-%m-%d")
    return dbc.Container(fluid=True, children=[
    dbc.Row(className="mt-3 mb-2", children=[
        dbc.Col(html.H4("Gráfica Técnica"), width="auto"),
    ]),

    dbc.Row(className="g-2 align-items-end mb-3", children=[
        dbc.Col([
            dbc.Label("Empresa"),
            dcc.Dropdown(id="graf-empresa", options=[], placeholder="Selecciona empresa…", clearable=False),
        ], md=3),
        dbc.Col([
            dbc.Label("Desde"),
            dbc.Input(id="graf-inicio", type="date", value=seis_meses),
        ], md=2),
        dbc.Col([
            dbc.Label("Hasta"),
            dbc.Input(id="graf-fin", type="date", value=hoy),
        ], md=2),
        dbc.Col([
            dbc.Label("Indicadores"),
            dcc.Checklist(
                id="graf-indicadores",
                options=[
                    {"label": " SMA 50",      "value": "sma50"},
                    {"label": " SMA 200",     "value": "sma200"},
                    {"label": " Bollinger",   "value": "bollinger"},
                    {"label": " Fibonacci",   "value": "fibonacci"},
                    {"label": " RSI",         "value": "rsi"},
                    {"label": " Volumen",     "value": "volumen"},
                    {"label": " Índ. Vol.",   "value": "custom"},
                ],
                value=["sma50", "rsi", "volumen"],
                inline=True,
                inputStyle={"marginRight": "4px"},
                labelStyle={"marginRight": "12px"},
            ),
        ], md=5),
    ]),

    dbc.Row(className="mb-2", children=[
        dbc.Col([
            dbc.Button("Generar gráfica", id="graf-btn", color="success", n_clicks=0),
            html.Span(id="graf-rendimiento", className="ms-4 fw-bold fs-5"),
        ])
    ]),

    dcc.Graph(
        id="graf-figura",
        config={
            "modeBarButtonsToAdd": ["drawline", "drawopenpath", "eraseshape"],
            "scrollZoom": True,
        },
        style={"height": "75vh"},
    ),

    # Store para recordar las formas dibujadas entre callbacks
    dcc.Store(id="graf-shapes-store", data=[]),
])


# ── Callbacks ──────────────────────────────────────────────────────────────────

@callback(
    Output("graf-empresa", "options"),
    Output("graf-empresa", "value"),
    Input("graf-empresa", "id"),
)
def cargar_opciones(_):
    empresas = listar_empresas()
    opciones = [{"label": e.replace("_", " "), "value": e} for e in empresas]
    default = empresas[0] if empresas else None
    return opciones, default


@callback(
    Output("graf-figura", "figure"),
    Output("graf-rendimiento", "children"),
    Output("graf-rendimiento", "style"),
    Input("graf-empresa", "value"),
    Input("graf-btn", "n_clicks"),
    State("graf-inicio", "value"),
    State("graf-fin", "value"),
    State("graf-indicadores", "value"),
    State("graf-shapes-store", "data"),
    prevent_initial_call=False,
)
def actualizar_grafica(empresa, _, f_inicio, f_fin, indicadores, shapes_previas):
    if not empresa:
        return no_update, "Selecciona una empresa", {"color": "gray"}

    try:
        df = leer_csv(empresa)
    except Exception as e:
        return no_update, f"Error: {e}", {"color": "red"}

    df, col_precio = calcular_indicadores(df)

    if f_inicio:
        df = df[df.index >= f_inicio]
    if f_fin:
        df = df[df.index <= f_fin]

    if df.empty:
        return no_update, "Sin datos para ese rango", {"color": "orange"}

    mostrar_rsi = "rsi" in indicadores
    mostrar_vol = "volumen" in indicadores
    mostrar_custom = "custom" in indicadores

    n_subplots = 1 + int(mostrar_rsi) + int(mostrar_vol)
    ratios = [4] + [1.2] * (n_subplots - 1)
    subplot_titles = [empresa.replace("_", " ")] + (["RSI (14)"] if mostrar_rsi else []) + (["Volumen"] if mostrar_vol else [])

    fig = make_subplots(
        rows=n_subplots, cols=1,
        shared_xaxes=True,
        row_heights=ratios,
        subplot_titles=subplot_titles,
        vertical_spacing=0.04,
    )

    # Precio
    fig.add_trace(go.Scatter(
        x=df.index, y=df[col_precio],
        name="Precio", line=dict(color="#1f77b4", width=1.5),
        hovertemplate="<b>%{x|%d/%m/%Y}</b>  %{y:.3f} €<extra>Precio</extra>",
    ), row=1, col=1)

    if "sma50" in indicadores:
        fig.add_trace(go.Scatter(
            x=df.index, y=df["SMA50"],
            name="SMA 50", line=dict(color="#FF9800", width=1, dash="dash"),
            hovertemplate="%{y:.3f} €<extra>SMA 50</extra>",
        ), row=1, col=1)

    if "sma200" in indicadores:
        fig.add_trace(go.Scatter(
            x=df.index, y=df["SMA200"],
            name="SMA 200", line=dict(color="#E91E63", width=1.2),
            hovertemplate="%{y:.3f} €<extra>SMA 200</extra>",
        ), row=1, col=1)

    if "bollinger" in indicadores:
        fig.add_trace(go.Scatter(
            x=df.index, y=df["Boll_Up"],
            name="Bollinger", line=dict(color="gray", width=1, dash="dashdot"),
            legendgroup="boll",
            hovertemplate="%{y:.3f} €<extra>Boll. Superior</extra>",
        ), row=1, col=1)
        fig.add_trace(go.Scatter(
            x=df.index, y=df["Boll_Low"],
            line=dict(color="gray", width=1, dash="dashdot"),
            showlegend=False, legendgroup="boll",
            fill="tonexty", fillcolor="rgba(128,128,128,0.07)",
            hovertemplate="%{y:.3f} €<extra>Boll. Inferior</extra>",
        ), row=1, col=1)

    if "fibonacci" in indicadores:
        max_p, min_p = df[col_precio].max(), df[col_precio].min()
        diff = max_p - min_p
        for ratio, color in zip(
            [0, 0.236, 0.382, 0.5, 0.618, 0.786, 1],
            ["#7B1FA2", "#9C27B0", "#AB47BC", "#BA68C8", "#CE93D8", "#E1BEE7", "#7B1FA2"],
        ):
            nivel = max_p - diff * ratio
            fig.add_hline(y=nivel, line=dict(color=color, width=0.8, dash="dash"),
                          annotation_text=f"{ratio*100:.1f}% ({nivel:.2f})",
                          annotation_position="left", row=1, col=1)

    # Precio actual: vivo si disponible, último del CSV como fallback
    from activos import ACTIVOS
    from database import get_tickers_extra
    extra = {r["nombre"]: r["ticker"] for r in get_tickers_extra()}
    todos = {**ACTIVOS, **extra}

    precio_mostrar = df[col_precio].iloc[-1]
    etiqueta = f"Último cierre: {precio_mostrar:.3f} €"
    color_linea = "gray"

    if empresa in todos:
        pv = precio_vivo(todos[empresa])
        if pv:
            precio_mostrar = pv
            etiqueta = f"VIVO: {pv:.3f} €"
            color_linea = "red"

    fig.add_hline(
        y=precio_mostrar,
        line=dict(color=color_linea, width=1.5, dash="dot"),
        annotation_text=f" {etiqueta} ",
        annotation_position="top right",
        row=1, col=1,
    )

    # Eje Y izquierdo: precio absoluto
    fig.update_yaxes(row=1, col=1, title_text="Precio (€)")

    # Eje Y derecho: porcentaje vs precio actual
    # Los subplots usan y, y2, ..., y{n_subplots}; el siguiente está libre
    pct_axis_n = n_subplots + 1
    pct_yid    = f"y{pct_axis_n}"       # referencia en el trace
    pct_ykey   = f"yaxis{pct_axis_n}"   # clave en update_layout

    ymin = df[col_precio].min() * 0.995
    ymax = df[col_precio].max() * 1.005
    pct_min = (ymin - precio_mostrar) / precio_mostrar * 100
    pct_max = (ymax - precio_mostrar) / precio_mostrar * 100

    # Trace invisible que ancla el eje de porcentaje
    pct_series = (df[col_precio] - precio_mostrar) / precio_mostrar * 100
    fig.add_trace(go.Scatter(
        x=df.index, y=pct_series,
        yaxis=pct_yid,
        showlegend=False,
        hoverinfo="skip",
        line=dict(color="rgba(0,0,0,0)"),
        mode="lines",
    ))

    # RSI
    row_rsi = 2 if mostrar_rsi else None
    if mostrar_rsi:
        fig.add_trace(go.Scatter(
            x=df.index, y=df["RSI"],
            name="RSI", line=dict(color="#9C27B0", width=1.2),
            hovertemplate="RSI: %{y:.1f}<extra></extra>",
        ), row=row_rsi, col=1)
        fig.add_hline(y=70, line=dict(color="red", dash="dash", width=0.8), row=row_rsi, col=1)
        fig.add_hline(y=30, line=dict(color="green", dash="dash", width=0.8), row=row_rsi, col=1)
        rsi_range = None if mostrar_custom else [0, 100]
        fig.update_yaxes(row=row_rsi, col=1, range=rsi_range, title_text="RSI")

        if mostrar_custom and "Ind_Custom" in df.columns:
            ind_display = df["Ind_Custom"] * 100
            fig.add_trace(go.Scatter(
                x=df.index, y=ind_display,
                name="Índ. Vol.", line=dict(color="#00BCD4", width=1.2),
                hovertemplate="Índ.Vol: %{y:.1f}<extra></extra>",
            ), row=row_rsi, col=1)

    # Volumen
    row_vol = (2 if not mostrar_rsi else 3) if mostrar_vol else None
    if mostrar_vol and "Volume" in df.columns:
        fig.add_trace(go.Bar(
            x=df.index, y=df["Volume"],
            name="Volumen", marker_color="rgba(117,117,117,0.5)",
            hovertemplate="Vol: %{y:,.0f}<extra></extra>",
        ), row=row_vol, col=1)
        fig.update_yaxes(row=row_vol, col=1, title_text="Volumen")

    # Recuperar formas dibujadas previamente
    spike_style = dict(
        showspikes=True, spikemode="across", spikesnap="cursor",
        spikedash="dot", spikecolor="#888", spikethickness=1,
    )
    fig.update_layout(
        template="plotly_white",
        hovermode="x unified",
        hoverdistance=50,
        legend=dict(orientation="h", yanchor="bottom", y=1.01, xanchor="right", x=1),
        margin=dict(l=50, r=70, t=60, b=40),
        newshape=dict(line_color="#2E7D32", line_width=2),
        shapes=shapes_previas or [],
        dragmode="zoom",
        **{pct_ykey: dict(
            overlaying="y",
            side="right",
            title_text="% vs actual",
            range=[pct_min, pct_max],
            tickformat="+.1f",
            ticksuffix="%",
            showgrid=False,
            zeroline=True,
            zerolinecolor="rgba(200,0,0,0.3)",
            zerolinewidth=1,
            tickfont=dict(color="#777"),
            title_font=dict(color="#777"),
            anchor="x",
        )},
    )
    fig.update_xaxes(rangeslider_visible=False, **spike_style)
    fig.update_yaxes(showspikes=True, spikemode="across", spikesnap="cursor",
                     spikedash="dot", spikecolor="#888", spikethickness=1)

    # Rendimiento del periodo
    p0 = df[col_precio].iloc[0]
    p1 = df[col_precio].iloc[-1]
    rend = (p1 - p0) / p0 * 100
    signo = "+" if rend >= 0 else ""
    color_rend = "#2E7D32" if rend >= 0 else "#C62828"
    texto_rend = f"{signo}{rend:.2f}% ({p0:.2f} € → {p1:.2f} €)"

    return fig, texto_rend, {"color": color_rend}


@callback(
    Output("graf-shapes-store", "data"),
    Input("graf-figura", "relayoutData"),
    State("graf-shapes-store", "data"),
    prevent_initial_call=True,
)
def guardar_shapes(relayout, shapes_actuales):
    if relayout and "shapes" in relayout:
        return relayout["shapes"]
    return shapes_actuales or []
