Skip to content

Added Random Forest Regressor as an additional prediction model. #12767

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed May 25, 2025
commit c8df2ccf38e382707855f83e2c10cf12a162de8e
12 changes: 7 additions & 5 deletions machine_learning/forecasting/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,21 @@
u can just adjust it for ur own purpose
"""

import logging
from warnings import simplefilter

import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import Normalizer
from sklearn.svm import SVR
from statsmodels.tsa.statespace.sarimax import SARIMAX
import matplotlib.pyplot as plt

Check failure on line 23 in machine_learning/forecasting/run.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

machine_learning/forecasting/run.py:14:1: I001 Import block is un-sorted or un-formatted

logging.basicConfig(level=logging.Info)
logger = logging.getLogger(__name__)


def linear_regression_prediction(
train_dt: list, train_usr: list, train_mtch: list, test_dt: list, test_mtch: list
) -> float:
Expand Down Expand Up @@ -147,13 +148,14 @@
not_safe += 1
return safe > not_safe


def plot_forecast(actual, predictions):
plt.figure(figsize=(10, 5))
plt.plot(range(len(actual)), actual, label="Actual")
plt.plot(len(actual), predictions[0], 'ro', label="Linear Reg")
plt.plot(len(actual), predictions[1], 'go', label="SARIMAX")
plt.plot(len(actual), predictions[2], 'bo', label="SVR")
plt.plot(len(actual), predictions[3], 'yo', label="RF")
plt.plot(len(actual), predictions[0], "ro", label="Linear Reg")
plt.plot(len(actual), predictions[1], "go", label="SARIMAX")
plt.plot(len(actual), predictions[2], "bo", label="SVR")
plt.plot(len(actual), predictions[3], "yo", label="RF")
plt.legend()
plt.title("Data Safety Forecast")
plt.xlabel("Days")
Expand Down Expand Up @@ -204,5 +206,5 @@
# check the safety of today's data
not_str = "" if data_safety_checker(res_vote, test_user[0]) else "not "
logger.info(f"Today's data is {not_str}safe.")

plot_forecast(train_user, res_vote)
Loading