在训练深度学习模型时,如何高效地追踪训练过程、记录实验参数、可视化损失与准确率曲线,一直是我们开发流程中的痛点。而 Weights & Biases (WandB) 正是为了解决这些问题而生的工具。

本篇文章将带你通过一个完整的 PyTorch 示例,从零快速上手 WandB:训练一个简单的 CNN 模型用于 MNIST 分类,并记录并可视化训练过程。


🧠 为什么选择 WandB?

• 📈 自动可视化指标(loss、accuracy等)

• 🧪 记录超参数

• 💾 自动保存最优模型

• 📊 支持团队协作与在线分享实验结果

• 🔁 便于实验复现


🧰 环境准备

确保你已经安装好以下依赖:

pip install torch torchvision wandb

🏗️ 项目结构概览

我们将实现以下几件事:

• 使用 argparse 支持命令行参数

• 构建一个 CNN 模型用于 MNIST

• 支持多种优化器、动态超参数配置

• 使用 WandB 记录和可视化训练过程


✍️ 代码详解与使用方式

你可以把以下完整代码保存为 train_mnist.py,并通过命令行运行:

python train_mnist.py

 

1. 引入依赖并配置设备

import wandb
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

我们这里支持 macOS MPS、CPU。

 

2. 使用 argparse 接收超参数

parser.add_argument('--project_name', type=str, default='test_project')
...

你可以通过命令行自由调整每次实验的配置,比如 batch size、dropout 等。

 

3. 构建模型

def build_model(hidden_dim):
    return nn.Sequential(OrderedDict([
        ...
    ]))

这是一个简单的 CNN,两层卷积 + 全连接。

 

4. 初始化 WandB

wandb.init(project=config.project_name, config=vars(config))
wandb.watch(model)

• wandb.init 用于开始一个实验记录

• wandb.watch 会追踪模型中的梯度和参数变化

 

5. 训练与评估

for epoch in range(1, config.epochs + 1):
    ...
    wandb.log({
        'epoch': epoch,
        'train_loss': train_loss,
        'train_acc': train_acc,
        'val_loss': val_loss,
        'val_acc': val_acc
    })

你可以在 WandB 面板中直观查看每个 epoch 的准确率、损失变化趋势。

 

6. 保存最佳模型

if val_acc > best_val_acc:
    torch.save(model.state_dict(), config.ckpt_path)

每当验证集准确率提升时,我们会保存模型。

 

7. 完整代码示例

import argparse
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from collections import OrderedDict
import wandb

# Device configuration
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f'Using device: {device}')


def parse_args():
    parser = argparse.ArgumentParser(description="Train a CNN on MNIST with configurable parameters")
    parser.add_argument('--project_name', type=str, default='test_project')
    parser.add_argument('--batch_size', type=int, default=512)
    parser.add_argument('--hidden_layer_width', type=int, default=64)
    parser.add_argument('--dropout_p', type=float, default=0.1)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--optim_type', type=str, default='Adam', choices=['Adam', 'SGD', 'RMSprop'])
    parser.add_argument('--epochs', type=int, default=15)
    parser.add_argument('--ckpt_path', type=str, default='best_model.pt')
    return parser.parse_args()


def get_dataloaders(batch_size):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    train_set = torchvision.datasets.MNIST(root='./mnist', train=True, download=True, transform=transform)
    val_set = torchvision.datasets.MNIST(root='./mnist', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader


def build_model(hidden_dim):
    return nn.Sequential(OrderedDict([
        ("conv1", nn.Conv2d(1, hidden_dim, kernel_size=3, padding=1)),
        ("relu1", nn.ReLU()),
        ("pool1", nn.MaxPool2d(2, 2)),

        ("conv2", nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1)),
        ("relu2", nn.ReLU()),
        ("pool2", nn.MaxPool2d(2, 2)),

        ("flatten", nn.Flatten()),
        ("fc", nn.Linear(hidden_dim * 7 * 7, 10))
    ]))


def get_optimizer(model, config):
    if config.optim_type == 'Adam':
        return torch.optim.Adam(model.parameters(), lr=config.lr)
    elif config.optim_type == 'SGD':
        return torch.optim.SGD(model.parameters(), lr=config.lr)
    elif config.optim_type == 'RMSprop':
        return torch.optim.RMSprop(model.parameters(), lr=config.lr)


def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy


def train(model, train_loader, val_loader, config):
    wandb.init(project=config.project_name, config=vars(config))
    wandb.watch(model)

    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = get_optimizer(model, config)
    best_val_acc = 0.0

    for epoch in range(1, config.epochs + 1):
        model.train()
        running_loss, correct, total = 0.0, 0, 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_loss = running_loss / total
        train_acc = correct / total
        val_loss, val_acc = evaluate(model, val_loader, criterion)

        wandb.log({
            'epoch': epoch,
            'train_loss': train_loss,
            'train_acc': train_acc,
            'val_loss': val_loss,
            'val_acc': val_acc
        })

        print(f"[Epoch {epoch:02d}] "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), config.ckpt_path)
            print(f"🔥 New best model saved! Val Acc: {best_val_acc:.4f}")

    wandb.finish()


def main():
    config = parse_args()
    train_loader, val_loader = get_dataloaders(config.batch_size)
    model = build_model(config.hidden_layer_width)
    train(model, train_loader, val_loader, config)


if __name__ == '__main__':
    main()

📺 运行效果展示

 

训练完成后,你可以打开控制台提供的链接,看到如下信息:

• 训练 / 验证准确率曲线

• 模型权重变化

• 超参数记录表

• 模型 checkpoint 下载入口

 


🧪 小结与建议

 

WandB 能让你从混乱的 notebook/print 日志中解放出来,专注在 模型本身的改进。而且它还支持团队协作、在线对比不同实验结果,对于科研或工业落地都有很大帮助。

 


📦 完整代码仓库

 

你可以将本文代码作为模板,快速迁移到你自己的模型中。强烈建议你试着改一改:

• 用 CIFAR-10 替代 MNIST

• 增加更多层、dropout、或者 BatchNorm

• 对比 Adam 与 SGD 的训练曲线

 


如果你觉得本文有帮助,不妨点个赞或收藏吧 😉

 

Logo

永洪科技,致力于打造全球领先的数据技术厂商,具备从数据应用方案咨询、BI、AIGC智能分析、数字孪生、数据资产、数据治理、数据实施的端到端大数据价值服务能力。

更多推荐