pytorch深度学习实战lesson25
创始人
2024-02-09 16:45:47
0

第二十五课 network in network(NIN)

NIN 叫做network in network或者叫做网络中的网络。这个网络现在用的不多,几乎很少被用到。但是它里面提出了比较重要的概念,在之后很多网络都会被持续的用到。所以今天认识一下这一个网络。

目录

理论部分

实践部分


理论部分

在 alexnet 和 vgg 的时候都在最后用了比较大的全连接层,在 vgg 和alexnet都是一样的,用了两个4096的全链阶层,最后通过一个全链阶层作为输出。这些全连阶层的参数其实特别占用空间,也会占用很多的计算带宽,它还很容易会产生过拟合。

           它首先有一个卷积层,然后跟了两个全连阶层,其实1乘1的卷积层可以等价是一个全链阶层,具体来说1乘1的卷积层也就是窗口的大小是1乘1、步幅为1,无填充的卷积层,这个卷积层不会改变输入的形状,也不会改变通道数。所有1乘1的卷积层可以当做全连阶层来使用,它的作用就是对每个通道数帮你做一些混合。

       就是说我的池化层的高宽是等于输入的高宽,等价于把每一个通道最大的值给拿出来,再加个 softmax 就会得到我们的概率了。

上图是vgg架构和nin架构的对比图,vgg 就是有四个 vgg块,再加上两个大的全连接层最后得到输出类是1000类;那么 NIN的话主要由nin 块和一个步幅为2的最大池化层组成,不断重复这一个过程,直到最后如果把通道数设成分类个数的话,那么最后直接用全局的平均池化层来得到输出对每一个类的预测即可。

所以整体来讲就是 nin 架构比较简单,就是 nin块 加上最大池化层一直到最后一个全局的平均池化层。而且它的通道参数个数非常少,少是因为整个就没有全链阶层。这就是nin网络。

实践部分

nin与Alex net对比一下。发现nin精度(0.83)还没有之前Alexnet(0.88)高,然后nin的速度是也没有比 alexnet高太多,这是因为nin额外加入了大量的1乘1的卷积层,会使得计算会变慢。然后也因为数据集相对来说比较少。

代码:

#网络中的网络(NiN)
#NiN块
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt
def nin_block(in_channels, out_channels, kernel_size, strides, padding):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1),#1*1卷积层使得输入输出通道个数一样nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1),#1*1卷积层使得输入输出通道个数一样nn.ReLU())
#NiN模型
net = nn.Sequential(nin_block(1, 96, kernel_size=11, strides=4, padding=0),#一个nin块nn.MaxPool2d(3, stride=2),#加一个最大池化层,卷积核维度为3,步长为2nin_block(96, 256, kernel_size=5, strides=1, padding=2),nn.MaxPool2d(3, stride=2),nin_block(256, 384, kernel_size=3, strides=1, padding=1),nn.MaxPool2d(3, stride=2),nn.Dropout(0.5),#把一半的权值设为0,减少计算量nin_block(384, 10, kernel_size=3, strides=1, padding=1),#最后一个nin块的输出通道数要等于类别数nn.AdaptiveAvgPool2d((1, 1)),#全局平均池化层,高宽都为1nn.Flatten())#把最后两个维度直接消掉,就变成了一个 backsize 乘以10的矩阵。这个东西就可以直接softmax回归
#查看每个块的输出形状
X = torch.rand(size=(1, 1, 224, 224))
for layer in net:X = layer(X)print(layer.__class__.__name__, 'output shape:\t', X.shape)
#训练模型
lr, num_epochs, batch_size = 0.1, 10, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
plt.show()

Sequential output shape:     torch.Size([1, 96, 54, 54])
MaxPool2d output shape:     torch.Size([1, 96, 26, 26])
Sequential output shape:     torch.Size([1, 256, 26, 26])
MaxPool2d output shape:     torch.Size([1, 256, 12, 12])
Sequential output shape:     torch.Size([1, 384, 12, 12])
MaxPool2d output shape:     torch.Size([1, 384, 5, 5])
Dropout output shape:     torch.Size([1, 384, 5, 5])
Sequential output shape:     torch.Size([1, 10, 5, 5])
AdaptiveAvgPool2d output shape:     torch.Size([1, 10, 1, 1])
Flatten output shape:     torch.Size([1, 10])
training on cuda:0


。。。。
loss 0.369, train acc 0.863, test acc 0.853
1226.8 examples/sec on cuda:0

进程已结束,退出代码0

相关内容

热门资讯

原创 就... 【军武次位面】作者:乐乐 日前,美国“Military Watch”网站报道称,中国海军一艘“基洛”...
输球又输点!阿森纳赛后点球3-... 在刚刚结束的季前友谊赛中,阿森纳以2-3不敌比利亚雷亚尔,随后在点球大战中以3-4失利,令人意外的是...
8月1日起乌鲁木齐天山国际机场... 2025年8月1日起,新疆机场集团乌鲁木齐天山国际机场将迎来一项关键服务升级:所有国内出港航班值机手...
“我就在这儿坐着怎么了”,火车... 安全乘车,文明出行,是每一位公民应尽的责任和义务。近日,旅客李某持无座车票强占其他旅客座位,经乘警多...
黑龙江省制定出台20条政策措施... 近日,黑龙江制定出台支持高端智能农机装备产业高质量发展20条政策措施。旨在引导产学研用等各方用好国家...
债券利息收入增值税新规落地在即... 债券利息收入税收新规实施前夕,政策性银行密集发行金融债。 8月5日,中国债券信息网披露的信息显示,中...
静乐县公安局征集“六霸”及殡葬... 为深入开展群众身边不正之风和腐败问题集中整治,严厉打击“六霸”及殡葬等领域涉民生违法犯罪,现向社会各...
原创 欧... 欧洲媒体在8月5日的报道中提到,美国与欧洲似乎达成一致,准备联合打压俄罗斯石油的主要买家——中国和印...
普京与美特使聊了3小时之后,特... 来源:视觉中国 俄罗斯总统普京与美国特使威特科夫的会晤在持续近3小时后结束。 据新华社报道,俄总统助...