[pytorch] Resnet3D预训练网络 + MedMNIST 3D医学数据分类 | 您所在的位置:网站首页 › 3d影像图片 › [pytorch] Resnet3D预训练网络 + MedMNIST 3D医学数据分类 |
[pytorch] MedMNIST 3D医学数据分类
MedMNIST数据集OrganMNIST3D 多分类任务加载库加载数据使用Resnet3D预训练网络train结果
VesselMNIST3D 二分类任务
MedMNIST数据集
医学数据集的资源往往是比较难找的,3d数据集公开的更少。而MedMNIST v2,是一个大规模的类似 MNIST 的标准化生物医学图像集合,包括 12 个 2D 数据集和 6 个 3D 数据集。所有图像都被预处理成 28 x 28 (2D) 或 28 x 28 x 28 (3D) 并带有相应的分类标签,因此用户不需要背景知识。MedMNIST v2 涵盖生物医学图像中的主要数据模式,旨在对具有各种数据规模(从 100 到 100,000)和不同任务(二元/多类、序数回归和多标签)的轻量级 2D 和 3D 图像执行分类。 我们可以使用它来测试我们的3d网络等等。 数据集介绍:MedMNIST v2: A Large-Scale Lightweight Benchmark for 2D and 3D Biomedical Image Classification github:MedMNIST 我们这里分析两个3d数据集 OrganMNIST3D 和 VesselMNIST3D,分别实现多分类和二分类。 使用tensorboard记录结果 from torch.utils.tensorboard import SummaryWriter summaryWriter = SummaryWriter("./logs/") 加载数据 batch_size = 256数据处理 class Transform3D: def __init__(self, mul=None): self.mul = mul def __call__(self, voxel): if self.mul == '0.5': voxel = voxel * 0.5 elif self.mul == 'random': voxel = voxel * np.random.uniform() return voxel.astype(np.float32)下载数据 print('==> Preparing data...') train_transform = Transform3D(mul='random') eval_transform = Transform3D(mul='0.5') data_flag = 'organmnist3d' # Multi-Class (11) download = True info = INFO[data_flag] DataClass = getattr(medmnist, info['python_class']) # load the data train_dataset = DataClass(split='train', transform=train_transform, download=download) val_dataset = DataClass(split='val', transform=eval_transform, download=download) test_dataset = DataClass(split='test', transform=eval_transform, download=download)3d数据可视化函数 def draw_oct(volume, type_volume = 'np',canal_first = False): if type_volume == 'np': if canal_first == False: print("taille du volume = %s (%s)"%(volume.shape,type_volume)) slice_h_n, slice_d_n , slice_w_n = int(volume.shape[0]/2),int(volume.shape[1]/2),int(volume.shape[2]/2) slice_h = volume[slice_h_n,:,:,:] slice_d = volume[:,slice_d_n,:,:] slice_w = volume[:,:,slice_w_n,:] slice_h = Image.fromarray(np.squeeze(slice_h)) slice_d = Image.fromarray(np.squeeze(slice_d)) slice_w = Image.fromarray(np.squeeze(slice_w)) plt.figure(figsize=(21,7)) plt.subplot(1, 3, 1) plt.imshow(slice_h) plt.title(slice_h.size) plt.axis('off') plt.subplot(1, 3, 2) plt.imshow(slice_d) plt.title(slice_d.size) plt.axis('off') plt.subplot(1, 3, 3) plt.imshow(slice_w) plt.title(slice_w.size) plt.axis('off') if canal_first == True: print("taille du volume = %s (%s)"%(volume.shape,type_volume)) slice_h_n, slice_d_n , slice_w_n = int(volume.shape[1]/2),int(volume.shape[2]/2),int(volume.shape[3]/2) slice_h = volume[:,slice_h_n,:,:] slice_d = volume[:,:,slice_d_n,:] slice_w = volume[:,:,:,slice_w_n] slice_h = Image.fromarray(np.squeeze(slice_h)) slice_d = Image.fromarray(np.squeeze(slice_d)) slice_w = Image.fromarray(np.squeeze(slice_w)) plt.figure(figsize=(21,7)) plt.subplot(1, 3, 1) plt.imshow(slice_h) plt.title(slice_h.size) plt.axis('off') plt.subplot(1, 3, 2) plt.imshow(slice_d) plt.title(slice_d.size) plt.axis('off') plt.subplot(1, 3, 3) plt.imshow(slice_w) plt.title(slice_w.size) plt.axis('off') if type_volume == 'tensor': if canal_first == False: print("taille du volume = %s (%s)"%(volume.shape,type_volume)) slice_h_n, slice_d_n , slice_w_n = int(volume.shape[0]/2),int(volume.shape[1]/2),int(volume.shape[2]/2) slice_h = volume[slice_h_n,:,:,:].numpy() slice_d = volume[:,slice_d_n,:,:].numpy() slice_w = volume[:,:,slice_w_n,:].numpy() slice_h = Image.fromarray(np.squeeze(slice_h)) slice_d = Image.fromarray(np.squeeze(slice_d)) slice_w = Image.fromarray(np.squeeze(slice_w)) plt.figure(figsize=(21,7)) plt.subplot(1, 3, 1) plt.imshow(slice_h) plt.title(slice_h.size) plt.axis('off') plt.subplot(1, 3, 2) plt.imshow(slice_d) plt.title(slice_d.size) plt.axis('off') plt.subplot(1, 3, 3) plt.imshow(slice_w) plt.title(slice_w.size) plt.axis('off') if canal_first == True: slice_h_n, slice_d_n , slice_w_n = int(volume.shape[1]/2),int(volume.shape[2]/2),int(volume.shape[3]/2) slice_h = volume[:,slice_h_n,:,:].numpy() slice_d = volume[:,:,slice_d_n,:].numpy() slice_w = volume[:,:,:,slice_w_n].numpy() slice_h = Image.fromarray(np.squeeze(slice_h)) slice_d = Image.fromarray(np.squeeze(slice_d)) slice_w = Image.fromarray(np.squeeze(slice_w)) plt.figure(figsize=(21,7)) plt.subplot(1, 3, 1) plt.imshow(slice_h) plt.title(slice_h.size) plt.axis('off') plt.subplot(1, 3, 2) plt.imshow(slice_d) plt.title(slice_d.size) plt.axis('off') plt.subplot(1, 3, 3) plt.imshow(slice_w) plt.title(slice_w.size) plt.axis('off') x, y = train_dataset[0] print(x.shape, y) draw_oct(x*500,type_volume = 'np',canal_first = True)
我使用了MedicalNet的预训练resnet模型。 mednet的网络是用于分割任务的,所以其结构是resnet提取特征图像,最后加反卷积层做分割。我们的任务是分类,于是我将最后的反卷积层替换为分类层。 resnet3d预训练模型参数可以从官方的github上下载,然后直接像下面一样加载即可。注意:需要使用mednet项目代码中的models文件夹,将这个文件夹和要加载的预训练参数复制到自己的项目中。 from models import resnet device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') print('device =',device) print(torch.cuda.get_device_name(0)) def generate_model(model_type='resnet', model_depth=50, input_W=224, input_H=224, input_D=224, resnet_shortcut='B', no_cuda=False, gpu_id=[0], pretrain_path = 'pretrain/resnet_50.pth', nb_class=1): assert model_type in [ 'resnet' ] if model_type == 'resnet': assert model_depth in [10, 18, 34, 50, 101, 152, 200] if model_depth == 10: model = resnet.resnet10( sample_input_W=input_W, sample_input_H=input_H, sample_input_D=input_D, shortcut_type=resnet_shortcut, no_cuda=no_cuda, num_seg_classes=1) fc_input = 256 elif model_depth == 18: model = resnet.resnet18( sample_input_W=input_W, sample_input_H=input_H, sample_input_D=input_D, shortcut_type=resnet_shortcut, no_cuda=no_cuda, num_seg_classes=1) fc_input = 512 elif model_depth == 34: model = resnet.resnet34( sample_input_W=input_W, sample_input_H=input_H, sample_input_D=input_D, shortcut_type=resnet_shortcut, no_cuda=no_cuda, num_seg_classes=1) fc_input = 512 elif model_depth == 50: model = resnet.resnet50( sample_input_W=input_W, sample_input_H=input_H, sample_input_D=input_D, shortcut_type=resnet_shortcut, no_cuda=no_cuda, num_seg_classes=1) fc_input = 2048 elif model_depth == 101: model = resnet.resnet101( sample_input_W=input_W, sample_input_H=input_H, sample_input_D=input_D, shortcut_type=resnet_shortcut, no_cuda=no_cuda, num_seg_classes=1) fc_input = 2048 elif model_depth == 152: model = resnet.resnet152( sample_input_W=input_W, sample_input_H=input_H, sample_input_D=input_D, shortcut_type=resnet_shortcut, no_cuda=no_cuda, num_seg_classes=1) fc_input = 2048 elif model_depth == 200: model = resnet.resnet200( sample_input_W=input_W, sample_input_H=input_H, sample_input_D=input_D, shortcut_type=resnet_shortcut, no_cuda=no_cuda, num_seg_classes=1) fc_input = 2048 model.conv_seg = nn.Sequential(nn.AdaptiveAvgPool3d((1, 1, 1)), nn.Flatten(), nn.Linear(in_features=fc_input, out_features=nb_class, bias=True)) if not no_cuda: if len(gpu_id) > 1: model = model.cuda() model = nn.DataParallel(model, device_ids=gpu_id) net_dict = model.state_dict() else: import os os.environ["CUDA_VISIBLE_DEVICES"]=str(gpu_id[0]) model = model.cuda() model = nn.DataParallel(model, device_ids=None) net_dict = model.state_dict() else: net_dict = model.state_dict() print('loading pretrained model {}'.format(pretrain_path)) pretrain = torch.load(pretrain_path) pretrain_dict = {k: v for k, v in pretrain['state_dict'].items() if k in net_dict.keys()} # k 是每一层的名称,v是权重数值 net_dict.update(pretrain_dict) #字典 dict2 的键/值对更新到 dict 里。 model.load_state_dict(net_dict) #model.load_state_dict()函数把加载的权重复制到模型的权重中去 print("-------- pre-train model load successfully --------") return model model = generate_model(model_type='resnet', model_depth=50, input_W=224, input_H=224, input_D=224, resnet_shortcut='B', no_cuda=False, gpu_id=[0], pretrain_path = './resnet_50_23dataset.pth', nb_class=11)最后让我们看一下训练结果。 大体上和多分类任务是一样的,有几段代码需要修改。 数据下载 data_flag = 'vesselmnist3d' # Binary-Class (2) download = True info = INFO[data_flag] DataClass = getattr(medmnist, info['python_class']) # load the data train_dataset = DataClass(split='train', transform=train_transform, download=download) val_dataset = DataClass(split='val', transform=eval_transform, download=download) test_dataset = DataClass(split='test', transform=eval_transform, download=download)可视化 训练参数,二分类使用BCEWithLogitsLoss optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([10.0])).cuda() #分类不均衡 scheduler = ExponentialLR(optimizer, gamma=0.99) num_epochs = 1200我们使用acc和auc作为指标 from sklearn.metrics import roc_curve from sklearn.metrics import auc for epoch in range(num_epochs): start = time.time() per_epoch_loss = 0 num_correct= 0 score_list = [] label_list = [] val_num_correct = 0 val_score_list = [] val_label_list = [] model.train() with torch.enable_grad(): for x,label in tqdm(train_loader): x = x.float() x = x.to(device) label = label.to(device) label = torch.squeeze(label) label_list.extend(label.cpu().numpy()) #print(label_list) # Forward pass logits = model(x) logits = logits.reshape([label.cpu().numpy().shape[0]]) prob_out = nn.Sigmoid()(logits) #print(logits.shape) pro_list = prob_out.detach().cpu().numpy() #print(pro_list) #print(abc) #print(pro_list) for i in range(pro_list.shape[0]): if (pro_list[i] > 0.5) == label.cpu().numpy()[i]: num_correct += 1 score_list.extend(pro_list) loss = criterion(logits, label.float()) per_epoch_loss += loss.item() # Backward and optimize optimizer.zero_grad() loss.backward() optimizer.step() #pred = logits.argmax(dim=1) #num_correct += torch.eq(pred, label).sum().float().item() score_array = np.array(score_list) label_array = np.array(label_list) fpr_keras_1, tpr_keras_1, thresholds_keras_1 = roc_curve(label_array, score_array) auc_keras_1 = auc(fpr_keras_1,tpr_keras_1) print("Train EVpoch: {}\t Loss: {:.6f}\t Acc: {:.6f} AUC: {:.6f} ".format(epoch,per_epoch_loss/len(train_loader),num_correct/len(train_loader.dataset),auc_keras_1)) summaryWriter.add_scalars('loss', {"loss":(per_epoch_loss/len(train_loader))}, epoch) summaryWriter.add_scalars('acc', {"acc":num_correct/len(train_loader.dataset)}, epoch) summaryWriter.add_scalars('auc', {"auc":auc_keras_1}, epoch) model.eval() with torch.no_grad(): for x,label in tqdm(val_loader): x = x.float() x = x.to(device) label = label.to(device) #label_n = label.cpu().numpy() val_label_list.extend(label.cpu().numpy()) # Forward pass logits = model(x) logits = logits.reshape([label.cpu().numpy().shape[0]]) prob_out = nn.Sigmoid()(logits) #print(logits.shape) pro_list = prob_out.detach().cpu().numpy() #print(pro_list) for i in range(pro_list.shape[0]): if (pro_list[i] > 0.5) == label.cpu().numpy()[i]: val_num_correct += 1 val_score_list.extend(pro_list) score_array = np.array(val_score_list) label_array = np.array(val_label_list) fpr_keras_1, tpr_keras_1, thresholds_keras_1 = roc_curve(label_array, score_array) auc_keras_1 = auc(fpr_keras_1,tpr_keras_1) print("val Epoch: {}\t Acc: {:.6f} AUC: {:.6f} ".format(epoch,val_num_correct/len(val_loader.dataset),auc_keras_1)) summaryWriter.add_scalars('acc', {"val_acc":val_num_correct/len(val_loader.dataset)}, epoch) summaryWriter.add_scalars('auc', {"val_auc":auc_keras_1}, epoch) summaryWriter.add_scalars('time', {"time":(time.time() - start)}, epoch) scheduler.step() #filepath = "./weights" #folder = os.path.exists(filepath) #if not folder: # # 判断是否存在文件夹如果不存在则创建为文件夹 # os.makedirs(filepath) #path = './weights/model' + str(epoch) + '.pth' #torch.save(model.state_dict(), path)结果 |
CopyRight 2018-2019 实验室设备网 版权所有 |