Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ wheels/
.installed.cfg
*.egg
MANIFEST
.env310

# Jupyter Notebook
.ipynb_checkpoints
Expand Down
20 changes: 14 additions & 6 deletions finetune/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, data_type: str = 'train'):
self.n_samples = self.config.n_val_iter

with open(self.data_path, 'rb') as f:
self.data = pickle.load(f)
self.data = pickle.load(f) # {symbol: pd.DataFrame(ochlv, index=datetime)}

self.window = self.config.lookback_window + self.config.predict_window + 1

Expand All @@ -63,7 +63,7 @@ def __init__(self, data_type: str = 'train'):
df['day'] = df['datetime'].dt.day
df['month'] = df['datetime'].dt.month
# Keep only necessary columns to save memory.
self.data[symbol] = df[self.feature_list + self.time_feature_list]
self.data[symbol] = df[self.feature_list + self.time_feature_list] # store ochlv and time features as a dataframe in final self.data

# Add all valid starting indices for this symbol to the global list.
for i in range(num_samples):
Expand Down Expand Up @@ -111,16 +111,24 @@ def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:

# Extract the sliding window from the dataframe.
df = self.data[symbol]
end_idx = start_idx + self.window
end_idx = start_idx + self.window # this is time window, T
win_df = df.iloc[start_idx:end_idx]

# Separate main features and time features.
x = win_df[self.feature_list].values.astype(np.float32)
x_stamp = win_df[self.time_feature_list].values.astype(np.float32)

# Perform instance-level normalization.
x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
x = (x - x_mean) / (x_std + 1e-5)
# # Perform instance-level normalization.
# # Be careful about data leakage here, because x contains time span in this window.
# # We use this window to predict the future, not to predict within this window!
# x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
# x = (x - x_mean) / (x_std + 1e-5)
# x = np.clip(x, -self.config.clip, self.config.clip)

L = self.config.lookback_window # only use the lookback window for normalization
x_hist = x[:L]
x_hist_mean, x_hist_std = np.mean(x_hist, axis=0), np.std(x_hist, axis=0)
x = (x - x_hist_mean) / (x_hist_std + 1e-5)
x = np.clip(x, -self.config.clip, self.config.clip)

# Convert to PyTorch tensors.
Expand Down
8 changes: 8 additions & 0 deletions finetune/qlib_data_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@ def load_qlib_data(self):
symbol_df['amt'] = (symbol_df['open'] + symbol_df['high'] + symbol_df['low'] + symbol_df['close']) / 4 * symbol_df['vol']
symbol_df = symbol_df[self.config.feature_list]

# finally the symbol_df looks like:
# symbol 'XSHG.600000'
#
# field(column axis) open close xxx xxx xxx
# datetime(index name)
# 2024-01-02 100 101 xxx xxx xxx
# 2024-01-03 102 103 xxx xxx xxx

# Filter out symbols with insufficient data.
symbol_df = symbol_df.dropna()
if len(symbol_df) < self.config.lookback_window + self.config.predict_window + 1:
Expand Down
15 changes: 11 additions & 4 deletions finetune/train_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,21 @@ def train_model(model, tokenizer, device, config, save_dir, logger, rank, world_

# Tokenize input data on-the-fly
with torch.no_grad():
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
# token_seq_0, token_seq_1 are coarse subtoken and fine subtoken batches, like: [LongTensor(B, T), LongTensor(B, T)], each value in [0, 2**s1_bits-1] or [0, 2**s2_bits-1]
# From (B, T, D) to (B, T, 1), by squeezing the last dimension as a interger index, then to (B, T) LongTensor
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True) # coarse subtokens and fine subtokens

# Prepare inputs and targets for the language model
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
# token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
# token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
s = config['lookback_window'] - 1
token_in = [token_seq_0[:, s:-1], token_seq_1[:, s:-1]]
token_out = [token_seq_0[:, s+1:], token_seq_1[:, s+1:]]

# Forward pass and loss calculation
logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
# logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
stamp_in = batch_x_stamp[:, s:-1, :]
logits = model(token_in[0], token_in[1], stamp_in)
loss, s1_loss, s2_loss = model.module.head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])

# Backward pass and optimization
Expand Down
6 changes: 3 additions & 3 deletions finetune/train_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def train_model(model, device, config, save_dir, logger, rank, world_size):
valid_dataset.set_epoch_seed(0) # Keep validation sampling consistent

for i, (ori_batch_x, _) in enumerate(train_loader):
ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True)
ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True) # (B_size, Time_window, D_features)

# --- Gradient Accumulation Loop ---
current_batch_total_loss = 0.0
Expand All @@ -145,11 +145,11 @@ def train_model(model, device, config, save_dir, logger, rank, world_size):

loss_scaled = loss / config['accumulation_steps']
current_batch_total_loss += loss.item()
loss_scaled.backward()
loss_scaled.backward() # (accumulate gradients: delta_Loss/delta_weightsθ) + (remove the grad graph), each accumulation step

# --- Optimizer Step after Accumulation ---
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
optimizer.step()
optimizer.step() # adjust weights(θ) based on batched gradients, one batch
scheduler.step()
optimizer.zero_grad()

Expand Down
21 changes: 12 additions & 9 deletions model/kronos.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def __init__(self, d_in, d_model, n_heads, ff_dim, n_enc_layers, n_dec_layers, f
self.embed = nn.Linear(self.d_in, self.d_model)
self.head = nn.Linear(self.d_model, self.d_in)

# Encoder Transformer Blocks
# Encoder Transformer Blocks, causal
self.encoder = nn.ModuleList([
TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p)
for _ in range(self.enc_layers - 1)
])
# Decoder Transformer Blocks
# Decoder Transformer Blocks, causal
self.decoder = nn.ModuleList([
TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p)
for _ in range(self.dec_layers - 1)
Expand Down Expand Up @@ -91,15 +91,18 @@ def forward(self, x):
for layer in self.encoder:
z = layer(z)

z = self.quant_embed(z) # (B, T, codebook)
z = self.quant_embed(z) # (B, T, codebook(k=s1_bits+s2_bits))

bsq_loss, quantized, z_indices = self.tokenizer(z)
# quantized is the float representation before binarization to subtokens, with the same shape as z and z_indices. this Float is used for decoder input.
# z_indices is a quantized index, reflecting the codebook entry subtoken1 and subtoken2. e.g. z_indices [10,23] reflecting the subtoken1{1,0,1,0} and subtoken2{1,0,1,1,1} in the codebook
bsq_loss, quantized, z_indices = self.tokenizer(z)

quantized_pre = quantized[:, :, :self.s1_bits] # Extract the first part of quantized representation (s1_bits)
z_pre = self.post_quant_embed_pre(quantized_pre)

z_pre = self.post_quant_embed_pre(quantized_pre)
z = self.post_quant_embed(quantized)

# Causal Transformer decoder layers
# Decoder layers (for pre part - s1 bits)
for layer in self.decoder:
z_pre = layer(z_pre)
Expand Down Expand Up @@ -153,10 +156,10 @@ def encode(self, x, half=False):
z = self.embed(x)
for layer in self.encoder:
z = layer(z)
z = self.quant_embed(z)
z = self.quant_embed(z) # (B, T, codebook(k=s1_bits+s2_bits))

bsq_loss, quantized, z_indices = self.tokenizer(z, half)
return z_indices
return z_indices # [(B, T, s1), (B, T, s2)] ==> [LongTensor(B, T), LongTensor(B, T)], which are coarse subtokens and fine subtokens

def decode(self, x, half=False):
"""
Expand Down Expand Up @@ -258,7 +261,7 @@ def forward(self, s1_ids, s2_ids, stamp=None, padding_mask=None, use_teacher_for
x = self.token_drop(x)

for layer in self.transformer:
x = layer(x, key_padding_mask=padding_mask)
x = layer(x, key_padding_mask=padding_mask) # padding_mask is for hiding the padding place, which has no meaning in attention calculation. not causal mask!!!

x = self.norm(x)

Expand All @@ -273,7 +276,7 @@ def forward(self, s1_ids, s2_ids, stamp=None, padding_mask=None, use_teacher_for

x2 = self.dep_layer(x, sibling_embed, key_padding_mask=padding_mask) # Dependency Aware Layer: Condition on s1 embeddings
s2_logits = self.head.cond_forward(x2)
return s1_logits, s2_logits
return s1_logits, s2_logits #the element in these two list, reflects to subtoken3, subtoken4 in paper figure 2

def decode_s1(self, s1_ids, s2_ids, stamp=None, padding_mask=None):
"""
Expand Down
1 change: 1 addition & 0 deletions model/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def _rotate_half(self, x):
return torch.cat((-x2, x1), dim=-1)


# with a causal mask
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
Expand Down