1.数据集
数据集地址:https://www.kaggle.com/slothkong/10-monkey-species 采用kaggle上的猴子数据集,包含两个文件:训练集和验证集。每个文件夹包含10个标记为n0-n9的猴子。图像尺寸为400x300像素或更大,并且为JPEG格式(近1400张图像)。 图片样本 图片类别标签,训练集,验证集划分说明
2.代码
2.1 定义需要的库
import os
import sys
import json
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
2.2 定义训练验证函数
def train_and_val(epochs, model, train_loader, val_loader, criterion, optimizer):
torch.cuda.empty_cache()
train_loss = []
val_loss = []
train_acc = []
val_acc = []
best_acc = 0
model.to(device)
fit_time = time.time()
for e in range(epochs):
since = time.time()
running_loss = 0
training_acc = 0
with tqdm(total=len(train_loader)) as pbar:
for image, label in train_loader:
# training phase
# images, labels = data
# optimizer.zero_grad()
# logits = net(images.to(device))
# loss = loss_function(logits, labels.to(device))
# loss.backward()
# optimizer.step()
model.train()
optimizer.zero_grad()
image = image.to(device)
label = label.to(device)
# forward
output = model(image)
loss = criterion(output, label)
predict_t = torch.max(output, dim=1)[1]
# backward
loss.backward()
optimizer.step() # update weight
running_loss += loss.item()
training_acc += torch.eq(predict_t, label).sum().item()
pbar.update(1)
model.eval()
val_losses = 0
validation_acc = 0
# validation loop
with torch.no_grad():
with tqdm(total=len(val_loader)) as pb:
for image, label in val_loader:
image = image.to(device)
label = label.to(device)
output = model(image)
# loss
loss = criterion(output, label)
predict_v = torch.max(output, dim=1)[1]
val_losses += loss.item()
validation_acc += torch.eq(predict_v, label).sum().item()
pb.update(1)
# calculatio mean for each batch
train_loss.append(running_loss / len(train_dataset))
val_loss.append(val_losses / len(val_dataset))
train_acc.append(training_acc / len(train_dataset))
val_acc.append(validation_acc / len(val_dataset))
torch.save(model, "last.pth")
if best_acc
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
train_dataset = datasets.ImageFolder("../input/10-monkey-species/training/training/", transform=data_transform["train"]) # 训练集数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True,
num_workers=2) # 加载数据
val_dataset = datasets.ImageFolder("../input/10-monkey-species/validation/validation/", transform=data_transform["val"]) # 测试集数据
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=2) # 加载数据
2.5 开始训练
net = resnet34()
loss_function = nn.CrossEntropyLoss() # 设置损失函数
optimizer = optim.Adam(net.parameters(), lr=0.0001) # 设置优化器和学习率
epoch = 60
history = train_and_val(epoch, net, train_loader, val_loader, loss_function, optimizer)
执行结果
Epoch:55/60.. Train Acc: 0.813.. Val Acc: 0.860.. Train Loss: 0.038.. Val Loss: 0.029.. Time: 38.40s
100%|██████████| 69/69 [00:28 |