使用 Python 实现一个简单的图像分类器
在深度学习和人工智能飞速发展的今天,图像分类已经成为计算机视觉中的一个重要任务。图像分类的目标是根据输入的图像内容将其分配到预定义的类别中。例如,我们可以训练一个模型来识别图片是猫、狗还是鸟。
本文将介绍如何使用 Python 和深度学习框架 PyTorch 来实现一个简单的图像分类器。我们将使用经典的卷积神经网络(CNN)架构,并在一个小型数据集上进行训练和测试。
1. 环境准备
首先,我们需要安装必要的库:
pip install torch torchvision matplotlib
torch
:PyTorch 深度学习框架。torchvision
:提供常用的数据集和模型。matplotlib
:用于可视化图像。2. 数据准备
我们将使用 torchvision.datasets.CIFAR10
数据集,它包含 10 类彩色图像(如飞机、汽车、鸟等),每张图像大小为 32x32 像素。
import torchimport torchvisionimport torchvision.transforms as transforms# 定义图像预处理方式transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 加载训练集和测试集trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
我们还对图像进行了归一化处理,使其像素值分布在 [-1, 1] 区间内。
可视化一些训练图像
import matplotlib.pyplot as pltimport numpy as npdef imshow(img): img = img / 2 + 0.5 # 反归一化 npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show()# 获取一批训练图像dataiter = iter(trainloader)images, labels = next(dataiter)# 显示图像imshow(torchvision.utils.make_grid(images))# 打印标签print(' '.join(f'{classes[labels[j]]}' for j in range(4)))
3. 构建卷积神经网络(CNN)
接下来,我们定义一个简单的 CNN 模型:
import torch.nn as nnimport torch.nn.functional as Fclass Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) # 输入通道数3,输出通道数6,卷积核大小5x5 self.pool = nn.MaxPool2d(2, 2) # 最大池化层 self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) # 第一层卷积+池化 x = self.pool(F.relu(self.conv2(x))) # 第二层卷积+池化 x = x.view(-1, 16 * 5 * 5) # 展平 x = F.relu(self.fc1(x)) # 全连接层 x = F.relu(self.fc2(x)) x = self.fc3(x) return xnet = Net()
4. 定义损失函数和优化器
我们使用交叉熵损失函数和随机梯度下降(SGD)优化器:
import torch.optim as optimcriterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
5. 训练模型
现在我们开始训练模型:
for epoch in range(2): # 多次遍历数据集 running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data optimizer.zero_grad() # 清空梯度 outputs = net(inputs) # 前向传播 loss = criterion(outputs, labels) # 计算损失 loss.backward() # 反向传播 optimizer.step() # 更新参数 running_loss += loss.item() if i % 2000 == 1999: # 每2000个 mini-batch 打印一次 print(f'Epoch {epoch + 1}, Batch {i + 1} loss: {running_loss / 2000:.3f}') running_loss = 0.0print('Finished Training')
训练过程可能需要几分钟时间,取决于你的硬件配置。
6. 测试模型
训练完成后,我们评估模型在测试集上的表现:
correct = 0total = 0with torch.no_grad(): for data in testloader: images, labels = data outputs = net(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item()print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f}%')
你也可以进一步分析每个类别的准确率:
class_correct = list(0. for i in range(10))class_total = list(0. for i in range(10))with torch.no_grad(): for data in testloader: images, labels = data outputs = net(images) _, predicted = torch.max(outputs, 1) c = (predicted == labels).squeeze() for i in range(4): label = labels[i] class_correct[label] += c[i].item() class_total[label] += 1for i in range(10): print(f'Accuracy of {classes[i]} : {100 * class_correct[i] / class_total[i]:.2f}%')
7. 总结与扩展
本篇文章介绍了如何使用 PyTorch 构建一个简单的图像分类器。我们使用了经典的 CNN 结构对 CIFAR-10 数据集进行训练,并实现了基本的图像分类功能。
后续改进方向:
使用更先进的网络结构(如 ResNet、VGG)提高准确率。使用 GPU 加速训练过程(通过.to(device)
将模型和数据转移到 GPU 上)。添加早停机制或学习率衰减策略。使用 TensorBoard 或 wandb 进行训练日志记录和可视化。图像分类是一个非常活跃的研究领域,随着技术的发展,我们可以构建越来越强大的模型来解决更加复杂的视觉任务。希望这篇文章能为你入门深度学习图像分类提供一个良好的起点。
免责声明:本文来自网站作者,不代表CIUIC的观点和立场,本站所发布的一切资源仅限用于学习和研究目的;不得将上述内容用于商业或者非法用途,否则,一切后果请用户自负。本站信息来自网络,版权争议与本站无关。您必须在下载后的24个小时之内,从您的电脑中彻底删除上述内容。如果您喜欢该程序,请支持正版软件,购买注册,得到更好的正版服务。客服邮箱:ciuic@ciuic.com