当前位置: 首页 > news >正文

VGG(pytorch)

VGG:达到了传统串型结构深度的极限

学习VGG原理要了解CNN感受野的基础知识

model.py

import torch.nn as nn
import torch# official pretrain weights
model_urls = {'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth','vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth','vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth','vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'
}class VGG(nn.Module):def __init__(self, features, num_classes=1000, init_weights=False):super(VGG, self).__init__()#features参数是特征层模型,传入这个参数直接使用构造的特征层模型self.features = featuresself.classifier = nn.Sequential(nn.Linear(512*7*7, 4096),nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(4096, 4096),nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(4096, num_classes))if init_weights:self._initialize_weights()def forward(self, x):# N x 3 x 224 x 224x = self.features(x)# N x 512 x 7 x 7x = torch.flatten(x, start_dim=1)# N x 512*7*7x = self.classifier(x)return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')nn.init.xavier_uniform_(m.weight)if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)# nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)def make_features(cfg: list):layers = []in_channels = 3#传入参数cfg是一个列表,遍历参数列表构造VGG特征层for v in cfg:if v == "M":layers += [nn.MaxPool2d(kernel_size=2, stride=2)]else:conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)layers += [conv2d, nn.ReLU(True)]in_channels = vreturn nn.Sequential(*layers)#特征层函数返回一个nn.Sequential(*layers),#这段代码中的 return nn.Sequential(*layers) 使用了 nn.Sequential 类来创建一个神经网络模型。# 在这里,layers 是一个可迭代对象,包含了神经网络模型的各个层或模块。#这段代码的作用是封装一个神经网络模型,该模型按照 layers 中层或模块的顺序连接起来,并作为 nn.Sequential 对象返回。cfgs = {'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}#在函数定义中的 **kwargs 是一个特殊的参数形式,它允许函数接受任意数量的关键字参数(keyword arguments)。
# 这个参数形式使用了双星号 ** 来表示。
#在上述代码中,**kwargs 的作用是允许函数 vgg() 接受额外的关键字参数,并将这些参数收集到 kwargs 字典中
#如vgg(model_name="vgg16", num_classes=10, pretrained=True) pretrained就是一个**kwargs参数
def vgg(model_name="vgg16", **kwargs):assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)cfg = cfgs[model_name]model = VGG(make_features(cfg), **kwargs)return model

train.py

import os
import sys
import jsonimport torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdmfrom model import vggdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root pathimage_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))#定义一个数据加载器用于迭代提取数据train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))# test_data_iter = iter(validate_loader)# test_image, test_label = test_data_iter.next()model_name = "vgg16"net = vgg(model_name=model_name, num_classes=5, init_weights=True)net.to(device)loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.0001)epochs = 30best_acc = 0.0save_path = './{}Net.pth'.format(model_name)train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()outputs = net(images.to(device))loss = loss_function(outputs, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')if __name__ == '__main__':main()

这里由于训练时间太长,运行了19个epoch中断。结果如下

using cuda:0 device.
Using 8 dataloader workers every process
using 3306 images for training, 364 images for validation.
train epoch[1/30] loss:1.542: 100%|██████████| 104/104 [08:39<00:00,  4.99s/it]
100%|██████████| 12/12 [01:13<00:00,  6.15s/it]
[epoch 1] train_loss: 1.605  val_accuracy: 0.245
train epoch[2/30] loss:1.399: 100%|██████████| 104/104 [08:33<00:00,  4.94s/it]
100%|██████████| 12/12 [01:13<00:00,  6.12s/it]
[epoch 2] train_loss: 1.476  val_accuracy: 0.401
train epoch[3/30] loss:1.310: 100%|██████████| 104/104 [08:34<00:00,  4.94s/it]
100%|██████████| 12/12 [01:18<00:00,  6.53s/it]
[epoch 3] train_loss: 1.293  val_accuracy: 0.456
train epoch[4/30] loss:0.958: 100%|██████████| 104/104 [08:33<00:00,  4.94s/it]
100%|██████████| 12/12 [01:13<00:00,  6.11s/it]
[epoch 4] train_loss: 1.185  val_accuracy: 0.519
train epoch[5/30] loss:1.327: 100%|██████████| 104/104 [08:33<00:00,  4.94s/it]
100%|██████████| 12/12 [01:13<00:00,  6.11s/it]
[epoch 5] train_loss: 1.135  val_accuracy: 0.527
train epoch[6/30] loss:1.209: 100%|██████████| 104/104 [08:33<00:00,  4.94s/it]
100%|██████████| 12/12 [01:13<00:00,  6.12s/it]
[epoch 6] train_loss: 1.077  val_accuracy: 0.571
train epoch[7/30] loss:0.725: 100%|██████████| 104/104 [1:25:27<00:00, 49.30s/it]
100%|██████████| 12/12 [01:21<00:00,  6.82s/it]
[epoch 7] train_loss: 1.051  val_accuracy: 0.596
train epoch[8/30] loss:1.146: 100%|██████████| 104/104 [08:50<00:00,  5.10s/it]
100%|██████████| 12/12 [01:27<00:00,  7.31s/it]
[epoch 8] train_loss: 1.008  val_accuracy: 0.615
train epoch[9/30] loss:1.381: 100%|██████████| 104/104 [08:48<00:00,  5.08s/it]
100%|██████████| 12/12 [01:13<00:00,  6.14s/it]
[epoch 9] train_loss: 0.995  val_accuracy: 0.640
train epoch[10/30] loss:0.466: 100%|██████████| 104/104 [08:34<00:00,  4.95s/it]
100%|██████████| 12/12 [01:13<00:00,  6.14s/it]
[epoch 10] train_loss: 0.966  val_accuracy: 0.673
train epoch[11/30] loss:0.867: 100%|██████████| 104/104 [08:33<00:00,  4.94s/it]
100%|██████████| 12/12 [01:13<00:00,  6.13s/it]
[epoch 11] train_loss: 0.926  val_accuracy: 0.659
train epoch[12/30] loss:0.804: 100%|██████████| 104/104 [08:34<00:00,  4.94s/it]
100%|██████████| 12/12 [01:13<00:00,  6.14s/it]
[epoch 12] train_loss: 0.916  val_accuracy: 0.665
train epoch[13/30] loss:0.377: 100%|██████████| 104/104 [08:35<00:00,  4.96s/it]
100%|██████████| 12/12 [01:13<00:00,  6.14s/it]
[epoch 13] train_loss: 0.879  val_accuracy: 0.648
train epoch[14/30] loss:0.588: 100%|██████████| 104/104 [08:35<00:00,  4.95s/it]
100%|██████████| 12/12 [01:13<00:00,  6.16s/it]
[epoch 14] train_loss: 0.841  val_accuracy: 0.676
train epoch[15/30] loss:0.725: 100%|██████████| 104/104 [08:35<00:00,  4.96s/it]
100%|██████████| 12/12 [01:13<00:00,  6.13s/it]
[epoch 15] train_loss: 0.830  val_accuracy: 0.687
train epoch[16/30] loss:0.977: 100%|██████████| 104/104 [08:35<00:00,  4.96s/it]
100%|██████████| 12/12 [01:13<00:00,  6.14s/it]
[epoch 16] train_loss: 0.811  val_accuracy: 0.720
train epoch[17/30] loss:0.923: 100%|██████████| 104/104 [08:34<00:00,  4.95s/it]
100%|██████████| 12/12 [01:13<00:00,  6.14s/it]
[epoch 17] train_loss: 0.796  val_accuracy: 0.703
train epoch[18/30] loss:1.150: 100%|██████████| 104/104 [08:34<00:00,  4.95s/it]
100%|██████████| 12/12 [01:13<00:00,  6.15s/it]
[epoch 18] train_loss: 0.794  val_accuracy: 0.720
train epoch[19/30] loss:0.866:  19%|█▉        | 20/104 [01:54<07:59,  5.71s/it]

predict.py

import os
import jsonimport torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as pltfrom model import GoogLeNetdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load imageimg_path = "../tulip.jpg"assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)# read class_indictjson_path = './class_indices.json'assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)with open(json_path, "r") as f:class_indict = json.load(f)# create modelmodel = GoogLeNet(num_classes=5, aux_logits=False).to(device)# load model weightsweights_path = "./googleNet.pth"assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)missing_keys, unexpected_keys = model.load_state_dict(torch.load(weights_path, map_location=device),strict=False)model.eval()with torch.no_grad():# predict classoutput = torch.squeeze(model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())plt.title(print_res)for i in range(len(predict)):print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))plt.show()if __name__ == '__main__':main()

预测结果

class: daisy        prob: 0.00207
class: dandelion    prob: 0.00144
class: roses        prob: 0.101
class: sunflowers   prob: 0.00535
class: tulips       prob: 0.89

http://www.lryc.cn/news/261644.html

相关文章:

  • celery/schedules.py源码精读
  • 单片机上位机(串口通讯C#)
  • 初识Flask
  • JeecgBoot jmreport/queryFieldBySql RCE漏洞复现
  • 机器学习---模型评估
  • 【机器学习】应用KNN实现鸢尾花种类预测
  • ACL和NAT
  • MX6ULL学习笔记(十二)Linux 自带的 LED 灯
  • Qt容器QToolBox工具箱
  • 华为实训课笔记
  • 基于java 的经济开发区管理系统设计与实现(源码+调试)
  • 外包干了3个月,技术退步明显。。。
  • 详细教程 - 从零开发 Vue 鸿蒙harmonyOS应用 第一节
  • R语言对医学中的自然语言(NLP)进行机器学习处理(1)
  • 什么是CI/CD?如何在PHP项目中实施CI/CD?
  • 玩转Docker(四):容器指令、生命周期、资源限制、容器化支持、常用命令
  • 回归预测 | MATLAB实现CHOA-BiLSTM黑猩猩优化算法优化双向长短期记忆网络回归预测 (多指标,多图)
  • Qt/C++视频监控安卓版/多通道显示视频画面/录像存储/视频播放安卓版/ffmpeg安卓
  • 【docker】容器使用(Nginx 示例)
  • 【QT】时间日期与定时器
  • 蓝桥杯专题-真题版含答案-【古代赌局】【古堡算式】【微生物增殖】【密码发生器】
  • 和鲸科技携手深圳数据交易所,“数据+数据开发者生态”赋能人工智能产业发展
  • 在MFC(Microsoft Foundation Classes)中 CreateThread函数
  • Ubuntu 常用命令之 ls 命令用法介绍
  • 【解决】Windows 11检测提示电脑不支持 TPM 2.0(注意从DTPM改为PTT)
  • ChatGPT 也宕机了?如何预防 DDOS 攻击的发生
  • wireshark下载安装
  • 如何退回chrome旧版ui界面?关闭Chrome浏览器新 UI 界面
  • 指针进阶篇
  • C语言之单链表理解与应用