当前位置: 首页 > news >正文

机器学习 - 提高模型 (代码)

如果模型出现了 underfitting 问题,就得提高模型了。

Model improvement techniqueWhat does it do?
Add more layersEach layer potentially increases the learning capabilities of the model with each layer being able to learn some kind of new pattern in the data, more layers is often referred to as making your neural network deeper.
Add more hidden unitsMore hidden units per layer means a potential increase in learning capabilities of the model, more hidden units is often referred to as making your neural network wider.
Fitting for longer (more epochs)Your model might learn more if it had more opportunities to look at the data.
Changing the activation functionsSome data just can’t be fit with only straight lines, using non-linear activation functions can help with this.
Change the learning rateLess model specific, but still related, the learning rate of the optimizer decides how much a model should change its parameter each step, too much and the model overcorrects, too little and it doesn’t learn enough.
Change the loss functionLess model specific but still important, different problems require different loss functions. For example, a binary cross entropy loss function won’t work with a multi-class classification problem.
Use transfer learningTake a pretrained model from a problem domain similar to yours and adjust it to your own problem.

举个例子,代码如下:

class CircleModelV1(nn.Module):def __init__(self):super().__init__()self.layer_1 = nn.Linear(in_features = 2, out_features = 10)self.layer_2 = nn.Linear(in_features = 10, out_features = 10)self.layer_3 = nn.Linear(in_features = 10, out_features = 1)def forward(self, x):return self.layer_3(self.layer_2(self.layer_1(x)))model_1 = CircleModelV1().to("cpu")
print(model_1)loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model_1.parameters(), lr=0.1)torch.manual_seed(42)epochs = 1000X_train, y_train = X_train.to("cpu"), y_train.to("cpu")
X_test, y_test = X_test.to("cpu"), y_test.to("cpu")for epoch in range(epochs):### Training# 1. Forward pass y_logits = model_1(X_train).squeeze()y_pred = torch.round(torch.sigmoid(y_logits))  # logits -> probabilities -> prediction labels # 2. Calculate loss/accuracy loss = loss_fn(y_logits, y_train)acc = accuracy_fn(y_true = y_train, y_pred = y_pred)# 3. Optimizer zero grad optimizer.zero_grad()# 4. Loss backwards loss.backward()# 5. Optimizer step optimizer.step() ### Testing model_1.eval()with torch.inference_mode():# 1. Forward pass test_logits = model_1(X_test).squeeze()test_pred = torch.round(torch.sigmoid(test_logits))# 2. Calculate loss/accuracy test_loss = loss_fn(test_logits, y_test)test_acc = accuracy_fn(y_true = y_test, y_pred = test_pred)if epoch % 100 == 0:print(f"Epoch: {epoch} | Loss: {loss:.5f}, Accuracy: {acc:.2f}%")# 结果如下
CircleModelV1((layer_1): Linear(in_features=2, out_features=10, bias=True)(layer_2): Linear(in_features=10, out_features=10, bias=True)(layer_3): Linear(in_features=10, out_features=1, bias=True)
)
Epoch: 0 | Loss: 0.69528, Accuracy: 51.38%
Epoch: 100 | Loss: 0.69325, Accuracy: 47.88%
Epoch: 200 | Loss: 0.69309, Accuracy: 49.88%
Epoch: 300 | Loss: 0.69303, Accuracy: 50.50%
Epoch: 400 | Loss: 0.69300, Accuracy: 51.38%
Epoch: 500 | Loss: 0.69299, Accuracy: 51.12%
Epoch: 600 | Loss: 0.69298, Accuracy: 51.50%
Epoch: 700 | Loss: 0.69298, Accuracy: 51.38%
Epoch: 800 | Loss: 0.69298, Accuracy: 51.50%
Epoch: 900 | Loss: 0.69298, Accuracy: 51.38%

都看到这了,点个赞呗~

http://www.lryc.cn/news/328254.html

相关文章:

  • 数值代数及方程数值解:预备知识——二进制及浮点数
  • 新数字时代的启示:揭开Web3的秘密之路
  • 算法——动态规划:01背包
  • 写作类AI推荐(二)
  • 分寝室(20分)(JAVA)
  • Spring 源码调试问题 ( List.of(“bin“, “build“, “out“); )
  • Centos7安装RTL8111网卡驱动
  • 吉时利KEITHLEY2460数字源表
  • 数据库原理(含思维导图)
  • 数据结构(六)——图
  • Android-AR眼镜屏幕显示
  • 蓝桥集训之货币系统
  • 基于微信小程序的校园服务平台设计与实现(程序+论文)
  • QT+Opencv+yolov5实现监测
  • 【Python-Docx库】Word与Python的完美结合
  • 吴恩达深度学习笔记:浅层神经网络(Shallow neural networks)3.6-3.8
  • 盘点最适合做剧场版的国漫,最后一部有望成为巅峰
  • Altium Designer许可需求分析
  • [c++]类和对象常见题目详解
  • 【c++】类和对象(五)赋值运算符重载
  • 密码学基础-对称密码/公钥密码/混合密码系统 详解
  • 《装饰器模式(极简c++)》
  • Spring Boot 整合分布式搜索引擎 Elastic Search 实现 自动补全功能
  • 实现一个Google身份验证代替短信验证
  • Spring框架与Spring Boot的区别和联系
  • [OpenCV学习笔记]Qt+OpenCV实现图像灰度反转、对数变换和伽马变换
  • 【大数据】Flink学习笔记
  • 社交网络的未来:Facebook如何塑造数字社交的下一章
  • RabbitMQ 延时消息实现
  • 【Django】枚举类型数据