import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from config import *
def preprocess_data():
# 加载并预处理文本数据
with open(TEXT_PATH, 'r', encoding='utf-8') as f:
text = f.read()
# 创建字符映射
chars = sorted(list(set(text)))
char_to_idx = {c: i for i, c in enumerate(chars)}
idx_to_char = {i: c for i, c in enumerate(chars)}
# 文本转索引序列
text_indices = np.array([char_to_idx[c] for c in text])
# 创建训练样本
samples = []
for i in range(0, len(text) - SEQ_LENGTH):
samples.append(text_indices[i:i + SEQ_LENGTH + 1])
samples = np.array(samples)
# 分割输入和目标
inputs = samples[:, :-1]
targets = samples[:, 1:]
# 转换为PyTorch数据集
dataset = TensorDataset(
torch.LongTensor(inputs),
torch.LongTensor(targets)
)
return DataLoader(dataset, BATCH_SIZE, shuffle=True), char_to_idx, idx_to_char
3. 模型定义 model.py
import torch.nn as nn
from config import DEVICE
class CharLSTM(nn.Module):
def __init__(self, vocab_size):
super().__init__()
sel