# 前言
想来已经很久没有写博客了,趁着春节假期刚结束,我一时兴起,决定利用 DeepSeek 水一篇博客,既为之后的学习做个记录,也为眼前的毕业设计留下一些痕迹。
这篇博客的主题是《深度学习模型训练代码》,顾名思义,主要讨论如何编写深度学习中的 train.py
文件。相信写过训练代码的人都知道, train.py
的基本逻辑大致相同:加载数据集、调用模型、计算损失、梯度优化、保存结果等。虽然听起来简单,但要写出一份既优雅整洁又功能完备的代码,还是需要费些功夫。既然如此,不如直接把这个框架搭建好,以后只要依葫芦画瓢,稍作调整即可。
此博客将在后续实践中不断更新
# 编写规范
为了利用 DeepSeek 写一份好的训练代码框架,首先要列出认为好的编写规范和要求。于我个人而言,好的 train.py
代码实现应该要满足以下要求:
模块化设计:
- 分离数据加载、模型构建、训练逻辑、日志记录等模块。
- 函数职责清晰,便于调试和复用。
可配置性:
- 从
config.json
配置文件加载默认参数。 - 支持命令行参数覆盖配置。
- 参数优先级:命令行 配置文件 代码默认值。
设备管理:
- 自动检测 GPU 并分配设备。
可重复性:
- 固定随机种子(控制
random
,numpy
,torch
等)。 - 保存训练时的完整配置(备份
config.json
)。
日志与监控:
- 记录训练过程中的损失、准确率等指标。
- 保存日志文件和控制台输出。
- 支持 TensorBoard 或 WandB 等可视化工具。
模型保存与恢复:
- 定期保存模型检查点(权重、优化器状态)。
- 支持从检查点恢复训练。
进度反馈:
- 显示训练进度条(如
tqdm
)。 - 打印每个 epoch 的指标摘要。
异常处理
- 处理文件路径不存在、参数不合法等问题。
- 捕获键盘中断(Ctrl+C)并安全保存模型。
# 代码框架
基于上述要求,让 DeepSeek 给出详细的训练代码框架。
# 训练代码
训练代码 train.py
主体内容如下:
# train.py | |
import torch | |
import argparse | |
import json | |
import os | |
import logging | |
import torch.nn as nn | |
from datetime import datetime | |
from tqdm import tqdm | |
from torch.utils.data import DataLoader | |
from torch.utils.tensorboard import SummaryWriter | |
# Define the model architecture based on the configuration | |
def build_model(config): | |
model = nn.Sequential( | |
nn.Linear(config["input_dim"], config["hidden_dim"]), | |
nn.ReLU(), | |
nn.Linear(config["hidden_dim"], config["output_dim"]) | |
) | |
return model | |
# Load and preprocess the dataset | |
def load_data(config): | |
train_loader = DataLoader(...) | |
val_loader = DataLoader(...) | |
return train_loader, val_loader | |
# Parse command-line arguments | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, default="config.json", help="Path to the configuration file") | |
parser.add_argument("--lr", type=float, help="Learning rate (overrides config file)") | |
parser.add_argument("--batch_size", type=int, help="Batch size (overrides config file)") | |
parser.add_argument("--epochs", type=int, help="Number of epochs (overrides config file)") | |
# Add other overridable parameters as needed... | |
return parser.parse_args() | |
# Load the configuration file | |
def load_config(config_path): | |
with open(config_path, "r") as f: | |
config = json.load(f) | |
return config | |
# Configure logging to both file and console | |
def setup_logging(config): | |
log_dir = config["log_dir"] | |
os.makedirs(log_dir, exist_ok=True) | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
log_file = os.path.join(log_dir, f"train_{timestamp}.log") | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s [%(levelname)s] %(message)s", | |
handlers=[ | |
logging.FileHandler(log_file), | |
logging.StreamHandler() | |
] | |
) | |
# Train the model on the train set | |
def train_model(model, train_loader, val_loader, config, device, writer): | |
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"]) | |
criterion = nn.CrossEntropyLoss() | |
# Resume training from a checkpoint if specified | |
start_epoch = 0 | |
if config.get("resume_checkpoint"): | |
checkpoint = torch.load(config["resume_checkpoint"]) | |
model.load_state_dict(checkpoint["model_state_dict"]) | |
optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) | |
start_epoch = checkpoint["epoch"] + 1 | |
logging.info(f"Resuming training from epoch {start_epoch}") | |
for epoch in range(start_epoch, config["epochs"]): | |
model.train() | |
train_loss = 0.0 | |
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}") | |
for batch_idx, batch in enumerate(progress_bar): | |
inputs, labels = batch | |
inputs, labels = inputs.to(device), labels.to(device) | |
optimizer.zero_grad() | |
outputs = model(inputs) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
train_loss += loss.item() | |
progress_bar.set_postfix({"loss": loss.item()}) | |
# Log training loss to TensorBoard | |
writer.add_scalar("Loss/train", loss.item(), epoch * len(train_loader) + batch_idx) | |
# Validation step | |
val_loss, val_acc = validate(model, val_loader, criterion, device) | |
logging.info( | |
f"Epoch {epoch+1}: " | |
f"Train Loss: {train_loss/len(train_loader):.4f}, " | |
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}" | |
) | |
# Log validation metrics to TensorBoard | |
writer.add_scalar("Loss/val", val_loss, epoch) | |
writer.add_scalar("Accuracy/val", val_acc, epoch) | |
# Save checkpoint periodically | |
if (epoch + 1) % config["save_interval"] == 0: | |
checkpoint_path = os.path.join(config["checkpoint_dir"], f"model_epoch_{epoch+1}.pt") | |
torch.save({ | |
"epoch": epoch, | |
"model_state_dict": model.state_dict(), | |
"optimizer_state_dict": optimizer.state_dict(), | |
}, checkpoint_path) | |
# Evaluate the model on the validation set | |
def validate(model, val_loader, criterion, device): | |
model.eval() | |
total_loss = 0.0 | |
correct = 0 | |
with torch.no_grad(): | |
for batch in val_loader: | |
inputs, labels = batch | |
inputs, labels = inputs.to(device), labels.to(device) | |
outputs = model(inputs) | |
loss = criterion(outputs, labels) | |
total_loss += loss.item() | |
preds = torch.argmax(outputs, dim=1) | |
correct += (preds == labels).sum().item() | |
return total_loss / len(val_loader), correct / len(val_loader.dataset) | |
# Main function | |
def main(): | |
# Parse command-line arguments | |
args = parse_args() | |
config = load_config(args.config) | |
# Override config with command-line arguments | |
for key, value in vars(args).items(): | |
if value is not None and key in config: | |
config[key] = value | |
# Validate required configuration keys (optional) | |
required_keys = ["input_dim", "hidden_dim", "output_dim", "lr", "batch_size"] | |
for key in required_keys: | |
if key not in config: | |
raise ValueError(f"Missing required key in config: {key}") | |
# Initialize logging, device, and random seed | |
setup_logging(config) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
torch.manual_seed(config.get("seed", 42)) | |
# Initialize TensorBoard writer | |
writer = SummaryWriter(log_dir=os.path.join(config["log_dir"], "tensorboard")) | |
# Build the model and load data | |
model = build_model(config).to(device) | |
train_loader, val_loader = load_data(config) | |
# Log model graph to TensorBoard (optional) | |
sample_input = next(iter(train_loader))[0].to(device) | |
writer.add_graph(model, sample_input) | |
# Start training | |
try: | |
train_model(model, train_loader, val_loader, config, device, writer) | |
except KeyboardInterrupt: | |
logging.info("Training interrupted. Saving the latest model...") | |
torch.save(model.state_dict(), os.path.join(config["checkpoint_dir"], "interrupted.pt")) | |
finally: | |
# Close TensorBoard writer | |
writer.close() | |
logging.info("Training completed!") | |
if __name__ == "__main__": | |
main() |
# 配置文件
配置文件 config.json
的示例如下,注意 json 的键不能嵌套:
# config.json | |
{ | |
"input_dim": 784, | |
"hidden_dim": 128, | |
"output_dim": 10, | |
"lr": 0.001, | |
"batch_size": 64, | |
"epochs": 20, | |
"seed": 42, | |
"log_dir": "./logs", | |
"checkpoint_dir": "./checkpoints", | |
"save_interval": 5 | |
} |
# 使用方法
使用默认配置文件运行训练脚本:
python train.py --config config.json |
通过命令行参数覆盖配置文件中的值。例如,修改学习率和 batch size
:
python train.py --config config.json --lr 0.01 --batch_size 128 |
如果需要从某个检查点恢复训练,可以在 config.json
中添加 resume_checkpoint
参数,或者通过命令行指定:
python train.py --config config.json --resume_checkpoint ./checkpoints/model_epoch_10.pt |
训练过程中,TensorBoard 日志会保存在 ./logs/tensorboard
目录下。启动 TensorBoard 查看可视化结果:
tensorboard --logdir=./logs/tensorboard |
然后在浏览器中访问 http://localhost:6006
,即可查看训练和验证的损失、准确率等指标。
运行脚本后,生成的文件结构如下:
./ | |
├── logs/ | |
│ ├── train_20231025_153000.log # 训练日志 | |
│ └── tensorboard/ # TensorBoard 日志 | |
├── checkpoints/ | |
│ ├── model_epoch_5.pt # 模型检查点 | |
│ └── model_epoch_10.pt | |
└── config.json # 配置文件 |