当前位置: 首页 > news >正文

PyTorch-线性回归

已经进入大模微调的时代,但是学习pytorch,对后续学习rasa框架有一定帮助吧。

<!--  给出一系列的点作为线性回归的数据,使用numpy来存储这些点。 -->
x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],[9.779], [6.182], [7.59], [2.167], [7.042],[10.791], [5.313], [7.997], [3.1]], dtype=np.float32)
y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],[3.366], [2.596], [2.53], [1.221], [2.827],[3.465], [1.65], [2.904], [1.3]], dtype=np.float32)<!--  转化tensor格式。 -->
x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)<!--  这里的nn.Linear表示的是 y=w*x b,里面的两个参数都是1,表示的是x是1维,y也是1维。当然这里是可以根据你想要的输入输出维度来更改的。 -->
class linearRegression(nn.Module):def __init__(self):super(linearRegression, self).__init__()self.linear = nn.Linear(1, 1)  # input and output is 1 dimensiondef forward(self, x):out = self.linear(x)return out
model = linearRegression()<!-- 定义loss和优化函数,这里使用的是最小二乘loss,之后我们做分类问题更多的使用的是cross entropy loss,交叉熵。优化函数使用的是随机梯度下降,注意需要将model的参数model.parameters()传进去让这个函数知道他要优化的参数是那些。 -->
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)<!-- 开始训练 -->
num_epochs = 1000
for epoch in range(num_epochs):inputs = Variable(x_train)target = Variable(y_train)# forwardout = model(inputs) # 前向传播loss = criterion(out, target) # 计算loss# backwardoptimizer.zero_grad() # 梯度归零loss.backward() # 反向传播optimizer.step() # 更新参数if (epoch 1) % 20 == 0:print(f'Epoch[{epoch+1}/{num_epochs}], loss: {loss.item():.6f}')<!--训练完成之后我们就可以开始测试模型了-->
model.eval()
predict = model(Variable(x_train))
predict = predict.data.numpy()<!-- 显示图例 -->
fig = plt.figure(figsize=(10, 5))
plt.plot(x_train.numpy(), y_train.numpy(), 'ro', label='Original data')
plt.plot(x_train.numpy(), predict, label='Fitting Line')plt.legend() 
plt.show()<!-- 保存模型 -->
torch.save(model.state_dict(), './linear.pth')

http://www.lryc.cn/news/300950.html

相关文章:

  • C++数据结构与算法——栈与队列
  • 掌上新闻随心播控,HarmonyOS SDK助力新浪新闻打造精致易用的资讯服务新体验
  • 2024年危险化学品经营单位主要负责人证模拟考试题库及危险化学品经营单位主要负责人理论考试试题
  • C/C++如何把指针所指向的指针设为空指针?
  • 第三节:基于 InternLM 和 LangChain 搭建你的知识库(课程笔记)
  • qt-C++笔记之打印所有发生的事件
  • pytorch 实现线性回归(深度学习)
  • [Doris] Doris的安装和部署 (二)
  • 【QT+QGIS跨平台编译】之三十五:【cairo+Qt跨平台编译】(一套代码、一套框架,跨平台编译)
  • MySQL(基础)
  • STM32F1 - 中断系统
  • 【Linux系统化学习】缓冲区
  • 基于BP算法的SAR成像matlab仿真
  • 【C++ STL】你真的了解string吗?浅谈string的底层实现
  • 17.3.1.3 灰度
  • 基于CAS操作的atomic原子类型
  • Rust HashMap详解及单词统计示例
  • 命令执行讲解和函数
  • 外包实在是太坑了,划水三年,感觉人都废了
  • 代码随想录算法训练营第19天
  • 树莓派5 EEPROM引导加载程序恢复镜像
  • 循序渐进-讲解Markdown进阶(Mermaid绘图)-附使用案例
  • 寒假作业2月6号
  • ChatGPT绘图指南:DALL.E3玩法大全(一)
  • 计算机服务器中了_locked勒索病毒怎么办?Encrypted勒索病毒解密数据恢复
  • VueCLI核心知识3:全局事件总线、消息订阅与发布
  • Redis中缓存问题
  • 数码管扫描显示-单片机通用模板
  • IDEA中的神仙插件——Smart Input (自动切换输入法)
  • shell编程:求稀疏数组中元素的和(下标不连续)