"""
Module with endpoints for specific macro steps of the project workflow.
"""
import argparse
import time
from typing import Optional, Union
from wip.logging_config import logger
from wip.mltrainer import train_ml_models
from wip.otm import main_otm
[docs]def run_mltrainer_and_optimization():
"""Train ML models and solve the optimization problem."""
start_time = time.time()
try:
print(
"\n"
"+---------------------+\n"
"| Predictive Models |\n"
"+---------------------+\n"
)
train_ml_models()
end_ml_time = time.time()
except Exception as exc: # pylint: disable=broad-except
logger.exception(exc)
logger.critical("Failed to train ML models. See above exception for details.")
raise exc
try:
print(
"\n"
"+---------------------+\n"
"| Optimization Model |\n"
"+---------------------+\n"
)
main_otm()
end_otm_time = time.time()
except Exception as exc: # pylint: disable=broad-except
logger.exception(exc)
logger.critical("Failed to run optimization. See above exception for details.")
raise exc
logger.info("Predictive models took: %ss", round(end_ml_time - start_time, 2))
logger.info("Optimization problem took: %ss", round(end_otm_time - end_ml_time, 2))
logger.info("Total time: %ss", round(end_otm_time - start_time, 2))
logger.info("Finished training predictive models and solving optimization problem.")
[docs]def run_mltrainer():
"""Train ML models."""
parser = argparse.ArgumentParser(description='Train ML models.')
parser.add_argument('--datasets_filepath', type=Optional[str], default=None, help='The filepath to the datasets.joblib file.')
parser.add_argument('--df_sql_filepath', type=Optional[str], default=None, help='The filepath to the df_sql.joblib file.')
parser.add_argument('--outputs_folder', type=Optional[str], default=None,
help='The folder path where to save the outputs from `mltrainer.py`.')
parser.add_argument('--skip_transformations', type=bool, default=False,
help='Whether to skip the data transformations before training the ML models.')
args = parser.parse_args()
train_ml_models(datasets_filepath=args.datasets_filepath,
df_sql_filepath=args.df_sql_filepath,
outputs_folder=args.outputs_folder,
skip_transformations=args.skip_transformations)