当前位置: 首页 > news >正文

TensorFlow简单的线性回归任务

如何使用 TensorFlow 和 Keras 创建、训练并进行预测

1. 数据准备与预处理

2. 构建模型

3. 编译模型

4. 训练模型

5. 评估模型

6. 模型应用与预测

7. 保存与加载模型

8.完整代码


1. 数据准备与预处理

我们将使用一个简单的线性回归问题,其中输入特征 x 和标签 y 之间存在线性关系。我们创建一个训练数据集,并将标签设置为输入特征的两倍加上一些噪声。

import numpy as np
import tensorflow as tf# 创建训练数据,x 是输入特征,y 是标签(y = 2 * x + 噪声)
x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=float)  # 输入数据
y = 2 * x + np.random.normal(0, 1, size=x.shape)  # 标签数据,加一些噪声

2. 构建模型

我们使用一个简单的神经网络来进行线性回归。这个网络只有一个全连接层,激活函数是线性的。

model = tf.keras.Sequential([tf.keras.layers.Dense(units=1, input_dim=1, activation='linear')  # 线性激活函数
])

3. 编译模型

使用 SGD 优化器和均方误差损失函数,适合线性回归问题。

model.compile(optimizer='sgd', loss='mean_squared_error')

4. 训练模型

训练模型时,我们设置 1000 个训练周期,并传入数据 x 和标签 y

model.fit(x, y, epochs=1000)

5. 评估模型

训练结束后,我们评估模型的表现,使用 evaluate 函数来查看损失值。

loss = model.evaluate(x, y)
print(f"模型的损失值:{loss}")

6. 模型应用与预测

训练完成后,我们使用 model.predict() 来进行预测。你可以将新的输入数据传入模型,得到预测结果。

# 使用模型进行预测
new_x = np.array([11, 12, 13, 14, 15], dtype=float)
predictions = model.predict(new_x)print("新的输入数据预测结果:")
print(predictions)

7. 保存与加载模型

你还可以保存和加载训练好的模型,以便在未来使用。\

# 保存模型
model.save('linear_model.keras')# 加载模型
loaded_model = tf.keras.models.load_model('linear_model.keras')# 使用加载的模型进行预测
loaded_predictions = loaded_model.predict(new_x)
print("加载的模型预测结果:")
print(loaded_predictions)

8.完整代码

import numpy as np
import tensorflow as tf# 创建训练数据,x 是输入特征,y 是标签(y = 2 * x + 噪声)
x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=float)
y = 2 * x + np.random.normal(0, 1, size=x.shape)# 构建模型
model = tf.keras.Sequential([tf.keras.layers.Dense(units=1, input_dim=1, activation='linear')  # 线性激活函数
])# 编译模型
model.compile(optimizer='sgd', loss='mean_squared_error')# 训练模型
model.fit(x, y, epochs=1000)# 评估模型
loss = model.evaluate(x, y)
print(f"模型的损失值:{loss}")# 使用模型进行预测
new_x = np.array([11, 12, 13, 14, 15], dtype=float)
predictions = model.predict(new_x)print("新的输入数据预测结果:")
print(predictions)# 保存模型
model.save('linear_model.keras')# 加载模型
loaded_model = tf.keras.models.load_model('linear_model.keras')# 使用加载的模型进行预测
loaded_predictions = loaded_model.predict(new_x)
print("加载的模型预测结果:")
print(loaded_predictions)

 

http://www.lryc.cn/news/530476.html

相关文章:

  • 【视频+图文详解】HTML基础4-html标签的基本使用
  • 在Arm芯片苹果Mac系统上通过homebrew安装多版本mysql并解决各种报错,感谢deepseek帮助解决部分问题
  • c++可变参数详解
  • 【深度分析】DeepSeek 遭暴力破解,攻击 IP 均来自美国,造成影响有多大?有哪些好的防御措施?
  • CMake项目编译与开源项目目录结构
  • 完全卸载mysql server步骤
  • C#方法(练习)
  • Unity游戏(Assault空对地打击)开发(3) 摄像机的控制
  • ChatGPT-4o和ChatGPT-4o mini的差异点
  • SQL进阶实战技巧:某芯片工厂设备任务排产调度分析 | 间隙分析技术应用
  • 【力扣】438.找到字符串中所有字母异位词
  • 2024具身智能模型汇总:从训练数据、动作预测、训练方法到Robotics VLM、VLA
  • Day33【AI思考】-函数求导过程 的优质工具和网站
  • 【URL】一个简单基于Gym的2D随机游走环境,用于无监督强化学习(URL)
  • 【VM】VirtualBox安装ubuntu22.04虚拟机
  • MySQL的GROUP BY与COUNT()函数的使用问题
  • C# 精炼题18道题(类,三木运算,Switch,计算器)
  • 96,【4】 buuctf web [BJDCTF2020]EzPHP
  • 数据库 - Sqlserver - SQLEXPRESS、由Windows认证改为SQL Server Express认证进行连接 (sa登录)
  • 2025年02月02日Github流行趋势
  • 【数据分析】案例03:当当网近30日热销图书的数据采集与可视化分析(scrapy+openpyxl+matplotlib)
  • 如何使用 DeepSeek 和 Dexscreener 构建免费的 AI 加密交易机器人?
  • buu-jarvisoj_level0-好久不见30
  • 深度学习查漏补缺:1.梯度消失、梯度爆炸和残差块
  • 【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】2.2 多维数组切片:跨步访问与内存布局
  • ResNet--深度学习中的革命性网络架构
  • TypeScript语言的语法糖
  • 17.2 图形绘制4
  • tomcat核心组件及原理概述
  • 本地部署DeepSeek教程(Mac版本)