pytorch 手写数字识别1
创始人
2024-02-23 17:55:06
0

目录

  1.       概述
  2.       加载图片
  3.       绘图部分
  4.       backward

前言:

       这里以一个手写数字识别的例子,简单了解一下pytorch 实现神经网络的过程.

本章重点讲一下加载数据过程

参考:

课时9 手写数字识别初体验-1_哔哩哔哩_bilibili

Pytorch中的backward函数 - 知乎


一  概述

    整体流程如下,分为四步

 


二   加载图片

     如下为加载minist 数据集过程

    

# -*- coding: utf-8 -*-
"""
Created on Thu Nov 24 17:17:19 2022@author: chengxf2
"""
import torchvisionfrom matplotlib  import pyplot as plt
import torch
import torchvision.transforms as transforms
import torchvision.datasets
from util import plot_curve,plot_image'''
root : 需要下载地址的根目录位置
train: True  下载训练集trainin.pt  False 下载test.pt
transform: 一系列作用在PIL 图片上的转换操作,返回一个转换版本
dowenload: 是否下载到root 指定的位置
transforms.Compose(): 将多个预处理依次累加在一起, 每次执行transform都会依次执行其中包含的多个预处理程序
transforms.ToTensor():在做数据归一化之前必须要把PIL Image转成Tensor
transforms.Normalize([0.5], [0.5]):归一化,这里的两个0.5分别表示对张量进行归一化的 全局平均值和方差,因为图像是灰色的只有一个通道,所以分别指定一了一个值,如果有多个通道,需要有多个数字,如3个通道,就应该是Normalize([m1, m2, m3], [n1, n2, n3])
'''
def load_data(batch =512):transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.1307], [0.3018])])train_dataset = torchvision.datasets.MNIST('mnist_data', train=True, transform=transform, download=True)test_dataset = torchvision.datasets.MNIST('mnist_data/', train=False, transform=transform, download=False)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch, shuffle=True)test_loader = torch.utils.data.DataLoader(test_dataset)print("\n --end--",type(train_loader))return train_loader, test_loaderdef show(data):#递归所有的元素for step, (x,y) in  enumerate(data):print("\n step ",step,y.shape) #512###单独取一个###   x,y = next(iter(train_loader))print(x.shape, y.shape)print(x.min(), x.max(),type(x)) #Tensorplot_image(x,y,'image sample')
if __name__  =="__main__":train_loader , test_loader = load_data()show(train_loader)

三  绘图部分


​
# -*- coding: utf-8 -*-
"""
Created on Mon Nov 21 17:16:16 2022@author: chengxf2
"""
import torchfrom  matplotlib  import pyplot as  pltdef plot_curve(data):#画训练过程的lossfig = plt.figure()N= len(data)plt.plot(range(N),data, color='green')plt.legend(['value'], loc='up right')plt.xlabel('step')plt.ylabel('value')plt.show()def plot_image(img, label, name):#画图片    fig = plt.figure()for i in range(6):plt.subplot(2,3,i+1) #t(nrows ncols plot_number)plt.tight_layout() #会自动调整子图参数,使之填充整个图像区域plt.imshow(img[i][0]*0.3081+0.1307,interpolation ='none')   plt.title("{}:{}".format(name, label[i].item()))plt.xticks([])plt.yticks([])plt.show()'''
生成one-hot
Tensor.scatter_(dim, index, src, reduce=None) → TensorParameters
scatter_(dim, index, src): 将src中所有的值分散到self 中,填法是按照index中所指示的索引来填入。
dim (int) – the axis along which to indexdim=0,按照index行索引的指示来进行散射dim=1 ,按照index列索引的指示来进行散射
index (LongTensor) – the indices of elements to scatter, can be either empty or of the same dimensionality as src. When empty, the operation returns self unchanged.
src (Tensor or float) – the source element(s) to scatter. 要填进去的元素
reduce (str, optional) – reduction operation to apply, can be either 'add' or 'multiply'. 用的相对较少。'''
def one_hot(label, depth=10):N = label.size(0)print("\n n:",N)out = torch.zeros(N, depth)idx = torch.LongTensor(label).view(-1,1)out.scatter_(dim=1, index=idx, value=1)print("\n out ",out)return out#label =[1,2,5]
#label = torch.LongTensor(label)
#one_hot(label)​

四  Pytorch 中的backward 

     Numpy ,pytorch 可以自己实现反向传播算法,也可以使用pytorch给的API,通过动态图自动

求导

 这里面给出3个例子3.1  

3.1 简单的LR 模型

      

   \hat{y}=xw^T

  L=\frac{(\hat{y}-y)^2}{2}

  梯度:

   \frac{\partial L}{\partial w}=(\hat{y}-y)x

Created on Tue Nov 22 14:58:50 2022@author: chengxf2
"""
import torch
from torch.autograd import Variable'''
自动求梯度例子1
'''
def grad():x = torch.tensor([2.0,1.0],requires_grad=True) w = torch.tensor([1.0,2.0],requires_grad=True) y = torch.matmul(w, x.T)L = (y-1.0)**2/2.0print("\n L ",L)L.backward()print(w.grad)grad()

  bias tensor(3., grad_fn=)
  tensor([6., 3.])

相关内容

热门资讯

「长镜头」从合约纠纷到经济犯罪... 本报(chinatimes.net.cn)记者于玉金 北京报道 近几个月,头部艺人与经纪公司之间的解...
生产车间遭“非法入场,被拍摄、... 据海正生材12月19日晚间披露,公司收到台州市椒江区人民法院《出庭通知书》,被告人钱某某、郑某某与附...
为医美消费纠纷提供指引 山东省... 12月19日,山东省消费者协会与山东省民营整形美容协会在济南联合召开媒体通气会,发布《承诺效果型医疗...
深圳核发首张房票,四个一线城市... 据新华社消息,18日,深圳市在城市轨道交通27号线工程(西丽段)涉及西丽福光楼土地整备项目中核发首张...
原创 汪... 最近,汪小菲在社交媒体上掀起了一波波热议。他宣布要起诉抖音的一位副总裁,并且公开了与前妻大S关于孩子...
原创 揪... 现在买新能源车,谁都被大尺寸中控屏吸引。刷剧、打游戏、听音乐,指尖一划全搞定。车机就像移动娱乐厅。大...
中公教育(002607)披露新... 截至2025年12月19日收盘,中公教育(002607)报收于2.82元,较前一交易日上涨3.68%...
原创 比... 最近国际金融圈炸了个大雷,欧盟刚宣布把俄罗斯的钱“无限期看管”,俄罗斯就直接把官司砸了过来,一开口就...
ST岭南最新公告:公司及控股子... ST岭南(002717.SZ)公告称,截至2025年12月18日,公司及控股子公司连续十二个月内新增...
瑞茂通(600180)全资子公... 瑞茂通(600180)12月20日公告,公司旗下全资子公司河南瑞茂通粮油有限公司于近日收到河南省郑州...