当前位置: 首页 > news >正文

TensorFlow与Pytorch的转换——1简单线性回归

import numpy as np# 生成随机数据
# 生成随机数据
x_train = np.random.rand(100000).astype(np.float32)
y_train = 0.5 * x_train + 2 import tensorflow as tf# 定义模型
W = tf.Variable(tf.random.normal([1]))
b = tf.Variable(tf.zeros([1]))
y = W * x_train + b
# 定义损失函数
loss = tf.reduce_mean(tf.square(y - y_train))
# 定义优化器
optimizer = tf.optimizers.SGD(0.5)
# 训练模型
for i in range(100):with tf.GradientTape() as tape:y = W * x_train + bloss = tf.reduce_mean(tf.square(y - y_train))gradients = tape.gradient(loss, [W, b])optimizer.apply_gradients(zip(gradients, [W, b]))if (i+1) % 50 == 0:print("Epoch [{}/{}], loss: {:.3f}, W: {:.3f}, b: {:.3f}".format(i+1, 1000, loss.numpy(), W.numpy()[0], b.numpy()[0]))# 预测新数据
x_test = np.array([0.1, 0.2, 0.3], dtype=np.float32)
y_pred = W * x_test + b
print("Predictions:", y_pred.numpy())
import matplotlib.pyplot as plt# 绘制结果
plt.scatter(x_train, y_train)
plt.plot(x_train, W * x_train + b, c='r')
plt.show()

Pytorch

import torch
import numpy as np
import matplotlib.pyplot as plt# 生成随机数据
x_train = torch.from_numpy(np.random.rand(100000).astype(np.float32))
y_train = 0.5 * x_train + 2# 定义模型参数
W = torch.randn(1, requires_grad=True)
b = torch.zeros(1, requires_grad=True)# 定义损失函数
loss_fn = torch.nn.MSELoss()# 定义优化器
optimizer = torch.optim.SGD([W, b], lr=0.5)# 训练模型
for i in range(100):y = W * x_train + bloss = loss_fn(y, y_train)optimizer.zero_grad()loss.backward()optimizer.step()if (i + 1) % 50 == 0:print(f"Epoch [{i + 1}/{100}], loss: {loss.item():.3f}, W: {W.item():.3f}, b: {b.item():.3f}")# 预测新数据
x_test = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32)
y_pred = W * x_test + b
print("Predictions:", y_pred.detach().numpy())# 绘制结果
plt.scatter(x_train.numpy(), y_train.numpy())
plt.plot(x_train.numpy(), (W * x_train + b).detach().numpy(), c='r')
plt.show()
http://www.lryc.cn/news/457924.html

相关文章:

  • 短剧小程序短剧APP在线追剧APP网剧推广分销微短剧小剧场小程序集师知识付费集师短剧小程序集师小剧场小程序集师在线追剧小程序源码
  • AI与物理学的交汇:Hinton与Hopfield获诺贝尔物理学奖
  • 六西格玛设计DFSS方法论在消费级无人机设计中的应用——张驰咨询
  • 按分类调用标签 调用指定分类下的TAG
  • 报错 - llama-index pydantic error | arbitrary_types_allowed | PydanticUserError
  • PostgreSQL Docker Error – 5432: 地址已被占用
  • 【LeetCode】动态规划—646. 最长数对链(附完整Python/C++代码)
  • 数字媒体产业园区:创新资源集聚,助力企业成长
  • 【Linux】来查看当前系统的架构
  • QT中的信号槽
  • 域名怎么转让给别人?
  • 计算机网络思维导图
  • 07.useDefault
  • git更加详细和灵活的提交过程,附带如何配置. gitignore来忽略部分文件的提交。
  • 使用正则表达式删除文本的奇数行或者偶数行
  • YOLOv10改进策略【注意力机制篇】| CVPR2024 CAA上下文锚点注意力机制
  • Unity修改鼠标图片【超简单】
  • windows C++-创建数据流代理(三)
  • C语言学习-循环嵌套打印字母金字塔
  • 探索CI/CD:持续集成与持续部署的基本概念
  • 大厂面试真题:说一说CMS和G1
  • 使用Qt Creator创建项目
  • C++ 与 C 的那些事儿:深度剖析两者区别
  • 学习​Redis 高可用性​
  • 【含开题报告+文档+PPT+源码】基于springBoot+vue超市仓库管理系统的设计与实现
  • 美发店管理革新:SpringBoot系统的应用
  • C++从0到1
  • VMware Tools 安装和配置
  • 云原生化 - 基础镜像(简约版)
  • 云计算相关