快速上手 WandB:用 PyTorch 训练 MNIST + 可视化全流程
本篇文章将带你通过一个完整的 PyTorch 示例,从零快速上手 WandB:训练一个简单的 CNN 模型用于 MNIST 分类,并记录并可视化训练过程。在训练深度学习模型时,如何高效地追踪训练过程、记录实验参数、可视化损失与准确率曲线,一直是我们开发流程中的痛点。你可以通过命令行自由调整每次实验的配置,比如 batch size、dropout 等。你可以在 WandB 面板中直观查看每个 ep
在训练深度学习模型时,如何高效地追踪训练过程、记录实验参数、可视化损失与准确率曲线,一直是我们开发流程中的痛点。而 Weights & Biases (WandB) 正是为了解决这些问题而生的工具。
本篇文章将带你通过一个完整的 PyTorch 示例,从零快速上手 WandB:训练一个简单的 CNN 模型用于 MNIST 分类,并记录并可视化训练过程。
🧠 为什么选择 WandB?
• 📈 自动可视化指标(loss、accuracy等)
• 🧪 记录超参数
• 💾 自动保存最优模型
• 📊 支持团队协作与在线分享实验结果
• 🔁 便于实验复现
🧰 环境准备
确保你已经安装好以下依赖:
pip install torch torchvision wandb
🏗️ 项目结构概览
我们将实现以下几件事:
• 使用 argparse 支持命令行参数
• 构建一个 CNN 模型用于 MNIST
• 支持多种优化器、动态超参数配置
• 使用 WandB 记录和可视化训练过程
✍️ 代码详解与使用方式
你可以把以下完整代码保存为 train_mnist.py,并通过命令行运行:
python train_mnist.py
1. 引入依赖并配置设备
import wandb
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
我们这里支持 macOS MPS、CPU。
2. 使用 argparse 接收超参数
parser.add_argument('--project_name', type=str, default='test_project')
...
你可以通过命令行自由调整每次实验的配置,比如 batch size、dropout 等。
3. 构建模型
def build_model(hidden_dim):
return nn.Sequential(OrderedDict([
...
]))
这是一个简单的 CNN,两层卷积 + 全连接。
4. 初始化 WandB
wandb.init(project=config.project_name, config=vars(config))
wandb.watch(model)
• wandb.init 用于开始一个实验记录
• wandb.watch 会追踪模型中的梯度和参数变化
5. 训练与评估
for epoch in range(1, config.epochs + 1):
...
wandb.log({
'epoch': epoch,
'train_loss': train_loss,
'train_acc': train_acc,
'val_loss': val_loss,
'val_acc': val_acc
})
你可以在 WandB 面板中直观查看每个 epoch 的准确率、损失变化趋势。
6. 保存最佳模型
if val_acc > best_val_acc:
torch.save(model.state_dict(), config.ckpt_path)
每当验证集准确率提升时,我们会保存模型。
7. 完整代码示例
import argparse
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from collections import OrderedDict
import wandb
# Device configuration
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f'Using device: {device}')
def parse_args():
parser = argparse.ArgumentParser(description="Train a CNN on MNIST with configurable parameters")
parser.add_argument('--project_name', type=str, default='test_project')
parser.add_argument('--batch_size', type=int, default=512)
parser.add_argument('--hidden_layer_width', type=int, default=64)
parser.add_argument('--dropout_p', type=float, default=0.1)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--optim_type', type=str, default='Adam', choices=['Adam', 'SGD', 'RMSprop'])
parser.add_argument('--epochs', type=int, default=15)
parser.add_argument('--ckpt_path', type=str, default='best_model.pt')
return parser.parse_args()
def get_dataloaders(batch_size):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_set = torchvision.datasets.MNIST(root='./mnist', train=True, download=True, transform=transform)
val_set = torchvision.datasets.MNIST(root='./mnist', train=False, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
return train_loader, val_loader
def build_model(hidden_dim):
return nn.Sequential(OrderedDict([
("conv1", nn.Conv2d(1, hidden_dim, kernel_size=3, padding=1)),
("relu1", nn.ReLU()),
("pool1", nn.MaxPool2d(2, 2)),
("conv2", nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1)),
("relu2", nn.ReLU()),
("pool2", nn.MaxPool2d(2, 2)),
("flatten", nn.Flatten()),
("fc", nn.Linear(hidden_dim * 7 * 7, 10))
]))
def get_optimizer(model, config):
if config.optim_type == 'Adam':
return torch.optim.Adam(model.parameters(), lr=config.lr)
elif config.optim_type == 'SGD':
return torch.optim.SGD(model.parameters(), lr=config.lr)
elif config.optim_type == 'RMSprop':
return torch.optim.RMSprop(model.parameters(), lr=config.lr)
def evaluate(model, dataloader, criterion):
model.eval()
total_loss, correct, total = 0.0, 0, 0
with torch.no_grad():
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
total_loss += loss.item() * images.size(0)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
avg_loss = total_loss / total
accuracy = correct / total
return avg_loss, accuracy
def train(model, train_loader, val_loader, config):
wandb.init(project=config.project_name, config=vars(config))
wandb.watch(model)
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = get_optimizer(model, config)
best_val_acc = 0.0
for epoch in range(1, config.epochs + 1):
model.train()
running_loss, correct, total = 0.0, 0, 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
train_loss = running_loss / total
train_acc = correct / total
val_loss, val_acc = evaluate(model, val_loader, criterion)
wandb.log({
'epoch': epoch,
'train_loss': train_loss,
'train_acc': train_acc,
'val_loss': val_loss,
'val_acc': val_acc
})
print(f"[Epoch {epoch:02d}] "
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), config.ckpt_path)
print(f"🔥 New best model saved! Val Acc: {best_val_acc:.4f}")
wandb.finish()
def main():
config = parse_args()
train_loader, val_loader = get_dataloaders(config.batch_size)
model = build_model(config.hidden_layer_width)
train(model, train_loader, val_loader, config)
if __name__ == '__main__':
main()
📺 运行效果展示
训练完成后,你可以打开控制台提供的链接,看到如下信息:
• 训练 / 验证准确率曲线
• 模型权重变化
• 超参数记录表
• 模型 checkpoint 下载入口
🧪 小结与建议
WandB 能让你从混乱的 notebook/print 日志中解放出来,专注在 模型本身的改进。而且它还支持团队协作、在线对比不同实验结果,对于科研或工业落地都有很大帮助。
📦 完整代码仓库
你可以将本文代码作为模板,快速迁移到你自己的模型中。强烈建议你试着改一改:
• 用 CIFAR-10 替代 MNIST
• 增加更多层、dropout、或者 BatchNorm
• 对比 Adam 与 SGD 的训练曲线
如果你觉得本文有帮助,不妨点个赞或收藏吧 😉
更多推荐
所有评论(0)