Source code for iowa_forecast.ml_train

"""
BigQuery Model Training and Execution Module.

This module provides functionality for creating, executing, and managing
'ARIMA_PLUS_XREG' models using Google BigQuery. The module includes
functions to generate SQL queries for creating models, executing these
queries with retries, and evaluating model performance.

Functions
---------
* `create_model_query`: Generate an SQL query to create an ARIMA_PLUS_XREG model
  for a specified item and its associated data.

* `execute_query_with_retries`: Execute a given SQL query with retry logic in case of failure.

* `create_models_for_items`: Create 'ARIMA_PLUS_XREG' models for a list of items
  by executing generated SQL queries.

* `train_arima_models`: Train ARIMA models for specified columns,
  executes the corresponding SQL queries, and stores the model metrics in BigQuery tables.

Notes
-----
This module is designed to work with Google BigQuery and requires a valid
BigQuery client instance. The models generated by this module are intended
for forecasting time series data, with options to handle holiday effects,
step changes, and data cleaning.

See Also
--------
Google BigQuery: https://cloud.google.com/bigquery
BigQuery ML: https://cloud.google.com/bigquery-ml
"""
from __future__ import annotations

import time
from typing import List

from google.cloud import bigquery  # pylint: disable=no-name-in-module
from rich.progress import track

from iowa_forecast.models_configs import ARIMA_PLUS_XREG_Config, ARIMAConfig
from iowa_forecast.utils import normalize_item_name


[docs]def create_model_query( # pylint: disable=too-many-arguments item_name: str, timestamp_col: str = "date", time_series_data_col: str = "total_amount_sold", model_name: str = "bqmlforecast.arima_plus_xreg_model", train_table_name: str = "bqmlforecast.training_data", test_table_name: str = "bqmlforecast.test_data", **kwargs, ) -> str: """ Generate a BigQuery 'CREATE MODEL' query for a specified item. This function constructs an SQL query to create an ARIMA_PLUS_XREG model in BigQuery, tailored for the provided item and its associated data. Parameters ---------- item_name : str Name of the item for which the model is to be created. timestamp_col : str, default="date" The column name representing the timestamp in the dataset. time_series_data_col : str, default="total_amount_sold" The column name representing the time series data. model_name : str, default="bqmlforecast.arima_plus_xreg_model" The base name for the model. train_table_name : str, default="bqmlforecast.training_data" The name of the table containing training data. test_table_name : str | None, default="bqmlforecast.test_data" The name of the table containing test data. **kwargs : Any Additional keyword arguments such as: holiday_region : str, default="US" The holiday region to be used by the model. auto_arima : bool, default=True Whether to enable AUTO_ARIMA. adjust_step_changes : bool, default=True Whether to adjust for step changes in the data. clean_spikes_and_dips : bool, default=True Whether to clean spikes and dips in the data. Returns ------- str A SQL query string for creating the specified model. """ configs = ARIMA_PLUS_XREG_Config(**kwargs) item_name_norm = normalize_item_name(item_name) test_table_query = include_test_on_model_train(item_name, timestamp_col, train_table_name, test_table_name) return f""" CREATE OR REPLACE MODEL `{model_name}_{item_name_norm}` OPTIONS( MODEL_TYPE='ARIMA_PLUS_XREG', TIME_SERIES_TIMESTAMP_COL='{timestamp_col}', TIME_SERIES_DATA_COL='{time_series_data_col}', HOLIDAY_REGION='{configs.holiday_region}', AUTO_ARIMA={configs.auto_arima}, ADJUST_STEP_CHANGES={configs.adjust_step_changes}, CLEAN_SPIKES_AND_DIPS={configs.clean_spikes_and_dips} ) AS SELECT * FROM `{train_table_name}` WHERE item_name = "{item_name}" {test_table_query} ORDER BY date """
[docs]def include_test_on_model_train( item_name: str, timestamp_col: str, train_table_name: str, test_table_name: str | None = None, ) -> str: """ Include test data in the model training process. This function generates an SQL query component to union test data with training data if a test table is specified. Parameters ---------- item_name : str The name of the item being modeled. timestamp_col : str The column name representing the timestamp in the dataset. train_table_name : str The name of the table containing training data. test_table_name : str or None, optional The name of the table containing test data. If None, no test data is included. Returns ------- str An SQL query string component to include test data. """ if not isinstance(test_table_name, str): return "" return f""" UNION ALL ( SELECT * FROM ( SELECT t2.* FROM `{test_table_name}` AS t2 JOIN ( SELECT item_name, MAX({timestamp_col}) AS max_date FROM `{train_table_name}` GROUP BY item_name ) AS md ON t2.item_name = md.item_name WHERE t2.{timestamp_col} > md.max_date AND t2.item_name = "{item_name}" ) ) """
[docs]def include_test_on_arima_model_train( column: str, time_series_timestamp_col: str, time_series_id_col: str, train_table_name: str, test_table_name: str | None = None, ) -> str: """ Include test data in the uni-variate ARIMA model training process. This function generates an SQL query component to union test data with training data if a test table is specified. Parameters ---------- column : str The name of the feature being modeled. time_series_timestamp_col : str The column name representing the timestamp in the dataset. time_series_id_col : str The column name representing the identifier. train_table_name : str The name of the table containing training data. test_table_name : str or None, optional The name of the table containing test data. If None, no test data is included. Returns ------- str An SQL query string component to include test data. """ if not isinstance(test_table_name, str): return "" return f""" UNION ALL ( SELECT * FROM ( SELECT t2.{time_series_timestamp_col}, t2.{column}, t2.{time_series_id_col} FROM `{test_table_name}` AS t2 JOIN ( SELECT {time_series_id_col}, MAX({time_series_timestamp_col}) AS max_date FROM `{train_table_name}` GROUP BY {time_series_id_col} ) AS md ON t2.{time_series_id_col} = md.{time_series_id_col} WHERE t2.{time_series_timestamp_col} > md.max_date ) ) """
[docs]def execute_query_with_retries( client: bigquery.Client, query: str, max_retries: int = 3, ) -> None: """ Execute a BigQuery SQL query with retries on failure. This function executes a given SQL query using a BigQuery client. If the query fails, it will automatically retry up to `max_retries` times, with an increasing delay between each attempt. Parameters ---------- client : bigquery.Client Instance of the BigQuery client used to execute the query. query : str The SQL query to be executed. max_retries : int, default=3 Maximum number of retry attempts in case of query failure. Raises ------ Exception Raises an exception if all retry attempts fail. Notes ----- The delay between retries increases linearly by 120 seconds multiplied by the current attempt number. Examples -------- Execute a query with the default number of retries: >>> client = bigquery.Client() >>> query = "SELECT * FROM `my_dataset.my_table`" >>> execute_query_with_retries(client, query) """ tries = 0 success = False while not success and tries < max_retries: try: query_job = client.query(query) query_job.result() success = True except Exception as exc: # pylint: disable=broad-except tries += 1 sleep_time = 120 * tries print(exc) print(f"Attempt {tries} failed. Sleeping for {sleep_time} seconds...") time.sleep(sleep_time)
[docs]def create_models_for_items( # pylint: disable=too-many-arguments client: bigquery.Client, items_list: List[str], max_items: int | None = None, timestamp_col: str = "date", time_series_data_col: str = "total_amount_sold", model_name: str = "bqmlforecast.arima_plus_xreg_model", train_table_name: str = "bqmlforecast.training_data", test_table_name: str | None = "bqmlforecast.test_data", **kwargs, ) -> None: """ Create `'ARIMA_PLUS_XREG'` models for a list of items. This function generates and executes a `'CREATE MODEL'` query for each item in the provided list. The models are created using the specified training and test tables in BigQuery. Parameters ---------- client : bigquery.Client Instance of the BigQuery client used to execute queries. items_list : List[str] List of item names for which models are to be created. max_items : int or None, default=None Maximum number of items to process. If None, all items are processed. See the 'Notes' section for more information. timestamp_col : str, default="date" The column name representing the timestamp in the dataset. time_series_data_col : str, default="total_amount_sold" The column name representing the time series data. model_name : str, default="bqmlforecast.arima_plus_xreg_model" The base name for the models. train_table_name : str, default="bqmlforecast.training_data" The name of the table containing training data. test_table_name : str | None, default="bqmlforecast.test_data" The name of the table containing test data. If `None`, then only the data from `train_table_name` is used for training the model. See the 'Notes' section for more information. **kwargs : Any Additional keyword arguments such as: holiday_region : str, default="US" The holiday region to be used by the models. auto_arima : bool, default=True Whether to enable `'AUTO_ARIMA'`. adjust_step_changes : bool, default=True Whether to adjust for step changes in the data. clean_spikes_and_dips : bool, default=True Whether to clean spikes and dips in the data. Notes ----- Not specifying a value for `max_items` requires you to use a Google Cloud account with billing enabled. If you're not using a Google Cloud account with billing enabled, then you should limit the number of items to a value smaller than or equal to 4. .. important:: If using a Google Cloud account with billing enabled, running this code might incur charges. If you are evaluating the model, you shouldn't use all available data to train the model. Therefore, if you're evaluating the model, consider setting the parameter `test_table_name` to `None`. Doing so will cause the model to be trained using only the specified data from the `train_table_name` which in turn will allow you to use the data from `test_table_name` for evaluation. """ _items_list = ( items_list if not isinstance(max_items, int) else items_list[:max_items] ) for item_name in track(_items_list, description="Creating models..."): query = create_model_query( item_name, timestamp_col, time_series_data_col, model_name, train_table_name, test_table_name, **kwargs, ) execute_query_with_retries(client, query)
[docs]def train_arima_models( # pylint: disable=too-many-locals, too-many-arguments client: bigquery.Client, columns: List[str], model: str = "bqmlforecast.arima_model", train_table_name: str = "bqmlforecast.training_data", test_table_name: str | None = "bqmlforecast.test_data", model_metrics_table_name: str | None = "bqmlforecast.arima_model_metrics", time_series_timestamp_col: str = "date", time_series_id_col: str = "item_name", confidence_level: float = 0.9, horizon: int = 7, use_test_data_on_train: bool = True, **kwargs, ): """ Train ARIMA models for a list of columns and store their metrics. This function generates and executes `'CREATE MODEL'` queries for ARIMA models using the specified columns, and evaluates their performance by creating tables of model metrics. These ARIMA models will then be used to generate the future feature values used for forecasting the liquor sales. Parameters ---------- client : bigquery.Client Instance of the BigQuery client used to execute queries. columns : List[str] List of columns to be used for creating ARIMA models. model : str, default="bqmlforecast.arima_model" The base name for the ARIMA models. train_table_name : str, default="bqmlforecast.training_data" The name of the table containing training data. test_table_name : str | None, default="bqmlforecast.test_data" The name of the table containing test data. model_metrics_table_name : str or None, default="bqmlforecast.arima_model_metrics" The base name for the tables where model metrics will be stored. time_series_timestamp_col : str, default="date" The column name representing the timestamp in the dataset. time_series_id_col : str, default="item_name" The column name representing the identifier for the time series. confidence_level : float, default=0.9 The confidence level used in the model evaluation. horizon : int, default=7 The number of time steps (days) to forecast. use_test_data_on_train : bool, default=True Whether to use test data during model training. """ config = ARIMAConfig(**kwargs) for column in track(columns, description="Creating ARIMA models..."): model_name = f"{model}_{column}" test_data_query = "" if use_test_data_on_train: test_data_query = include_test_on_arima_model_train( column, time_series_timestamp_col, time_series_id_col, train_table_name, test_table_name, ) train_arima_query = f""" CREATE OR REPLACE MODEL `{model_name}` OPTIONS( MODEL_TYPE = '{config.model_type}', AUTO_ARIMA = {config.auto_arima}, HORIZON = {horizon}, TIME_SERIES_TIMESTAMP_COL = '{time_series_timestamp_col}', TIME_SERIES_DATA_COL = '{column}', TIME_SERIES_ID_COL = '{time_series_id_col}', FORECAST_LIMIT_LOWER_BOUND = {config.forecast_limit_lower_bound}, DECOMPOSE_TIME_SERIES = {config.decompose_time_series}, HOLIDAY_REGION = '{config.holiday_region}', DATA_FREQUENCY = '{config.data_frequency}', ADJUST_STEP_CHANGES = {config.adjust_step_changes}, CLEAN_SPIKES_AND_DIPS = {config.clean_spikes_and_dips} ) AS SELECT {time_series_timestamp_col}, {column}, {time_series_id_col} FROM `{train_table_name}` {test_data_query} """ train_arima_job = client.query(train_arima_query) train_arima_job.result() if isinstance(model_metrics_table_name, str): model_metrics_query = f""" CREATE OR REPLACE TABLE `{model_metrics_table_name}_{column}` AS ( SELECT * FROM ML.EVALUATE( MODEL `{model_name}`, ( SELECT {time_series_timestamp_col}, {time_series_id_col}, {column} FROM `{test_table_name}` ), STRUCT({horizon} AS HORIZON, {confidence_level} AS CONFIDENCE_LEVEL) ) ) """ model_metrics_job = client.query(model_metrics_query) model_metrics_job.result()