当前位置: 首页 > news >正文

模型选择拟合

1.通过多项式拟合交互探索概念

import math
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l

2.使用三阶多项式来生成训练和测试数据的标签

max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = np.zeros(max_degree)  # 分配大量的空间
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])features = np.random.normal(size=(n_train + n_test, 1))
np.random.shuffle(features)
poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))
for i in range(max_degree):poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!
# labels的维度:(n_train+n_test,)
labels = np.dot(poly_features, true_w)
labels += np.random.normal(scale=0.1, size=labels.shape)

3.查看样本

true_w, features, poly_features, labels = [torch.tensor(x, dtype=torch.float32) for x in [true_w, features, poly_features, labels]]features[:2], poly_features[:2, :], labels[:2]

4.实现函数来评估模型在给定数据集的损失

def evaluate_loss(net, data_iter, loss):"""评估给定数据集上模型的损失"""metric = d2l.Accumulator(2)for X, y in data_iter:out = net(X)y = y.reshape(out.shape)l = loss(out, y)metric.add(l.sum(), l.numel())return metric[0] / metric[1]

5.定义训练函数

def train(train_features, test_features, train_labels, test_labels,num_epochs=400):loss = nn.MSELoss()input_shape = train_features.shape[-1]net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))batch_size = min(10, train_labels.shape[0])train_iter = d2l.load_array((train_features, train_labels.reshape(-1,1)),batch_size)test_iter = d2l.load_array((test_features, test_labels.reshape(-1,1)),batch_size, is_train=False)trainer = torch.optim.SGD(net.parameters(), lr=0.01)animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log',xlim=[1, num_epochs], ylim=[1e-3, 1e2],legend=['train', 'test'])for epoch in range(num_epochs):d2l.train_epoch_ch3(net, train_iter, loss, trainer)if epoch == 0 or (epoch + 1) % 20 == 0:animator.add(epoch + 1, (evaluate_loss(net, train_iter, loss),evaluate_loss(net, test_iter, loss)))print('weight:', net[0].weight.data.numpy())

http://www.lryc.cn/news/469004.html

相关文章:

  • 文案语音图片视频管理分析系统-视频矩阵
  • ArcGIS计算落入面图层中的线的长度或面的面积
  • ctfshow-web入门-web172
  • Keep健身TV版 3.3.0 | 针对智能电视的健身塑形软件
  • 推荐一些关于计算机网络和 TCP/IP 协议的书籍
  • 生成式AI浪潮下的商业机遇与经济展望 —— 与互联网时代的比较
  • Go 标准库
  • AUTOSAR_EXP_ARAComAPI的6章笔记(5)
  • Photoshop中的混合模式公式详解
  • Vue 自定义指令 Directive 的高级使用与最佳实践
  • 万字图文实战:从0到1构建 UniApp + Vue3 + TypeScript 移动端跨平台开源脚手架
  • 在WebStorm遇到Error: error:0308010C:digital envelope routines::unsupported报错时的解决方案
  • 数据库产品中SQL注入防护功能应该包含哪些功能
  • Ribbon客户端负载均衡策略测试及其改进
  • linux网络编程5——Posix API和网络协议栈,使用TCP实现P2P通信
  • 低代码平台中的功能驱动开发:模块化与领域设计
  • HTTP和HTTPS基本概念,主要区别,应用场景
  • node.js使用Sequelize ORM操作数据库
  • STM32-Modbus协议(一文通)
  • 100. 不同方向的投影视图
  • Appium中的api(三)
  • 踩坑:关于使用ceph pg repair引发的业务阻塞
  • 瞬间升级!电子文档华丽变身在线题库,效率翻倍✨
  • 如何动态改变本地的ip
  • Spring Boot框架在中小企业设备管理中的创新应用
  • Ceph入门到精通-Osd db扩容
  • windows msvc2017 x64编译AWS SDK CPP库
  • 铜业机器人剥片 - SNK施努卡
  • 非接触式竖向位移、水平位移视频实时在线监测的设备分类及选型
  • Svelte 5 正式发布:新一代前端框架!