"""
This module provides tools for generating interactive plots.
Visualizations require data from a `pandas.DataFrame`, or `pandas.Series`.
Functions
---------
* `plot_time_series`: Plot a time series with optional colormap for data points.
Notes
-----
The core functionality of this module is to assist with visualizing data in an interactive
time series format, especially useful when you have a sequence of data points indexed in time order.
The colormap feature allows users to provide a visual representation based on another data series.
"""
from __future__ import annotations
import datetime
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
[docs]def plot_time_series(
    series: pd.Series,
    colormap_series: pd.Series | None = None,
    **plot_kwargs,
):
    """
    Plot a time series using Plotly with optional colormap for data points.
    Given a pandas series with datetime index, this function creates a Plotly
    interactive time series plot. Optionally, users can also provide a colormap
    series that determines the color of each data point in the time series plot.
    Parameters
    ----------
    series : pd.Series
        A `pandas.Series` object with a `DatetimeIndex` and values to plot.
    colormap_series : pd.Series, optional
        A `pandas.Series` object with the same DatetimeIndex as `series` that
        determines the colors of each data point in the main series plot.
    plot_kwargs : dict
        Additional keyword arguments for customizing the plot appearance.
        Some  of the recognized keys include:
        * width: width of the plot in pixels. Default is 1400.
        * height: height of the plot in pixels. Default is 700.
        * title: title of the plot. Default is the name of the series.
        * connectgaps: whether to connect data points with line. Default is True.
        * mode: mode of the plot. Default is "lines+markers".
        * marker: a dictionary of marker properties.
    Raises
    ------
    ValueError
        * If `series` is not a `pandas.Series` object.
        * If `series` index is not a `DatetimeIndex`.
    See Also
    --------
    plotly.subplots.make_subplots
    plotly.graph_objects.Scatter
    pandas.Series
    Examples
    --------
    >>> data = pd.Series([1, 2, 3], index=pd.date_range('2023-01-01', periods=3))
    >>> colormap = pd.Series([10, 20, 30], index=pd.date_range('2023-01-01', periods=3))
    >>> plot_time_series(data, colormap_series=colormap)
    Or:
    >>> plot_time_series(data)
    """
    if not isinstance(series, pd.Series):
        raise ValueError(
            f"Input 'series' should be a pandas Series object, not {type(series)!r}"
        )
    if not isinstance(series.index, pd.DatetimeIndex):
        raise ValueError("Index of the series should be a DatetimeIndex")
    fig = make_subplots(rows=1, cols=1)
    marker_colors, colormap_series_name = None, None
    hoverinfo = "x+y"
    hovertext = None
    if isinstance(colormap_series, pd.Series):
        series = series[series.index.isin(colormap_series.index)]
        marker_colors = colormap_series[colormap_series.index.isin(series.index)].values
        colormap_series_name = colormap_series.name
        # Adding colormap values to hover text
        hoverinfo = "text"
        hovertext = [
            f"{series.index[i]}<br>{series.name}: {series.values[i]}<br>{colormap_series_name}: {marker_colors[i]}"
            for i in range(len(series))
        ]
    width = plot_kwargs.get("width", 1400)
    height = plot_kwargs.get("height", 700)
    title = plot_kwargs.get("title", series.name)
    connectgaps = plot_kwargs.get("connectgaps", True)
    mode = plot_kwargs.get("mode", "lines+markers")
    marker = plot_kwargs.get("marker", {})
    marker["color"] = marker.get("color", marker_colors)
    marker["colorscale"] = marker.get("colorscale", "Viridis")
    marker["size"] = marker.get("size", 5)
    marker["colorbar"] = marker.get("colorbar", {"title": colormap_series_name})
    current_date = series.index[-1]
    first_date = datetime.datetime(current_date.year, 1, 1)
    ytd_count = (current_date - first_date).days
    fig.add_trace(
        go.Scatter(
            x=series.index,
            y=series.values,
            name=series.name,
            mode=mode,
            connectgaps=connectgaps,
            marker=marker,
            hoverinfo=hoverinfo,
            hovertext=hovertext,
        )
    )
    # Adding buttons to filter based on colormap_series values
    buttons = []
    if colormap_series is not None:
        buttons = [
            {
                "args": [
                    {
                        "visible": [
                            (
                                marker_colors[i] >= pd.Series(marker_colors).quantile(percentile)
                                and marker_colors[i] < pd.Series(marker_colors).quantile(percentile + 0.25)
                            ) for i in range(len(marker_colors))
                        ]
                    }
                ],
                "label": f"{pd.Series(marker_colors).quantile(percentile):.2f}~{pd.Series(marker_colors).quantile(percentile + 0.25):.2f}",
                "method": "restyle"
            }
            for percentile in [0, 0.25, 0.50, 0.75]
        ]
    fig.update_layout(
        title=title,
        xaxis={
            "rangeselector": {
                "buttons": [
                    {"count": 1, "label": "1m", "step": "month", "stepmode": "backward"},
                    {"count": 6, "label": "6m", "step": "month", "stepmode": "backward"},
                    {"count": 1, "label": "1y", "step": "year", "stepmode": "backward"},
                    {"count": ytd_count, "label": "YTD", "step": "day", "stepmode": "backward"},
                    {"step": "all"},
                ]
            },
            "rangeslider": {"visible": True},
            "type": "date",
        },
        width=width,
        height=height,
        updatemenus=[{"buttons": buttons, "direction": "down", "showactive": True,}],
    )
    fig.show()