Skip to content

Added Random Forest Regressor as an additional prediction model. #12767

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
Update run.py
Used matplotlib to plot actual vs predicted user count, forecast confidence intervals, outlier thresholds from IQR.
Added logging instead of print because in production, print() is not scalable.
  • Loading branch information
priyanshu-8789 authored May 25, 2025
commit 6c2f7b48ead29db0005e09649091d102b989d33d
27 changes: 23 additions & 4 deletions machine_learning/forecasting/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,19 @@
u can just adjust it for ur own purpose
"""

import logging
from warnings import simplefilter

import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import Normalizer
from sklearn.svm import SVR
from statsmodels.tsa.statespace.sarimax import SARIMAX
import matplotlib.pyplot as plt

Check failure on line 23 in machine_learning/forecasting/run.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

machine_learning/forecasting/run.py:14:1: I001 Import block is un-sorted or un-formatted

logging.basicConfig(level=logging.Info)
logger = logging.getLogger(__name__)

def linear_regression_prediction(
train_dt: list, train_usr: list, train_mtch: list, test_dt: list, test_mtch: list
Expand Down Expand Up @@ -143,6 +147,21 @@
not_safe += 1
return safe > not_safe

def plot_forecast(actual, predictions):
plt.figure(figsize=(10, 5))
plt.plot(range(len(actual)), actual, label="Actual")
plt.plot(len(actual), predictions[0], 'ro', label="Linear Reg")
plt.plot(len(actual), predictions[1], 'go', label="SARIMAX")
plt.plot(len(actual), predictions[2], 'bo', label="SVR")
plt.plot(len(actual), predictions[3], 'yo', label="RF")
plt.legend()
plt.title("Data Safety Forecast")
plt.xlabel("Days")
plt.ylabel("Normalized User Count")
plt.grid(True)
plt.tight_layout()
plt.show()


if __name__ == "__main__":
"""
Expand Down Expand Up @@ -179,11 +198,11 @@
),
sarimax_predictor(train_user, train_match, test_match),
support_vector_regressor(x_train, x_test, train_user),
random_forest_regressor(
x_train, x_test, train_user
), # Added Random Forest Regressor
random_forest_regressor(x_train, x_test, train_user),
]

# check the safety of today's data
not_str = "" if data_safety_checker(res_vote, test_user[0]) else "not "
print(f"Today's data is {not_str}safe.")
logger.info(f"Today's data is {not_str}safe.")

Check failure on line 207 in machine_learning/forecasting/run.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W293)

machine_learning/forecasting/run.py:207:1: W293 Blank line contains whitespace
plot_forecast(train_user, res_vote)
Loading