当前位置: 首页 > news >正文

2.线性神经网络--Softmax回归

2.1 从零实现Softmax回归

#数据集导入
import torch
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import DataLoader
###################################################################################################################
def get_fashion_mnist_labels(labels):  #@save"""返回Fashion-MNIST数据集的文本标签"""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']labels=text_labels[labels]return labels
def draw_fashion_mnist(row,col,num_fig,img_data):fig,axes=plt.subplots(row,col,figsize=(row*col,4))for i in range(num_fig):img,label=img_data[i]label=get_fashion_mnist_labels(label)img=img.squeeze(0)row_d,col_d=divmod(i,col)ax=axes[row_d][col_d]ax.imshow(img.numpy(),cmap='coolwarm')ax.set_title(f"{label}")ax.axis('off')plt.tight_layout()plt.show()
#softmax
def Softmax(X):X_exp=torch.exp(X)sum_exp=X_exp.sum(dim=1,keepdim=True)#按列方向加return X_exp/sum_expdef softmax_model(X,w,b):y_hat=torch.matmul(X.reshape((-1,w.shape[0])),w)+breturn Softmax(y_hat)
#-log(pi)
def crossentropyloss(y_hat,y):loss=-torch.log(y_hat[range(len(y_hat)),y]).mean()return loss
def accuray_score(y_hat,y):y_pred=y_hat.argmax(axis=1)cls=(y_pred==y)return cls.float().sum().item()
def sgd(params,lr,batch_size):with torch.no_grad():for param in params:param-=lr*param.grad/batch_sizeparam.grad.zero_()
###################################################################################################################
transforms=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))#第一个是mean,第二个是std])
train_img=torchvision.datasets.FashionMNIST(root="../data",train=True,transform=transforms,download=True)
test_img=torchvision.datasets.FashionMNIST(root="../data",train=False,transform=transforms,download=True)
###################################################################################################################
train_data=DataLoader(train_img,batch_size=200,num_workers=4,shuffle=True)
test_data=DataLoader(test_img,batch_size=200,num_workers=4)
#draw_fashion_mnist(row=2,col=9,num_fig=18,img_data=train_img)
#draw_fashion_mnist(row=2,col=9,num_fig=18,img_data=test_img)#参数初始化
num_inputs=784
num_outputs=10#num_class=10
lr=0.1
num_epochs=10
w=torch.normal(0,0.1,size=(num_inputs,num_outputs),requires_grad=True)
b=torch.zeros(num_outputs,requires_grad=True)for epoch in range(num_epochs):total_loss=0total_acc_sample=0total_samples=0loop=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")for X,y in loop:X=X.reshape(X.shape[0],-1)y_hat=softmax_model(X,w,b)loss=crossentropyloss(y_hat,y)loss.backward()sgd([w,b],lr,batch_size=X.shape[0])#loss累加total_loss+=loss.item()*X.shape[0]total_acc_sample+=accuray_score(y_hat,y)#保存样本数total_samples+=X.shape[0]test_acc_samples=0test_samples=0for X,y in test_data:X=X.reshape(X.shape[0],-1)y_hat=softmax_model(X,w,b)test_acc_samples+=accuray_score(y_hat,y)#保存样本数test_samples+=X.shape[0]print(f"Epoch {epoch+1}: Loss: {total_loss/total_samples:.4f},Trian Accuracy: {total_acc_sample/total_samples:.4f},test Accuracy: {test_acc_samples/test_samples:.4f}")
###################################################################################################################

2.2 简洁实现Softmax回归

#数据集导入
import torch
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
###################################################################################################################
transforms=transforms.Compose([transforms.Resize(28),transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))#第一个是mean,第二个是std])
train_img=torchvision.datasets.FashionMNIST(root="../data",train=True,transform=transforms,download=True)
test_img=torchvision.datasets.FashionMNIST(root="../data",train=False,transform=transforms,download=True)
###################################################################################################################
train_data=DataLoader(train_img,batch_size=200,num_workers=4,shuffle=True)
test_data=DataLoader(test_img,batch_size=200,num_workers=4)
#draw_fashion_mnist(row=2,col=9,num_fig=18,img_data=train_img)
#draw_fashion_mnist(row=2,col=9,num_fig=18,img_data=test_img)
#参数初始化
num_inputs=784
num_outputs=10#num_class=10
lr=0.01
num_epochs=10
w=torch.normal(0,0.1,size=(num_inputs,num_outputs),requires_grad=True)
b=torch.zeros(num_outputs,requires_grad=True)
model=nn.Sequential(nn.Flatten(),nn.Linear(num_inputs,num_outputs))
optimizer=torch.optim.SGD(model.parameters(),lr=lr)
CEloss=nn.CrossEntropyLoss()for epoch in range(num_epochs):total_loss=0total_acc_sample=0total_samples=0loop=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")for X,y in loop:X=X.reshape(X.shape[0],-1)y_hat=model(X)loss=CEloss(y_hat,y)optimizer.zero_grad()loss.backward()optimizer.step()#loss累加total_loss+=loss.item()*X.shape[0]y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()total_acc_sample+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数total_samples+=X.shape[0]test_acc_samples=0test_samples=0for X,y in test_data:X=X.reshape(X.shape[0],-1)y_hat=model(X)y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()test_acc_samples+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数test_samples+=X.shape[0]print(f"Epoch {epoch+1}: Loss: {total_loss/total_samples:.4f},Trian Accuracy: {total_acc_sample/total_samples:.4f},test Accuracy: {test_acc_samples/test_samples:.4f}")
###################################################################################################################
http://www.lryc.cn/news/581661.html

相关文章:

  • 算法分析与设计实验1:实现两路合并排序和折半插入排序
  • 3.8 java连接数据库
  • Vue2 day07
  • 工业相机和镜头
  • 基于Java+SpringBoot的医院信息管理系统
  • ARM 学习笔记(一)
  • 文心开源大模型ERNIE-4.5-0.3B-Paddle私有化部署保姆级教程及技术架构探索
  • 【学习笔记】4.1 什么是 LLM
  • 编程语言艺术:C语言中的属性attribute笔记总结
  • 程序员在线接单
  • 浅谈漏洞扫描与工具
  • 大型语言模型中的自动化思维链提示
  • 【数据分析】R语言多源数据的基线特征汇总
  • 玄机——第三章 权限维持-linux权限维持-隐藏练习
  • Dify+Ollama+QwQ:3步本地部署,开启AI搜索新篇章
  • 实现Spring MVC登录验证与拦截器保护:从原理到实战
  • 【机器学习深度学习】 如何解决“宏平均偏低 / 小类识别差”的问题?
  • HRDNet: High-resolution Detection Network for Small Objects论文阅读
  • mac中创建 .command 文件,执行node服务
  • Omi录屏专家 Screen Recorder by Omi 屏幕录制Mac
  • 【Linux】基础开发工具(1)
  • 开发项目时遇到的横向越权、行锁表锁与事务的关联与区别、超卖问题
  • Java学习——Lombok
  • Anaconda 常用命令
  • 【Elasticsearch】自定义评分检索
  • 【卫星语音】基于神经网络的低码率语音编解码(ULBC)方案架构分析:以SoundStream为例
  • Maven引入第三方JAR包实战指南
  • Day06- (使用asyncio进行异步编程:事件循环和协程)
  • 群晖 DS3617xs DSM 6.1.7 解决 PhotoStation 安装失败问题 PHP7.0
  • 数据结构---B+树