-
Couldn't load subscription status.
- Fork 1.8k
Description
如图,以2022到2025的长线预测为准,预测以后直接暴跌,看起来非常不正常。是模型不支持长线?还是代码问题?
import pandas as pd
import matplotlib.pyplot as plt
import sys
sys.path.append("./")
from model import Kronos, KronosTokenizer, KronosPredictor
import matplotlib as mpl
# 设置字体
try:
mpl.rcParams['font.sans-serif'] = ['Heiti TC']
mpl.rcParams['axes.unicode_minus'] = False
except:
print("设置中文字体失败,可能会显示乱码")
def plot_prediction(historical_df, pred_df):
"""
绘制历史数据和预测数据
参数:
historical_df: 历史数据DataFrame (只包含实际存在的数据)
pred_df: 预测数据DataFrame (包含未来时间戳的预测值)
"""
# 创建图形
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
# 1. 绘制收盘价
# 历史收盘价
ax1.plot(historical_df['timestamps'], historical_df['close'],
label='历史数据', color='blue', linewidth=1.5)
# 预测收盘价
ax1.plot(pred_df.index, pred_df['close'],
label='预测数据', color='red', linestyle='--', linewidth=1.5)
# 标记预测起点
if len(historical_df) > 0:
last_hist = historical_df.iloc[-1]
ax1.scatter(last_hist['timestamps'], last_hist['close'],
color='green', s=50, zorder=5, label='预测起点')
ax1.set_ylabel('收盘价', fontsize=12)
ax1.legend(loc='upper left', fontsize=10)
ax1.grid(True, alpha=0.3)
ax1.set_title('股票价格历史与预测')
# 2. 绘制成交量
# 历史成交量
ax2.bar(historical_df['timestamps'], historical_df['volume'],
alpha=0.7, label='历史成交量', color='blue', width=0.8)
# 预测成交量
# 计算合适的条形宽度
if len(pred_df) > 1:
time_diff = (pred_df.index[1] - pred_df.index[0]).days
bar_width = max(0.1, min(1.0, time_diff * 0.8))
else:
bar_width = 0.8
ax2.bar(pred_df.index, pred_df['volume'],
alpha=0.7, label='预测成交量', color='red', width=bar_width)
ax2.set_ylabel('成交量', fontsize=12)
ax2.legend(loc='upper left', fontsize=10)
ax2.grid(True, alpha=0.3)
# 设置x轴格式
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()
1. Load Model and Tokenizer
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
2. Instantiate Predictor
predictor = KronosPredictor(model, tokenizer, device="mps", max_context=512)
3. Prepare Data
df = pd.read_csv("./examples/data/01810_20220201_20251019_Daily_K.csv")
df['timestamps'] = pd.to_datetime(df['timestamps'])
total_data_points = len(df)
print(f"数据总条数: {total_data_points}")
lookback = 700
pred_len = 300
x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close', 'volume', 'amount']]
x_timestamp = df.loc[:lookback-1, 'timestamps']
生成真正的未来时间戳
last_timestamp = x_timestamp.iloc[-1]
计算时间间隔(基于历史数据)
if len(x_timestamp) > 1:
time_diffs = x_timestamp.diff().dropna()
avg_interval = time_diffs.mean()
else:
avg_interval = pd.Timedelta(days=1)
生成未来时间戳
y_timestamp = pd.Series([last_timestamp + avg_interval * (i+1) for i in range(pred_len)])
print(f"历史数据长度: {len(x_df)}")
print(f"生成的未来时间戳长度: {len(y_timestamp)}")
4. Make Prediction
pred_df = predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=pred_len,
T=1.0,
top_p=0.9,
sample_count=1,
verbose=True
)
5. Visualize Results
print("Forecasted Data Head:")
print(pred_df.head(5))
使用修改后的绘图函数
plot_prediction(df, pred_df)