# 前言

想来已经很久没有写博客了,趁着春节假期刚结束,我一时兴起,决定利用 DeepSeek 水一篇博客,既为之后的学习做个记录,也为眼前的毕业设计留下一些痕迹。

这篇博客的主题是《深度学习模型训练代码》,顾名思义,主要讨论如何编写深度学习中的 train.py 文件。相信写过训练代码的人都知道, train.py 的基本逻辑大致相同:加载数据集、调用模型、计算损失、梯度优化、保存结果等。虽然听起来简单,但要写出一份既优雅整洁又功能完备的代码,还是需要费些功夫。既然如此,不如直接把这个框架搭建好,以后只要依葫芦画瓢,稍作调整即可。

此博客将在后续实践中不断更新


# 编写规范

为了利用 DeepSeek 写一份好的训练代码框架,首先要列出认为好的编写规范和要求。于我个人而言,好的 train.py 代码实现应该要满足以下要求:

模块化设计

  • 分离数据加载、模型构建、训练逻辑、日志记录等模块。
  • 函数职责清晰,便于调试和复用。

可配置性

  • config.json 配置文件加载默认参数。
  • 支持命令行参数覆盖配置。
  • 参数优先级:命令行 >\gt 配置文件 >\gt 代码默认值。

设备管理

  • 自动检测 GPU 并分配设备。

可重复性

  • 固定随机种子(控制 random , numpy , torch 等)。
  • 保存训练时的完整配置(备份 config.json )。

日志与监控

  • 记录训练过程中的损失、准确率等指标。
  • 保存日志文件和控制台输出。
  • 支持 TensorBoard 或 WandB 等可视化工具。

模型保存与恢复

  • 定期保存模型检查点(权重、优化器状态)。
  • 支持从检查点恢复训练。

进度反馈

  • 显示训练进度条(如 tqdm )。
  • 打印每个 epoch 的指标摘要。

异常处理

  • 处理文件路径不存在、参数不合法等问题。
  • 捕获键盘中断(Ctrl+C)并安全保存模型。

# 代码框架

基于上述要求,让 DeepSeek 给出详细的训练代码框架。

# 训练代码

训练代码 train.py 主体内容如下:

# train.py
import torch
import argparse
import json
import os
import logging
import torch.nn as nn
from datetime import datetime
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# Define the model architecture based on the configuration
def build_model(config):
    model = nn.Sequential(
        nn.Linear(config["input_dim"], config["hidden_dim"]),
        nn.ReLU(),
        nn.Linear(config["hidden_dim"], config["output_dim"])
    )
    return model
# Load and preprocess the dataset
def load_data(config):
    train_loader = DataLoader(...)
    val_loader = DataLoader(...)
    return train_loader, val_loader
# Parse command-line arguments
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="config.json", help="Path to the configuration file")
    parser.add_argument("--lr", type=float, help="Learning rate (overrides config file)")
    parser.add_argument("--batch_size", type=int, help="Batch size (overrides config file)")
    parser.add_argument("--epochs", type=int, help="Number of epochs (overrides config file)")
    # Add other overridable parameters as needed...
    return parser.parse_args()
# Load the configuration file
def load_config(config_path):
    with open(config_path, "r") as f:
        config = json.load(f)
    return config
# Configure logging to both file and console
def setup_logging(config):
    log_dir = config["log_dir"]
    os.makedirs(log_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = os.path.join(log_dir, f"train_{timestamp}.log")
    
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(message)s",
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )
# Train the model on the train set
def train_model(model, train_loader, val_loader, config, device, writer):
    optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
    criterion = nn.CrossEntropyLoss()
    
    # Resume training from a checkpoint if specified
    start_epoch = 0
    if config.get("resume_checkpoint"):
        checkpoint = torch.load(config["resume_checkpoint"])
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_epoch = checkpoint["epoch"] + 1
        logging.info(f"Resuming training from epoch {start_epoch}")
    
    for epoch in range(start_epoch, config["epochs"]):
        model.train()
        train_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
        
        for batch_idx, batch in enumerate(progress_bar):
            inputs, labels = batch
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            progress_bar.set_postfix({"loss": loss.item()})
            
            # Log training loss to TensorBoard
            writer.add_scalar("Loss/train", loss.item(), epoch * len(train_loader) + batch_idx)
        
        # Validation step
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        logging.info(
            f"Epoch {epoch+1}: "
            f"Train Loss: {train_loss/len(train_loader):.4f}, "
            f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}"
        )
        
        # Log validation metrics to TensorBoard
        writer.add_scalar("Loss/val", val_loss, epoch)
        writer.add_scalar("Accuracy/val", val_acc, epoch)
        
        # Save checkpoint periodically
        if (epoch + 1) % config["save_interval"] == 0:
            checkpoint_path = os.path.join(config["checkpoint_dir"], f"model_epoch_{epoch+1}.pt")
            torch.save({
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
            }, checkpoint_path)
# Evaluate the model on the validation set
def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    with torch.no_grad():
        for batch in val_loader:
            inputs, labels = batch
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
    return total_loss / len(val_loader), correct / len(val_loader.dataset)
# Main function
def main():
    # Parse command-line arguments
    args = parse_args()
    config = load_config(args.config)
    
    # Override config with command-line arguments
    for key, value in vars(args).items():
        if value is not None and key in config:
            config[key] = value
    
    # Validate required configuration keys (optional)
    required_keys = ["input_dim", "hidden_dim", "output_dim", "lr", "batch_size"]
    for key in required_keys:
        if key not in config:
            raise ValueError(f"Missing required key in config: {key}")
    
    # Initialize logging, device, and random seed
    setup_logging(config)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.manual_seed(config.get("seed", 42))
    
    # Initialize TensorBoard writer
    writer = SummaryWriter(log_dir=os.path.join(config["log_dir"], "tensorboard"))
    
    # Build the model and load data
    model = build_model(config).to(device)
    train_loader, val_loader = load_data(config)
    
    # Log model graph to TensorBoard (optional)
    sample_input = next(iter(train_loader))[0].to(device)
    writer.add_graph(model, sample_input)
    
    # Start training
    try:
        train_model(model, train_loader, val_loader, config, device, writer)
    except KeyboardInterrupt:
        logging.info("Training interrupted. Saving the latest model...")
        torch.save(model.state_dict(), os.path.join(config["checkpoint_dir"], "interrupted.pt"))
    finally:
        # Close TensorBoard writer
        writer.close()
    
    logging.info("Training completed!")
if __name__ == "__main__":
    main()

# 配置文件

配置文件 config.json 的示例如下,注意 json 的键不能嵌套:

# config.json
{
    "input_dim": 784,
    "hidden_dim": 128,
    "output_dim": 10,
    "lr": 0.001,
    "batch_size": 64,
    "epochs": 20,
    "seed": 42,
    "log_dir": "./logs",
    "checkpoint_dir": "./checkpoints",
    "save_interval": 5
}

# 使用方法

使用默认配置文件运行训练脚本:

python train.py --config config.json

通过命令行参数覆盖配置文件中的值。例如,修改学习率和 batch size

python train.py --config config.json --lr 0.01 --batch_size 128

如果需要从某个检查点恢复训练,可以在 config.json 中添加 resume_checkpoint 参数,或者通过命令行指定:

python train.py --config config.json --resume_checkpoint ./checkpoints/model_epoch_10.pt

训练过程中,TensorBoard 日志会保存在 ./logs/tensorboard 目录下。启动 TensorBoard 查看可视化结果:

tensorboard --logdir=./logs/tensorboard

然后在浏览器中访问 http://localhost:6006 ,即可查看训练和验证的损失、准确率等指标。

运行脚本后,生成的文件结构如下:

./
├── logs/
│   ├── train_20231025_153000.log  # 训练日志
│   └── tensorboard/               # TensorBoard 日志
├── checkpoints/
│   ├── model_epoch_5.pt           # 模型检查点
│   └── model_epoch_10.pt
└── config.json                    # 配置文件