当前位置: 首页 > news >正文

4.权重衰减(weight decay)

4.1 手动实现权重衰减

import torch
from torch import nn
from torch.utils.data import TensorDataset,DataLoader
import matplotlib.pyplot as plt
def synthetic_data(w,b,num_inputs):X=torch.normal(0,1,size=(num_inputs,w.shape[0]))y=X@w+by+=torch.normal(0,0.1,size=y.shape)return X,y
def load_array(data,batch_size,is_train=True):dataset=TensorDataset(*data)return DataLoader(dataset,batch_size=batch_size,shuffle=is_train)
def init_params(num_inputs):w=torch.normal(0,1,size=(num_inputs,1),requires_grad=True)b=torch.zeros(1,requires_grad=True)return [w,b]
def l2_penalty(w):return 0.5*torch.sum(w.pow(2))def linear_reg(X,w,b):return torch.matmul(X,w)+b
def mse_loss(y_hat,y):return (y_hat-y)**2/2
def sgd(params,lr,batch_size):for params in params:params.data-=lr*params.grad/batch_sizeparams.grad.zero_()
def evaluate_loss(net, data_iter, loss):total_loss, total_samples = 0.0, 0for X, y in data_iter:l = loss(net(X), y)total_loss += l.sum().item()total_samples += y.numel()return total_loss / total_samples
n_train,n_test,num_inputs,batch_size=20,100,200,5
true_w,true_b=torch.ones((num_inputs,1))*0.01,0.05
train_data=synthetic_data(true_w,true_b,n_train)
test_data=synthetic_data(true_w,true_b,n_test)
train_iter=load_array(train_data,batch_size)
test_iter=load_array(test_data,batch_size,is_train=False)
w,b=init_params(num_inputs)
net=lambda X:linear_reg(X,w,b)
loss=mse_loss
num_epochs,lr,lambd=10,0.05,3
#animator=SimpleAnimator()
for epoch in range(num_epochs):for X,y in train_iter:l=loss(net(X),y)+lambd*l2_penalty(w)l.sum().backward()sgd([w,b],lr,batch_size)if (epoch+1)%5==0:train_loss=evaluate_loss(net,train_iter,loss)test_loss=evaluate_loss(net,test_iter,loss)#animator.add(epoch+1,train_loss,test_loss)print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f},test Loss: {test_loss:.4f}")
print('w的L2范数是:', torch.norm(w).item())
plt.show()

4.2 简单实现权重衰减

import torch
from torch import nn
from torch.utils.data import TensorDataset,DataLoader
import matplotlib.pyplot as plt
def synthetic_data(w,b,num_inputs):X=torch.normal(0,1,size=(num_inputs,w.shape[0]))y=X@w+by+=torch.normal(0,0.1,size=y.shape)return X,y
def load_array(data,batch_size,is_train=True):dataset=TensorDataset(*data)return DataLoader(dataset,batch_size=batch_size,shuffle=is_train)
def init_params(num_inputs):w=torch.normal(0,1,size=(num_inputs,1),requires_grad=True)b=torch.zeros(1,requires_grad=True)return [w,b]
def l2_penalty(w):return 0.5*torch.sum(w.pow(2))
def linear_reg(X,w,b):return torch.matmul(X,w)+b
def mse_loss(y_hat,y):return ((y_hat-y)**2).sum()/2
def evaluate_loss(net, data_iter, loss):total_loss, total_samples = 0.0, 0for X, y in data_iter:l = loss(net(X), y)total_loss += l.item()*y.shape[0]total_samples += y.numel()return total_loss / total_samples
n_train,n_test,num_inputs,batch_size=20,100,200,5
true_w,true_b=torch.ones((num_inputs,1))*0.01,0.05
train_data=synthetic_data(true_w,true_b,n_train)
test_data=synthetic_data(true_w,true_b,n_test)
train_iter=load_array(train_data,batch_size)
test_iter=load_array(test_data,batch_size,is_train=False)
w,b=init_params(num_inputs)
net=lambda X:linear_reg(X,w,b)
loss=mse_loss
num_epochs,lr,lambd=100,0.001,3
optimizer=torch.optim.SGD([w,b],lr=lr,weight_decay=0.001)
#animator=SimpleAnimator()
for epoch in range(num_epochs):for X,y in train_iter:optimizer.zero_grad()l=loss(net(X),y)l.backward()#sgd([w,b],lr,batch_size)optimizer.step() if (epoch+1)%5==0:train_loss=evaluate_loss(net,train_iter,loss)test_loss=evaluate_loss(net,test_iter,loss)#animator.add(epoch+1,train_loss,test_loss)print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f},test Loss: {test_loss:.4f}")
print('w的L2范数是:', torch.norm(w).item())
plt.show()
http://www.lryc.cn/news/582516.html

相关文章:

  • NumPy-随机数生成详解
  • 初识单例模式
  • 【网络安全】服务间身份认证与授权模式
  • 【Flutter】面试记录
  • Next.js 实战笔记 2.0:深入 App Router 高阶特性与布局解构
  • 算法训练营DAY29 第八章 贪心算法 part02
  • ubuntu 操作记录
  • Python语言+pytest框架+allure报告+log日志+yaml文件+mysql断言实现接口自动化框架
  • 机制、形式、周期、内容:算法备案抽检复审政策讲解
  • 探索下一代云存储技术:对象存储、文件存储与块存储的区别与选择
  • 光流 | 当前光流算法还存在哪些缺点及难题?
  • ReactNative【实战系列教程】我的小红书 4 -- 首页(含顶栏tab切换,横向滚动频道,频道编辑弹窗,瀑布流布局列表等)
  • 闲庭信步使用图像验证平台加速FPGA的开发:第五课——HSV转RGB的FPGA实现
  • Java连接Emqx实现订阅发布消息
  • 恒创科技:香港站群服务器做seo站群优化效果如何
  • ReactNative【实战】瀑布流布局列表(含图片自适应、点亮红心动画)
  • Rust DevOps框架管理实例
  • ffmpeg下编译tsan
  • iOS 性能测试工具全流程:主流工具实战对比与适用场景
  • cocos2dx3.x项目升级到xcode15以上的iconv与duplicate symbols报错问题
  • CSP-S模拟赛二总结(实际难度大于CSP-S)
  • 力扣 239 题:滑动窗口最大值的两种高效解法
  • Android kotlin 协程的详细使用指南
  • C++--AVL树
  • 微前端框架对比
  • (16)Java+Playwright自动化测试-iframe操作-监听事件和执行js脚本
  • 精益管理与数字化转型的融合:中小制造企业降本增效的双重引擎
  • Nexus zkVM 3.0 及未来:迈向模块化、分布式的零知识证明
  • 生成PDF文件(基于 iText PDF )
  • Android framework修改解决偶发开机时有两个launcher入口的情况