当前位置: 首页 > news >正文

Pytorch02 神经网路搭建步骤

文章目录

import numpy as np
import torch
from PIL.Image import Image
from torch.autograd import Variable# 获取数据
def get_data():train_X=np.asarray([3.3,4.4,5.5,6.71,6.93,4.168,9.779,6.182,7.59,2.167,7.042,10.791,5.313,7.997,5.654,9.27,3.1])train_Y=np.asarray([1.7,2.76,2.09,3.19,1.694,1.537,3.366,2.596,2.53,1.221,2.827,3.465,1.65,2.904,2.42,2.94,1.3])dtype=torch.FloatTensorX=Variable(torch.from_numpy(train_X).type(dtype),requires_grad=False).view(17,1)y=Variable(torch.from_numpy(train_Y).type(dtype),requires_grad=False)return X,y
# 随机参数
def get_weights():w=Variable(torch.randn(1),requires_grad=True)b=Variable(torch.randn(1),requires_grad=True)return w,b
w,b=get_weights()
# 模型计算
def simple_network(x):y_pred=torch.matmul(x,w)+breturn y_pred
# 计算损失进行评估
def loss_fn(y,y_pred):loss=(y_pred-y).pow(2).sum()for param in [w,b]:if not param.grad is None:param.grad.data.zero_()loss.backward()return loss.data[0]
# 优化网络
def optimize(learning_rate):w.data-=learning_rate * w.grad.datab.data-=learning_rate * b.grad.data
from torch.utils.data import Dataset
class DogsAndCatsDataset(Dataset):def __init__(self,root_dir,size=(224,224)):self.files=globals(root_dir)self.size=sizedef __len__(self):return len(self.files)def __getitem__(self, item):img=np.asarray(Image.open(self.files[item]).resize(self.size))label=self.files[item].split('/')[-2]return img,label
class myFirstNetwork(torch.nn.Module):def __init__(self,input_size,hidden_size,output_size):super(myFirstNetwork,self).__init__()self.layer1=torch.nn.Linear(input_size,hidden_size)self.layer2=torch.nn.Linear(hidden_size,output_size)def __forward__(self,input):out=self.layer1(input)out=torch.nn.ReLU(out)out=self.layer2(out)return out
http://www.lryc.cn/news/164232.html

相关文章:

  • 【源码】JavaWeb+Mysql招聘管理系统 课设
  • Java中级编程大师班<第一篇:初识数据结构与算法-数组(2)>
  • 杰哥教你面试之一百问系列:java集合
  • 【数据结构】树和二叉树概念
  • C盘清理教程
  • 【实战-05】 flinksql look up join
  • C++数据结构--红黑树
  • Linux perf使用思考
  • 自定义路由断言工厂
  • Nacos安装及在项目中的使用
  • overleaf中latex语法总结
  • Grafana配置邮件告警
  • setup中的nextTick函数
  • Matlab信号处理3:fft(快速傅里叶变换)标准使用方式
  • Python|合并两个字典的几种方法
  • ElementUI浅尝辄止24:Message 消息提示
  • 让照片动起来的软件,轻松制作照片动效
  • 【图解RabbitMQ-7】图解RabbitMQ五种队列模型(简单模型、工作模型、发布订阅模型、路由模型、主题模型)及代码实现
  • Linux命令200例:write用于向特定用户或特定终端发送信息
  • javaee spring整合mybatis spring帮我们创建dao层
  • 修改Tomcat的默认端口号
  • Open3D Ransac拟合空间直线(python详细过程版)
  • 题目:2729.判断一个数是否迷人
  • 微服务模式:服务发现模式
  • 9.4 数据库 TCP
  • 普通用户使用spark的client无法更新Ranger策略
  • Git超详细教程
  • C++ 回调函数
  • xilinx FPGA IOB约束使用以及注意事项
  • 如何统计iOS产品不同渠道的下载量?