Source code for iowa_forecast.models_configs

"""
This module provides classes for managing and validating configuration parameters
for various machine learning models, specifically focusing on ARIMA-based models.
The module includes a base class that standardizes the process of handling model
configuration, ensuring that all derived classes follow a consistent pattern for
validating and setting parameters.

Classes
-------
AbstractBaseModelConfig : ABC
    An abstract base class for the `BaseModelConfig` class.

BaseModelConfig : AbstractBaseModelConfig
    A base class that provides common functionality for model
    configuration, including parameter validation, default value handling, and
    error checking. Subclasses are required to define a `SUPPORTED_PARAMETERS`
    dictionary that specifies the expected parameter types, default values,
    and any valid choices.

ARIMAConfig : BaseModelConfig
    A configuration class for ARIMA model parameters. Inherits from `BaseModelConfig`
    and defines specific parameters used by ARIMA and ARIMA_PLUS models. This class
    ensures that the parameters adhere to the expected types and valid choices.

ARIMA_PLUS_XREG_Config : BaseModelConfig
    A configuration class for ARIMA_PLUS_XREG model parameters. This class extends
    `BaseModelConfig` and includes additional parameters for handling exogenous
    variables (`xreg_features`) and other settings specific to the `ARIMA_PLUS_XREG` model.

Usage
-----
These configuration classes are intended to be used in the setup and validation of
model parameters before they are passed to machine learning model training functions.
By leveraging these classes, developers can ensure that all configuration parameters
are correctly typed, fall within valid ranges, and adhere to expected choices, reducing
the likelihood of runtime errors.

Example
-------
>>> config = ARIMAConfig(model_type="ARIMA")
>>> print(config.model_type)
'ARIMA'

>>> xreg_config = ARIMA_PLUS_XREG_Config(
...     model_type="ARIMA_PLUS_XREG",
...     xreg_features=["feature1", "feature2"],
...     non_negative_forecast=True
... )
>>> print(xreg_config.xreg_features)
['feature1', 'feature2']
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple, List


[docs]class AbstractBaseModelConfig(ABC): # pylint: disable=too-few-public-methods """Abstract base class for `BaseModelConfig` configuration class.""" @property @abstractmethod def SUPPORTED_PARAMETERS(self) -> Dict[ # pylint: disable=invalid-name str, Tuple[Any, Any, List[Any]] ]: """ This abstract property must be implemented by subclasses. It should return a dictionary where the keys are parameter names, and the values are tuples containing the expected type, default value, and a list of valid choices (if any). """
[docs]class BaseModelConfig(AbstractBaseModelConfig): """ Base class for model configuration parameters. This class provides common functionality for handling configuration parameters passed via kwargs, including unpacking, validation, and setting default values. Subclasses must define the `SUPPORTED_PARAMETERS` dictionary, which specifies the expected parameter types, default values, and any restricted choices. """ @property def SUPPORTED_PARAMETERS(self) -> Dict[str, Tuple[Any, Any, List[Any]]]: return {} def __init__(self, **kwargs): self._params = {} self._validate_and_set_parameters(kwargs) def _validate_and_set_parameters(self, kwargs: Dict[str, Any]): for key, (expected_type, default_value, choices) in self.SUPPORTED_PARAMETERS.items(): if key in kwargs: value = kwargs[key] if not isinstance(value, expected_type): raise ValueError( f"Invalid value for parameter '{key}': expected {expected_type.__name__}, " f"but got {type(value).__name__}." ) if choices and value not in choices: raise ValueError( f"Invalid value for parameter '{key}': got '{value}', " f"but expected one of {choices}." ) self._params[key] = value else: self._params[key] = default_value # Identify unsupported parameters unsupported_params = set(kwargs) - set(self.SUPPORTED_PARAMETERS) if unsupported_params: raise ValueError( f"Unsupported parameters provided: {', '.join(unsupported_params)}. " "Please check your input." ) def __getattr__(self, name: str) -> Any: if name in self._params: return self._params[name] raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
[docs]class ARIMAConfig(BaseModelConfig): # pylint: disable=too-few-public-methods """ Configuration class for `'ARIMA'` model parameters. Inherits common functionality from `BaseModelConfig` and defines specific parameters for `'ARIMA'` models, including validation of choices for some parameters. """ @property def SUPPORTED_PARAMETERS(self) -> Dict[str, Tuple[Any, Any, List[Any]]]: return { "model_type": (str, "ARIMA_PLUS", ["ARIMA_PLUS", "ARIMA"]), "auto_arima": (bool, True, []), "forecast_limit_lower_bound": (int, 0, []), "clean_spikes_and_dips": (bool, True, []), "decompose_time_series": (bool, True, []), "holiday_region": (str, "US", []), "data_frequency": (str, "AUTO_FREQUENCY", ["AUTO_FREQUENCY", "DAILY", "WEEKLY", "MONTHLY"]), "adjust_step_changes": (bool, True, []), }
[docs]class ARIMA_PLUS_XREG_Config(BaseModelConfig): # pylint: disable=invalid-name, too-few-public-methods """ Configuration class for `'ARIMA_PLUS_XREG'` model parameters. Inherits common functionality from `BaseModelConfig` and defines specific parameters for `'ARIMA_PLUS_XREG'` models, including validation of choices for some parameters. """ @property def SUPPORTED_PARAMETERS(self) -> Dict[str, Tuple[Any, Any, List[Any]]]: return { "model_type": (str, "ARIMA_PLUS_XREG", ["ARIMA_PLUS_XREG"]), "auto_arima": (bool, True, []), "clean_spikes_and_dips": (bool, True, []), "holiday_region": (str, "US", []), "data_frequency": (str, "AUTO_FREQUENCY", ["AUTO_FREQUENCY", "DAILY", "WEEKLY", "MONTHLY"]), "adjust_step_changes": (bool, True, []), "non_negative_forecast": (bool, False, []), }