使用卷积神经网络的图像分类

使用 PyTorchMNIST 数据集训练卷积神经网络(CNN),以便在 AI 运行时上对图像分类。 MNIST 包含 70,000 张手写数字(0-9)的灰度图像,非常适合学习图像分类技术。

你将了解如何:

  • 使用 A10G GPU 将您的笔记本连接到无服务器的 GPU 计算服务
  • 定义简单的卷积神经网络体系结构
  • 在单个 GPU 上训练模型,并将指标记录到 MLflow
  • 将模型检查点保存到 Unity 目录卷
  • 加载和评估已训练的模型

连接到无服务器 GPU 计算

此笔记本需要 GPU 才能有效地训练神经网络。 按照以下步骤连接到无服务器 GPU 计算:

  1. 单击笔记本顶部的 “连接 ”下拉列表。
  2. 选择 无服务器 GPU
  3. 打开笔记本右侧的环境侧面板。
  4. 将此演示的 Accelerator 设置为 1xA10
  5. “环境”下拉列表中选择 AI v5
  6. 选择 “应用 ”,然后单击“ 确认 ”将此环境应用到笔记本。

有关详细信息,请参阅 无服务器 GPU 计算

配置检查点存储位置

以下单元格将创建控件参数,以指定模型检查点将保存在 Unity Catalog 中的位置。 这些参数定义:

  • uc_catalog:Unity Catalog 目录名称
  • uc_schema:目录中的架构(数据库)
  • uc_volume:用于存储检查点文件的卷
  • uc_model_name:此特定模型的卷中的子目录

这些值在整个笔记本中用于构造检查点路径: /Volumes/{uc_catalog}/{uc_schema}/{uc_volume}/{uc_model_name}

以下单元格使用占位符值作为默认值。 使用笔记本顶部的控件更新值。 或者,直接在下一个单元格中更新默认值。

dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_volume", "checkpoints")
dbutils.widgets.text("uc_model_name", "cnn_mnist")

定义卷积神经网络

以下单元格定义用于图像分类的简单 CNN 体系结构。 网络包括:

  • 两个带有最大池化的卷积层,用于从图像中提取特征
  • 两个完全连接的层,用于对提取的特征进行分类
  • 丢弃层以防止过度拟合

该代码还定义了用于将模型和优化器状态检查指向 Unity Catalog 卷的辅助类,以及用于设置分布式训练(用于多 GPU 场景)的函数。

此实现改编自 Horovod PyTorch MNIST 示例

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from datetime import timedelta
import os

from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_VOLUME = dbutils.widgets.get("uc_volume")
UC_MODEL_NAME = dbutils.widgets.get("uc_model_name")

# Ensure that the UC Volume directory exists first
CHECKPOINT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{UC_MODEL_NAME}"

class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

def setup():
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    # Shorter timeouts help surface failures quickly instead of hanging
    dist.init_process_group(
        backend="nccl",
        timeout=timedelta(seconds=120),
        init_method="env://",
        rank=rank,
        world_size=world_size,
    )
    torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0)))
    dist.barrier()
    if rank == 0:
        print("PG up; all ranks reached barrier")


def cleanup():
    try:
        dist.barrier()
    finally:
        dist.destroy_process_group()

配置训练参数

以下单元格设置用于训练的超参数:

  • batch_size:每个训练迭代中处理的图像数
  • num_epochs:训练数据集完整遍历的次数
  • momentum:SGD优化器的动量因子
  • log_interval:日志记录训练进度的频率
# Specify training parameters
batch_size = 100
num_epochs = 5
momentum = 0.5
log_interval = 100

定义训练循环

以下单元格定义函数 train_one_epoch ,该函数:

  • 迭代批量训练数据
  • 执行前向传播和反向传播
  • 使用优化器更新模型权重
  • 定期在固定时间间隔将训练损失记录到 MLflow
def train_one_epoch(model, device, data_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(data_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(data_loader) * len(data),
                100. * batch_idx / len(data_loader), loss.item()))
            # Log metrics
            mlflow.log_metric('loss', loss.item(), step=epoch * len(data_loader) + batch_idx)

在单个 GPU 上训练模型

以下单元格定义主要训练函数:

  • 加载 MNIST 训练数据集
  • 初始化模型和优化器
  • 针对指定的训练轮数训练模型
  • 在每个时期结束后将检查点保存到 Unity Catalog 卷中
  • 将日志指标记录到 MLflow 用于实验跟踪
import mlflow
import torch.optim as optim
from torchvision import datasets, transforms

def train(learning_rate):

  with mlflow.start_run() as run:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    train_dataset = datasets.MNIST(
      'data',
      train=True,
      download=True,
      transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
    data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    model = Net().to(device)
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
    with torch.no_grad():
      input_example, _ = next(iter(data_loader))
      output_example = model(input_example.to(device))

    for epoch in range(1, num_epochs + 1):
      train_one_epoch(model, device, data_loader, optimizer, epoch)

      state_dict = { "app": AppState(model, optimizer) }
      dcp.save(state_dict, checkpoint_id=CHECKPOINT_DIR)
      print(f"saved checkpoint to {CHECKPOINT_DIR}")

运行训练函数

以下单元格以学习速率 0.001 执行 train 函数。 训练过程将:

  • 下载 MNIST 数据集(如果尚未缓存)
  • 训练模型 5 个迭代周期
  • 显示训练进度和损失值
  • 将模型检查点保存到 Unity 目录卷
  • 将指标记录到 MLflow

在 A10G GPU 上训练通常需要几分钟时间。

train(learning_rate = 0.001)

加载和评估已训练的模型

训练后,可以从检查点加载模型,并评估其在测试数据集上的性能。

以下单元格定义一个 test 函数:

  • 从 Unity Catalog 卷检查点加载模型状态
  • 下载 MNIST 测试数据集
  • 基于测试数据评估模型
  • 计算并显示平均测试损失
def test():
  # Load model state from checkpoint using dcp
  model = Net()
  optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=momentum)
  app_state = AppState(model, optimizer)
  state_dict = { "app": app_state }
  dcp.load(state_dict, checkpoint_id=CHECKPOINT_DIR)
  model.eval()

  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  model.to(device)
  test_dataset = datasets.MNIST(
    'data',
    train=False,
    download=True,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
  data_loader = torch.utils.data.DataLoader(test_dataset)

  test_loss = 0
  for data, target in data_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)
      test_loss += F.nll_loss(output, target)

  test_loss /= len(data_loader.dataset)
  print("Average test loss: {}".format(test_loss.item()))

执行评估

以下单元格执行 test 函数来评估 MNIST 测试数据集上的训练模型。 较低的测试损失表示更好的模型性能。

test()

结束语

祝贺! 你已使用无服务器 GPU 计算成功训练了图像分类模型。 你已了解如何执行以下操作:

  • 配置并连接到无服务器 GPU 计算资源
  • 定义卷积神经网络体系结构
  • 使用 PyTorch 训练模型,并将指标记录到 MLflow
  • 将模型检查点保存到 Unity Catalog 存储卷中
  • 加载和评估已训练的模型

断开与 GPU 计算的连接

若要避免不必要的 GPU 使用率,请手动断开与 GPU 的连接:

  1. 选择笔记本顶部的 “连接”
  2. 将鼠标悬停在Serverless
  3. 从下拉菜单中选择 “终止
  4. 选择 “确认 ”以终止

注意:如果不手动断开连接,则连接在 60 分钟不活动后自动终止。

后续步骤

浏览以下资源,详细了解 Databricks 上的机器学习:

示例笔记本

使用卷积神经网络的图像分类

获取笔记本