当前位置: 首页 > news >正文

深度学习模型预测值集中在某一个值

深度学习模型,训练过程中,经常遇到预测的结果集中在某个值,而且在学习的过程中会变,样例如下。

主要有如下解决方案

1、更换relu ->tanh 或者其他激活函数

2、更改随机种子,估计是没有初始化好,或者调整学习率试试

3、输入的数据没有标准化,考虑对输入的特征进行分bin or标准化处理

4、增加drop out ,增加 batch normal,增加正则等

class ModelBNDropout(nn.Module):
    def __init__(self, input_size, class_nums=2):
        super(ModelBNDropout, self).__init__()
        
        self.model=nn.Sequential() #序列化模块构造的神经网络
        
        # 第一层 
        self.model.add_module('linear1',nn.Linear(input_size, 1024 )) #卷积层
        self.model.add_module('relu1', nn.ReLU()) #激活函数使用了ReLu
        self.model.add_module('bnorm1', nn.BatchNorm1d(1024))
        self.model.add_module('drop1', nn.Dropout()) 
 

#             L1 = 0
#             L2 = 0
#             for name,param in model.named_parameters():
#                 if 'bias' not in name:
#                     L1 += torch.norm(param, p=1) * 1e-5
#                     L2 += torch.norm(param, p=2) * 1e-3

5、使用其他模型的参数,进行权重初始化

model = torch.load('data/ckpt_xxx.model')

model_drop.fc1 = model.fc1
model_drop.fc2 = model.fc2
model_drop.fc3 = model.fc3
model_drop.fc4 = model.fc4
model_drop.fc5 = model.fc5
 

6、设置初始化函数

# # for m in model_drop.modules():
# #     if isinstance(m, nn.Linear):
# #         print('before',m.weight)
# #         torch.nn.init.kaiming_uniform_(m.weight)
# #         print('after',m.weight)
# #         nn.init.normal_(m.weight, mean=0, std=1)
# #         nn.init.zeros_(m.bias)
# print(model_res)

7、模型是不是在输出的时候加了一层sigmoid激活函数

8、终极大法:

获取模型的子模块,进行预测,看看哪一步出现 or 在每一层layer打印日志,看看在哪个层出现数据集中的情况,修该对应的层的网络结构or激活函数。

model_drop = ModelDropOut(input_size, class_nums=2)
model_drop = model_drop.to(device)

for m in model_drop.modules():
    print(m)

http://www.lryc.cn/news/514388.html

相关文章:

  • Sqoop的使用
  • OpenGL ES 04 图片数据是怎么写入到对应纹理单元的
  • C# 设计模式的六大原则(SOLID)
  • 数据库自增 id 过大导致前端时数据丢失
  • 第二十六天 自然语言处理(NLP)词嵌入(Word2Vec、GloVe)
  • MongoDB 固定集合
  • 数据结构9.3 - 文件基础(C++)
  • Leetcode 1254 Number of Closed Islands + Leetcode 1020 Number of Enclaves
  • Junit4单元测试快速上手
  • U盘提示格式化?原因、恢复方案与预防措施全解析
  • HTML——13.超链接
  • vue中的设计模式
  • 利用python将图片转换为pdf格式的多种方法,实现批量转换,内置模板代码,全网最全,超详细!!!
  • tcpdump的常见方法
  • 工控主板ESM7000/6800E支持远程桌面控制
  • wamp php7.4 运行dm8
  • HTML5 进度条(Progress Bar)详解
  • LabVIEW开发中常见硬件通讯接口快速识别
  • 高频 SQL 50 题(基础版)_1068. 产品销售分析 I
  • 笔记:一次mysql主从复制延迟高的处理尝试
  • 基于C语言的卡丁车管理系统【控制台应用程序】
  • Docker 搭建 Gogs
  • PostgreSQL的备份方式
  • Springboot 3项目整合Knife4j接口文档(接口分组详细教程)
  • 深入解析 Conda 安装的默认依赖包及其作用:conda create安装了哪些包(中英双语)
  • Redis核心技术知识点全集
  • 【Unity3D】ECS入门学习(九)SystemBase
  • 【Triton-ONNX】如何使用 ONNX 模型服务与 Triton 通信执行推理任务上-Triton快速开始
  • CertiK《Hack3d:2024年度安全报告》(附报告全文链接)
  • TIOBE 指数 12 月排行榜公布,VB.Net排行第九