首页 >> 大全

Pytorch之ResNet图像分类

2023-11-21 大全 33 作者:考证青年

目录

前言

一、网络结构

1.结构

2.BN(Batch )层

二、网络结构

三、网络实现

1.构建网络

2.加载数据集

3.训练和测试模型

四、实现图像分类

前言

2015 年,微软亚洲研究院何凯明等人发表了基于 Skip 的深度残差网络( ,简称 )算法,并提出了 18 层、34 层、50 层、101层、152 层的 -18、-34、-50、-101 和 -152 等模型,甚至成功训练出层数达到 1202 层的极深层神经网络,斩获当年竞赛中分类任务第一名,目标检测第一名。获得COCO数据集中目标检测第一名,图像分割第一名。 论文至今已经获得超 25000的引用量,可见 在人工智能行业的影响力。

一、网络结构

1.结构

、VGG、 等网络模型的出现将神经网络的发展带入了几十层的阶段,研究人员发现网络的层数越深,越有可能获得更好的泛化能力。

神经网络越深的卷积层理论上可以提取更多的图像特征,但是事实结果并不是,如下图:

从上图中可以看出随着层数的增加,预测效果反而越来越差,网络层数越深,训练误差越高,导致训练和测试效果变差,这一现象称为退化。

为了解决深层网络中的退化问题,给深层神经网络添加一种回退到浅层神经网络的机制。当深层神经网络可以回退到浅层神经网络时,深层神经网络可以获得和浅层神经网络相当的模型性能,而不至于更糟糕。

这种神经网络被称为残差网络 ()。论文提出了结构(残差结构)来减轻退化问题。

block有两种,一种两层结构,一种三层结构。

左图为是以两个3*3的卷积网络串接在一起作为一个残差模块,主分支和支路维度都一直保持相同。

右图为是1*1、3*3、1*1的3个卷积网络串接在一起作为一个残差模块。第一层的1× 1的卷积核的作用是对特征矩阵进行降维操作,将特征矩阵的深度由256降为64; 第三层的1× 1的卷积核是对特征矩阵进行升维操作,将特征矩阵的深度由64升成256。降低特征矩阵的深度主要是为了减少参数的个数。

先降后升为了主分支上输出的特征矩阵和分支上输出的特征矩阵形状相同,以便进行加法操作。一般搭建深层次网络时,采用三层残差结构。

通过在卷积层的输入和输出之间添加 Skip 实现层数回退机制,如下图上所示,输入x通过两个卷积层,得到特征变换后的输出ℱ(x),与输入x进行对应元素的相加运算,得到最终输出ℋ(x):ℋ(x) = x+ ℱ(x)。

ℋ(x)叫作残差模块( Block,简称 )。由于被 Skip 包围的卷积神经网络需要学习映射ℱ(x) = ℋ(x) − x,故称为残差网络。

为了能够满足输入x与卷积层的输出ℱ(x)能够相加运算,需要输入x的 shape 与ℱ(x)的shape 完全一致。

当出现 shape 不一致时,一般通过在 Skip 上添加额外的卷积运算环节将输入x变换到与ℱ(x)相同的 shape,如图上图中(x)函数所示,其中(x)以 × 的卷积运算居多,如1*1卷积,对进行升维操作,调整输入的通道数。

注意,让x和ℱ(x)相加,即特征矩阵对应的位置上的数字进行相加,与中的拼接不一样,是在维度上直接进行的拼接,并不是相加。

下图是和的网络结构

从残差网络结构中给出了两种连接,分别是实线连接和虚线连接。如下图所示,

实线残差结构:对应残差模块而言,输入特征矩阵和输出特征矩阵形状大小相同,能够直接相加。

虚线残差结构:输入特征矩阵和输出特征矩阵不能直接相加,输入特征矩阵需要经过分支上的1×1的卷积核进行了维度处理(特征矩阵在长宽方向降采样,深度方向调整成下一层残差结构所需要的)。

/34实线/虚线残差结构图

/101/152实线/虚线残差结构图

不同深度的网络结构配置,注意表中的残差结构给出了主分支上卷积核的大小与卷积核个数,表中残差块×N 表示将该残差结构重复N次。

, , 所对应的一系列残差结构的第一层残差结构都是虚线残差结构。因为这一系列残差结构的第一层都有调整输入特征矩阵shape的作用(将特征矩阵的高和宽缩减为原来的一半,将深度调整成下一层残差结构所需要的)

注意,对于/101/152,其实所对应的一系列残差结构的第一层也是虚线残差结构,因为它需要调整输入特征矩阵的。

根据表格可知通过3x3的max pool之后输出的特征矩阵shape应该是[56, 56, 64],但所对应的一系列残差结构中的实线残差结构它们期望的输入特征矩阵shape是[56, 56, 256](因为这样才能保证输入输出特征矩阵shape相同,才能将分支的输出与主分支的输出进行相加)。所以第一层残差结构需要将shape从[56, 56, 64] --> [56, 56, 256]。

注意,这里只调整维度,高和宽不变,而, , 所对应的一系列残差结构的第一层虚线残差结构不仅要调整还要将高和宽缩减为原来的一半。

下图是使用结构的卷积网络,可以看到随着网络的不断加深,效果并没有变差,而是变的更好了。(虚线是train error,实线是test error)

2.BN(Batch )层

同时当模型加深以后,网络变得越来越难训练,这主要是由于梯度消失和梯度爆炸现象造成的。在较深层数的神经网络中,梯度信息由网络的末层逐层传向网络的首层时,传递的过程中会出现梯度接近于 0 或梯度值非常大的现象。网络层数越深,这种现象可能会越严重。

梯度消失和梯度爆炸产生原因:

梯度消失:若每一层的误差梯度小于1,反向传播时,网络越深,梯度越趋近于0

梯度爆炸:若每一层的误差梯度大于1,反向传播时,网络越深,梯度越来越大

那么怎么解决深层神经网络的梯度弥散和梯度爆炸现象呢?

为了解决梯度消失或梯度爆炸问题,论文提出通过数据的预处理以及在网络中使用 BN(Batch )层来解决。

Batch 是指批标准化处理,将一批数据的 map满足均值为0,方差为1的分布规律。

在图像预处理过程中通常会对图像进行标准化处理,这样能够加速网络的收敛,如下图所示,对于Conv1来说输入的就是满足某一分布的特征矩阵,但对于Conv2而言输入的 map就不一定满足某一分布规律。

注意这里所说满足某一分布规律并不是指某一个 map的数据要满足分布规律,理论上是指整个训练样本集所对应 map的数据要满足分布规律。Batch 的目的就是使数据的 map满足均值为0,方差为1的分布规律。

“ 对于一个拥有d维的输入x,我们将对它的每一个维度进行标准化处理。” 假设我们输入的x是RGB三通道的彩色图像,那么这里的d就是输入图像的即d=3,

,其中

就代表我们的R通道所对应的特征矩阵,依此类推。标准化处理也就是分别对我们的R通道,G通道,B通道进行处理。

处理公式如下:

让 map满足某一分布规律,理论上是指整个训练样本集所对应 map的数据要满足分布规律,即要计算出整个训练集的 map然后在进行标准化处理,对于一个大型的数据集明显是不可能的。

所以论文中的Batch ,指的是计算一个Batch数据的 map然后在进行标准化(batch越大越接近整个数据集的分布,效果越好)。

根据上图的公式可以知道

代表着计算的 map每个维度()的均值,注意是一个向量不是一个值,向量的每一个元素代表着一个维度()的均值。

代表着我们计算的 map每个维度()的方差,注意是一个向量不是一个值,向量的每一个元素代表着一个维度()的方差,然后根据

计算标准化处理后得到的值。示例如下所示:

在原论文公式中还有

两个参数,

是用来调整数值分布的方差大小,

是用来调节数值均值的位置。这两个参数是在反向传播过程中学习得到的,

的默认值是1,

的默认值是0。

使用BN需要注意:

1.训练时要将参数设置为True,在验证时将参数设置为False。在中可通过创建模型的model.train()和model.eval()方法控制。

2.batch size尽可能设置大点,设置小后表现可能很糟糕,设置的越大求的均值和方差越接近整个训练集的均值和方差。

3.一般将BN层放在卷积层(Conv) 和激活层(Relu) 之间,且卷积层不要使用偏置bias。在有无偏置时推导出来的结果一样,使用反而增加运算效率。

图像分类pytorch代码_图像分类pytorch原理_

二、网络结构

( Next)是一种深度神经网络架构,它是对残差网络( ,通常简称为)的扩展和改进。 的设计目标是提高网络的性能和效率,特别是在大规模图像分类任务上表现出色。

网络的一些关键特点:

1.基于残差连接: 仍然基于残差连接,这是 的核心思想之一。残差连接允许信息在网络中跳跃传递,有助于解决梯度消失问题,使得更深的网络可以更容易地训练。

2.分组卷积: 引入了分组卷积( )的概念,这是与传统卷积不同的一种卷积操作。在分组卷积中,卷积核被分为多个组(),每个组对输入数据执行卷积操作。这种结构允许网络在不增加参数数量的情况下增加模型的宽度,从而提高了性能。

3.(基数): 中引入了一个称为 "" 的参数,用于指定分组卷积中的组数。通过调整基数,可以控制模型的宽度,从而平衡模型的性能和计算成本。通常,较大的基数可以提高性能,但也增加了计算成本。

分组卷积(Group ):

假设输入特征矩阵等于4,分为两个组,对每个组分别进行卷积操作,假设对每个Group使用n/2个卷积核,通过第每个Group的卷积可以得到对应的是n/2的特征矩阵,再对两组进行拼接,那么最终特征矩阵得到的是n。

若假设输入矩阵的等于cin,对输入特征矩阵分为g个组,那么对于每个group而言,每个group采用卷积核的参数是(k×k×cin/g×n/g)。当g=cin,n=cin,这就相当于对输入特征矩阵的每一个分配了一个为1的卷积核进行卷积。

中的残差模块(右图),其等价表示如下图三种形式。

a:第三层对每个分支: 先通过1*1的卷积,然后在进行相加(和b的第三层先通过拼接,在进行1*1卷积等价)

b:第一层有32个分支(32*4=128,与c第一层等价),第二层和c的group卷积一样,对于每个分支可理解为一个Group

c:先通过1*1卷积层进行降维处理(128),再通过group对它进行处理(group数32,大小3*3,输出),最后通过一个1*1的卷积升维

网络结构:将网络中的残差模块替换为模块中的残差模块即可。

C()=32是Group数,4对应每个组卷积核的个数。注意,只有block层数大于等于3的时,才能构建出一个比较有意义的block,对之前的浅层block而言,还是使用的block。

三、网络实现 1.构建网络

根据 block的两种结构搭建残差结构。

# resnet18/34 block
class BasicBlock(nn.Module):expansion = 1   # 主分支上卷积核的个数是否相同, BasicBlock第一层和第二层相同,Bottleneck第三层特征矩阵的维度是第一层的4倍def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):super(BasicBlock, self).__init__()# 实线残差结构:stride=1;虚线残差结构:stride = 2(特征图减半)self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=1, bias=False)  # 在使用BN层使,无biasself.bn1 = nn.BatchNorm2d(out_channel)self.relu = nn.ReLU()self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channel)self.downsample = downsample  # 虚线残差结构:shortcut分支的下采样def forward(self, x):identity = x# 虚线残差结构:输入特征矩阵和输出特征矩阵不能直接相加,输入特征矩阵需要经过shortcut分支上的1×1的卷积核进行了维度处理if self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)# 相加out += identityout = self.relu(out)return out# resnet50/101/152 block
class Bottleneck(nn.Module):"""注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,这么做的好处是能够在top1上提升大概0.5%的准确率。可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch"""# 主分支上卷积核的个数是否相同, BasicBlock第一层和第二层相同,Bottleneck第三层特征矩阵的维度是第一层的4倍expansion = 4# in_channel:传入Bottleneck的输入通道数,out_channel:中间3x3所使用卷积核的个数def __init__(self, in_channel, out_channel, stride=1, downsample=None, groups=1, width_per_group=64):super(Bottleneck, self).__init__()# ResNet网络不使用Grouped Convolution,groups、width_per_group使用默认参数width = out_channel# ResNeXt网络使用Grouped Convolution,groups=32、width_per_group=4使用默认参数width = 2*out_channel# ResNeXt50的输出特征矩阵通道数ResNet50的2倍width = int(out_channel * (width_per_group / 64.)) * groupsself.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width, kernel_size=1, stride=1, bias=False)  # squeeze channelsself.bn1 = nn.BatchNorm2d(width)# 第二层:实线残差结构:stride=1;虚线残差结构:stride = 2(特征图减半)self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups, kernel_size=3, stride=stride, bias=False, padding=1)self.bn2 = nn.BatchNorm2d(width)# 第三层的输出特征是第一层的特征输出的4(expansion)倍self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion, kernel_size=1, stride=1, bias=False)  # unsqueeze channelsself.bn3 = nn.BatchNorm2d(out_channel*self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += identityout = self.relu(out)return out

然后就可以根据论文中和网络结构表格搭建网络,

# 残差网络结构
class ResNet(nn.Module):def __init__(self,block,  # 残差模块:BasicBlock / Bottleneckblocks_num,  # conv2/3/4/5_x 残差模块的数量,查看论文中给出的网络配置表格num_classes=1000,include_top=True,  # 用于外部模块调用ResNet网络groups=1,width_per_group=64):super(ResNet, self).__init__()self.include_top = include_topself.in_channel = 64   # conv_2的输入通道数是64,经过7x7卷积和3x3最大池化层后维度为64self.groups = groupsself.width_per_group = width_per_groupself.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(self.in_channel)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# 堆叠残差模块: conv2/3/4/5_xself.layer1 = self._make_layer(block, 64, blocks_num[0])   # conv2_Xself.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)  # conv3_Xself.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)  # conv4_Xself.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)  # conv5_Xif self.include_top:self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')# 用于生产conv2/3/4/5_x各层配置# block:选用的残差结构的模块,resnet18/34:BasicBlock resnet50/101/152:Bottleneck# channel:conv2/3/4/5_x各层第一层的输入通道数# block_num:conv2/3/4/5_x各层堆叠的次数# stride:默认为1,从conv3_x开始stride=2def _make_layer(self, block, channel, block_num, stride=1):downsample = None# 对于resnet18/34不执行该语句,50,101,152:conv2_x列残差结构的第一层也是虚线残差结构,需要调整输入特征矩阵的channelif stride != 1 or self.in_channel != channel * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(channel * block.expansion))layers = []# 第一层残差结构layers.append(block(self.in_channel,channel,downsample=downsample,stride=stride,groups=self.groups,width_per_group=self.width_per_group))# 更新下一层残差结构的输入通道self.in_channel = channel * block.expansion# conv2/3/4/5_x的第二层残差结构都为实线残差结构for _ in range(1, block_num):layers.append(block(self.in_channel,channel,groups=self.groups,width_per_group=self.width_per_group))return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)if self.include_top:x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return x

使用类调用生产/网络:

def resnet34(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet34-333f7ec4.pthreturn ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet50(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet50-19c8e357.pthreturn ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet101(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet101-5d3b4d8f.pthreturn ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)def resnext50_32x4d(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pthgroups = 32width_per_group = 4return ResNet(Bottleneck, [3, 4, 6, 3],num_classes=num_classes,include_top=include_top,groups=groups,width_per_group=width_per_group)def resnext101_32x8d(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pthgroups = 32width_per_group = 8return ResNet(Bottleneck, [3, 4, 23, 3],num_classes=num_classes,include_top=include_top,groups=groups,width_per_group=width_per_group)

2.加载数据集

这里使用花朵数据集,数据集制造和数据集使用的脚本的参考:之花朵分类_风间琉璃•的博客-CSDN博客

加载数据集和测试集,并进行相应的预处理操作。

    data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 数据集根目录data_root = os.path.abspath(os.getcwd())print(os.getcwd())# 图片目录image_path = os.path.join(data_root, "data_set", "flower_data")print(image_path)assert os.path.exists(image_path), "{} path does not exit.".format(image_path)# 准备数据集train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)# 定义一个包含花卉类别到索引的字典:雏菊,蒲公英,玫瑰,向日葵,郁金香# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}# 获取包含训练数据集类别名称到索引的字典,这通常用于数据加载器或数据集对象中。flower_list = train_dataset.class_to_idx# 创建一个反向字典,将索引映射回类别名称cla_dict = dict((val, key) for key, val in flower_list.items())# 将字典转换为格式化的JSON字符串,每行缩进4个空格json_str = json.dumps(cla_dict, indent=4)# 打开名为 'class_indices.json' 的JSON文件,并将JSON字符串写入其中with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint("using {} dataloader workers every process".format(nw))# 加载数据集train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=4, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num, val_num))

3.训练和测试模型

数据集预处理完成后,就可以进行网络模型的训练和验证。

    net = resnet34()# 加载预训练权重# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pthmodel_weight_path = "./resnet34-pre.pth"assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)# 加载预训练的权重,这将使用预先训练的模型参数初始化模型。net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))# for param in net.parameters():#     param.requires_grad = False# 修改全连接层结构in_channel = net.fc.in_features  # 获取全连接层的输入特征维度# 输出为5个类别net.fc = nn.Linear(in_channel, 5)  # 替换全连接层以适应新的分类任务,输出5个类别net.to(device)# 定义损失函数loss_function = nn.CrossEntropyLoss()  # 使用交叉熵损失函数来计算损失# 构建优化器# 使用列表推导式,它遍历了模型中的所有参数,并只选择那些requires_grad为True的参数,# 将它们添加到一个名为params的列表中。params 列表包含了需要计算梯度并进行优化的所有参数。params = [p for p in net.parameters() if p.requires_grad]   # 获取需要梯度更新的模型参数optimizer = optim.Adam(params, lr=0.0001)  # 使用Adam优化器来更新模型参数,学习率为0.0001epochs = 100best_acc = 0.0save_path = './ResNet34.pth'train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()logits = net(images.to(device))loss = loss_function(logits, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))# loss = loss_function(outputs, test_labels)predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')

这里使用了官方的预训练权重,在其基础上训练自己的数据集。

训练的准确率能到达95%左右,官方的预训练权重文件训练一个epoch就能到达90%左右。

四、实现图像分类

利用上述训练好的网络模型进行测试,验证是否能完成分类任务。

def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 与训练的预处理一样data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# 加载图片img_path = 'roses.jpg'assert os.path.exists(img_path), "file: '{}' does not exist.".format(img_path)image = Image.open(img_path)# image.show()# [N, C, H, W]img = data_transform(image)# 扩展维度img = torch.unsqueeze(img, dim=0)# 获取标签json_path = 'class_indices.json'assert os.path.exists(json_path), "file: '{}' does not exist.".format(json_path)with open(json_path, 'r') as f:# 使用json.load()函数加载JSON文件的内容并将其存储在一个Python字典中class_indict = json.load(f)# 加载网络model = resnet34(num_classes=5).to(device)# 加载模型文件weights_path = "./ResNet34.pth"assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)model.load_state_dict(torch.load(weights_path, map_location=device))model.eval()with torch.no_grad():# 对输入图像进行预测output = torch.squeeze(model(img.to(device))).cpu()# 对模型的输出进行 softmax 操作,将输出转换为类别概率predict = torch.softmax(output, dim=0)# 得到高概率的类别的索引predict_cla = torch.argmax(predict).numpy()res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy())draw = ImageDraw.Draw(image)# 文本的左上角位置position = (10, 10)# fill 指定文本颜色draw.text(position, res, fill='red')image.show()for i in range(len(predict)):print("class: {:10}   prob: {:.3}".format(class_indict[str(i)], predict[i].numpy()))if __name__ == '__main__':main()

测试结果:

结束语

感谢阅读吾之文章,今已至此次旅程之终站 。

吾望斯文献能供尔以宝贵之信息与知识也 。

学习者之途,若藏于天际之星辰,吾等皆当努力熠熠生辉,持续前行。

然而,如若斯文献有益于尔,何不以三连为礼?点赞、留言、收藏 - 此等皆以证尔对作者之支持与鼓励也 。

关于我们

最火推荐

小编推荐

联系我们


版权声明:本站内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 88@qq.com 举报,一经查实,本站将立刻删除。备案号:桂ICP备2021009421号
Powered By Z-BlogPHP.
复制成功
微信号:
我知道了