PyTorch 训练神经网络模型,并集成到springboot项目中
使用 PyTorch 训练一个 5 输入单输出的神经网络模型(含隐藏层),导出为 ONNX 格式,再用 Java 加载推理。
一、PyTorch 训练 5 输入神经网络并导出 ONNX
- 模型设计
输入:5 个特征(in_features=5)
网络结构:2 层隐藏层(带 ReLU 激活)+ 输出层(线性激活,适合回归任务)
任务:预测一个连续值(例如根据 5 个特征预测某个指标) - 完整代码(PyTorch 部分)
python
运行
import torch
import torch.nn as nn
import torch.optim as optim
1. 定义神经网络模型(5输入 → 隐藏层 → 输出)
class NeuralNetwork(nn.Module):
def init(self):
super().init()
# 隐藏层1:5→16,ReLU激活
self.layer1 = nn.Linear(in_features=5, out_features=16)
# 隐藏层2:16→8,ReLU激活
self.layer2 = nn.Linear(in_features=16, out_features=8)
# 输出层:8→1(单输出)
self.output_layer = nn.Linear(in_features=8, out_features=1)
self.relu = nn.ReLU() # 激活函数
def forward(self, x):x = self.relu(self.layer1(x)) # 第一层 + ReLUx = self.relu(self.layer2(x)) # 第二层 + ReLUx = self.output_layer(x) # 输出层(无激活,适合回归)r