当前位置: 首页 > news >正文

RNN实现阿尔茨海默症的诊断识别

本文为为🔗365天深度学习训练营内部文章

原作者:K同学啊

 一 导入数据

import torch.nn as nn
import torch.nn.functional as F
import torchvision,torch
from sklearn.preprocessing import StandardScaler
from torch.utils.data import TensorDataset,DataLoader
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
import warnings
warnings.filterwarnings('ignore')# 设置硬件设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")df = pd.read_excel('dia.xls')
df

二 数据处理分析

# 删除第一列和最后一列
df = df.iloc[:,1:-1]
print(df)
 Age  Gender  Ethnicity  EducationLevel        BMI  Smoking   
0      73       0          0               2  22.927749        0  \
1      89       0          0               0  26.827681        0   
2      73       0          3               1  17.795882        0   
3      74       1          0               1  33.800817        1   
4      89       0          0               0  20.716974        0   
...   ...     ...        ...             ...        ...      ...   
2144   61       0          0               1  39.121757        0   
2145   75       0          0               2  17.857903        0   
2146   77       0          0               1  15.476479        0   
2147   78       1          3               1  15.299911        0   
2148   72       0          0               2  33.289738        0   AlcoholConsumption  PhysicalActivity  DietQuality  SleepQuality  ...   
0              13.297218          6.327112     1.347214      9.025679  ...  \
1               4.542524          7.619885     0.518767      7.151293  ...   
2              19.555085          7.844988     1.826335      9.673574  ...   
3              12.209266          8.428001     7.435604      8.392554  ...   
4              18.454356          6.310461     0.795498      5.597238  ...   
...                  ...               ...          ...           ...  ...   
2144            1.561126          4.049964     6.555306      7.535540  ...   
2145           18.767261          1.360667     2.904662      8.555256  ...   
2146            4.594670          9.886002     8.120025      5.769464  ...   
2147            8.674505          6.354282     1.263427      8.322874  ...   
2148            7.890703          6.570993     7.941404      9.878711  ...   FunctionalAssessment  MemoryComplaints  BehavioralProblems       ADL   
0                 6.518877                 0                   0  1.725883  \
1                 7.118696                 0                   0  2.592424   
2                 5.895077                 0                   0  7.119548   
3                 8.965106                 0                   1  6.481226   
4                 6.045039                 0                   0  0.014691   
...                    ...               ...                 ...       ...   
2144              0.238667                 0                   0  4.492838   
2145              8.687480                 0                   1  9.204952   
2146              1.972137                 0                   0  5.036334   
2147              5.173891                 0                   0  3.785399   
2148              6.307543                 0                   1  8.327563   Confusion  Disorientation  PersonalityChanges   
0             0               0                   0  \
1             0               0                   0   
2             0               1                   0   
3             0               0                   0   
4             0               0                   1   
...         ...             ...                 ...   
2144          1               0                   0   
2145          0               0                   0   
2146          0               0                   0   
2147          0               0                   0   
2148          0               1                   0   DifficultyCompletingTasks  Forgetfulness  Diagnosis  
0                             1              0          0  
1                             0              1          0  
2                             1              0          0  
3                             0              0          0  
4                             1              0          0  
...                         ...            ...        ...  
2144                          0              0          1  
2145                          0              0          1  
2146                          0              0          1  
2147                          0              1          1  
2148                          0              1          0  [2149 rows x 33 columns]

三 探索性数据分析 

1.得病分布

res = df.groupby('Diabetes')['Age'].count()
print(res)plt.figure(figsize=(8, 6))
plt.pie(res.values, labels=res.index, autopct='%1.1f%%', startangle=90,colors=['#ff9999','#66b3ff','#99ff99'], explode=(0.1,  0),wedgeprops={'edgecolor': 'black', 'linewidth': 1, 'linestyle': 'solid'})
plt.title('是否得阿尔茨海默症', fontsize=16, fontweight='bold')
plt.show()

 2.BMI分布直方图

# BMI分布直方图
sns.displot(df['BMI'], kde=True, color='skyblue', bins=30, height=6, aspect=1.2)
plt.title('BMI Distribution', fontsize=18, fontweight='bold', color='darkblue')
plt.xlabel('BMI', fontsize=14, color='darkgreen')
plt.ylabel('Frequency', fontsize=14, color='darkgreen')
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()

 

3.年龄分布直方图 

# Age分布直方图
sns.displot(df['Age'], kde=True, color='skyblue', bins=30, height=6, aspect=1.2)
plt.title('Age Distribution', fontsize=18, fontweight='bold', color='darkblue')
plt.xlabel('Age', fontsize=14, color='darkgreen')
plt.ylabel('Frequency', fontsize=14, color='darkgreen')
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()

 

 

四 构建划分数据集 

X = df.iloc[:,:-1]
y = df.iloc[:,-1]sc = StandardScaler()
X = sc.fit_transform(X)# 划分数据集
X = torch.tensor(np.array(X),dtype=torch.float32)
y = torch.tensor(np.array(y),dtype=torch.int64)X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.1,random_state=1)# 构建数据加载器
train_dl = DataLoader(TensorDataset(X_train,y_train),batch_size=64,shuffle=False)
test_dl = DataLoader(TensorDataset(X_test,y_test),batch_size=64,shuffle=False)

 五 训练模型

1.构建模型

# 构建模型
class model_rnn(nn.Module):def __init__(self):super(model_rnn, self).__init__()self.rnn0 = nn.RNN(input_size=32,hidden_size=200,num_layers=1,batch_first=True)self.fc0 = nn.Linear(200,50)self.fc1 = nn.Linear(50,2)def forward(self,x):out , hidden1 = self.rnn0(x)out = self.fc0(out)out = self.fc1(out)return outmodel = model_rnn().to(device)
print(model)
model_rnn((rnn0): RNN(32, 200, batch_first=True)(fc0): Linear(in_features=200, out_features=50, bias=True)(fc1): Linear(in_features=50, out_features=2, bias=True)
)

2.训练函数 

'''
训练模型
'''
# 训练循环
def train(dataloader,model,loss_fn,optimizer):size = len(dataloader.dataset)   # 训练集的大小num_batches = len(dataloader)      # 批次数目,(size/batchsize,向上取整)train_acc,train_loss = 0,0  # 初始化训练损失和正确率for x,y in dataloader:    # 获取数据X,y = x.to(device),y.to(device)# 计算预测误差pred = model(X)   # 网络输出loss = loss_fn(pred,y)   # 计算误差# 反向传播optimizer.zero_grad()    # grad属性归零loss.backward()   # 反向传播optimizer.step()   # 每一步自动更新# 记录acc与losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc,train_loss

3.测试函数 

# 测试循环
def valid(dataloader,model,loss_fn):size = len(dataloader.dataset)  # 训练集的大小num_batches = len(dataloader)  # 批次数目,(size/batchsize,向上取整)test_loss, test_acc = 0, 0  # 初始化训练损失和正确率# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs,target in dataloader:imgs,target = imgs.to(device),target.to(device)# 计算losstarget_pred = model(imgs)loss = loss_fn(target_pred,target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc,test_loss

4.正式训练 

loss_fn = nn.CrossEntropyLoss()   # 创建损失函数
learn_rate = 1e-4   # 学习率
opt = torch.optim.Adam(model.parameters(),lr=learn_rate)
epochs = 30train_loss = []
train_acc = []
test_loss = []
test_acc = []for epoch in range(epochs):model.train()epoch_train_acc,epoch_train_loss = train(train_dl,model,loss_fn,opt)model.eval()epoch_test_acc,epoch_test_loss = valid(test_dl,model,loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前的学习率lr = opt.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},lr:{:.2E}')print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss,lr))print("="*20,'Done',"="*20)
Epoch: 1,Train_acc:52.9%,Train_loss:0.688,Test_acc:67.0%,Test_loss:0.658,lr:1.00E-04
Epoch: 2,Train_acc:68.7%,Train_loss:0.612,Test_acc:67.4%,Test_loss:0.600,lr:1.00E-04
Epoch: 3,Train_acc:68.7%,Train_loss:0.566,Test_acc:70.7%,Test_loss:0.567,lr:1.00E-04
Epoch: 4,Train_acc:74.4%,Train_loss:0.526,Test_acc:72.6%,Test_loss:0.533,lr:1.00E-04
Epoch: 5,Train_acc:77.9%,Train_loss:0.487,Test_acc:78.1%,Test_loss:0.501,lr:1.00E-04
Epoch: 6,Train_acc:81.1%,Train_loss:0.451,Test_acc:79.5%,Test_loss:0.473,lr:1.00E-04
Epoch: 7,Train_acc:82.3%,Train_loss:0.421,Test_acc:80.0%,Test_loss:0.451,lr:1.00E-04
Epoch: 8,Train_acc:83.4%,Train_loss:0.397,Test_acc:78.6%,Test_loss:0.434,lr:1.00E-04
Epoch: 9,Train_acc:84.7%,Train_loss:0.378,Test_acc:80.0%,Test_loss:0.422,lr:1.00E-04
Epoch:10,Train_acc:85.2%,Train_loss:0.365,Test_acc:80.0%,Test_loss:0.414,lr:1.00E-04
Epoch:11,Train_acc:85.6%,Train_loss:0.354,Test_acc:80.0%,Test_loss:0.408,lr:1.00E-04
Epoch:12,Train_acc:85.9%,Train_loss:0.347,Test_acc:80.0%,Test_loss:0.405,lr:1.00E-04
Epoch:13,Train_acc:86.3%,Train_loss:0.341,Test_acc:78.6%,Test_loss:0.403,lr:1.00E-04
Epoch:14,Train_acc:87.0%,Train_loss:0.335,Test_acc:78.1%,Test_loss:0.403,lr:1.00E-04
Epoch:15,Train_acc:87.1%,Train_loss:0.331,Test_acc:78.6%,Test_loss:0.404,lr:1.00E-04
Epoch:16,Train_acc:87.1%,Train_loss:0.327,Test_acc:78.1%,Test_loss:0.405,lr:1.00E-04
Epoch:17,Train_acc:87.1%,Train_loss:0.324,Test_acc:78.6%,Test_loss:0.407,lr:1.00E-04
Epoch:18,Train_acc:87.3%,Train_loss:0.321,Test_acc:78.6%,Test_loss:0.409,lr:1.00E-04
Epoch:19,Train_acc:87.4%,Train_loss:0.318,Test_acc:77.7%,Test_loss:0.412,lr:1.00E-04
Epoch:20,Train_acc:87.7%,Train_loss:0.315,Test_acc:78.1%,Test_loss:0.415,lr:1.00E-04
Epoch:21,Train_acc:87.8%,Train_loss:0.312,Test_acc:77.7%,Test_loss:0.418,lr:1.00E-04
Epoch:22,Train_acc:88.1%,Train_loss:0.309,Test_acc:78.1%,Test_loss:0.422,lr:1.00E-04
Epoch:23,Train_acc:88.6%,Train_loss:0.306,Test_acc:78.1%,Test_loss:0.425,lr:1.00E-04
Epoch:24,Train_acc:88.6%,Train_loss:0.303,Test_acc:79.1%,Test_loss:0.429,lr:1.00E-04
Epoch:25,Train_acc:88.6%,Train_loss:0.301,Test_acc:79.5%,Test_loss:0.433,lr:1.00E-04
Epoch:26,Train_acc:88.6%,Train_loss:0.298,Test_acc:79.5%,Test_loss:0.437,lr:1.00E-04
Epoch:27,Train_acc:88.8%,Train_loss:0.295,Test_acc:80.0%,Test_loss:0.440,lr:1.00E-04
Epoch:28,Train_acc:89.1%,Train_loss:0.292,Test_acc:79.5%,Test_loss:0.444,lr:1.00E-04
Epoch:29,Train_acc:89.1%,Train_loss:0.290,Test_acc:79.1%,Test_loss:0.449,lr:1.00E-04
Epoch:30,Train_acc:89.2%,Train_loss:0.287,Test_acc:79.1%,Test_loss:0.453,lr:1.00E-04
==================== Done ====================

六 结果可视化 

1.Loss和Acc图

epochs_range = range(30)
plt.figure(figsize=(14,4))
plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label='training accuracy')
plt.plot(epochs_range,test_acc,label='validation accuracy')
plt.legend(loc='lower right')
plt.title('training and validation accuracy')plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='training loss')
plt.plot(epochs_range,test_loss,label='validation loss')
plt.legend(loc='upper right')
plt.title('training and validation loss')
plt.show()

 2.调用模型预测

test_X = X_test[0].reshape(1,-1)
pred = model(test_X.to(device)).argmax(1).item()
print('模型预测结果:',pred)
print('=='*20)
print('0:未患病')
print('1:已患病')
模型预测结果: 0
========================================
0:未患病
1:已患病

3.绘制混淆矩阵 

'''
绘制混淆矩阵
'''
print('=============输入数据shape为==============')
print('X_test.shape:',X_test.shape)
print('y_test.shape:',y_test.shape)pred = model(X_test.to(device)).argmax(1).cpu().numpy()print('\n==========输出数据shape为==============')
print('pred.shape:',pred.shape)from sklearn.metrics import confusion_matrix,ConfusionMatrixDisplay# 计算混淆矩阵
cm = confusion_matrix(y_test,pred)plt.figure(figsize=(6,5))
plt.suptitle('')
sns.heatmap(cm,annot=True,fmt='d',cmap='Blues')
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title('Confusion Matrix',fontsize=12)
plt.xlabel('Pred Label',fontsize=10)
plt.ylabel('True Label',fontsize=10)
plt.tight_layout()
plt.show()
=============输入数据shape为==============
X_test.shape: torch.Size([215, 32])
y_test.shape: torch.Size([215])==========输出数据shape为==============
pred.shape: (215,)

 

http://www.lryc.cn/news/527374.html

相关文章:

  • 14-6-1C++STL的list
  • Redis事务机制详解与Springboot项目中的使用
  • DeepSeek-R1,用Ollama跑起来
  • Leecode刷题C语言之组合总和②
  • YOLOv8改进,YOLOv8检测头融合DynamicHead,并添加小目标检测层(四头检测),适合目标检测、分割等,全网独发
  • 【PyQt】QThread快速创建多线程任务
  • 智能码二维码的成本效益分析
  • 企业财务管理系统的需求设计和实现
  • Springboot集成Swagger和Springdoc详解
  • 类和对象(4)——多态:方法重写与动态绑定、向上转型和向下转型、多态的实现条件
  • ui-automator定位官网文档下载及使用
  • 董事会办公管理系统的需求设计和实现
  • ESP32和STM32在处理中断方面的区别
  • 零售业革命:改变行业的顶级物联网用例
  • 字符串算法笔记
  • 在Ubuntu上用Llama Factory命令行微调Qwen2.5的简单过程
  • ThinkPhp伪静态设置后,访问静态资源也提示找不到Controller
  • JavaScript赋能智能网页设计
  • 基于STM32的阿里云智能农业大棚
  • 80,【4】BUUCTF WEB [SUCTF 2018]MultiSQL
  • 深入探索imi框架:PHP Swoole的高性能协程应用实践
  • 【算法篇·更新中】C++秒入门(附练习用题目)
  • 对神经网络基础的理解
  • 存储基础 -- SCSI命令格式与使用场景
  • 从崩溃难题看 C 标准库与 Rust:线程安全问题引发的深度思考
  • 【CSS入门学习】Flex布局设置div水平、垂直分布与居中
  • 9. 神经网络(一.神经元模型)
  • R 语言 | future 包,非阻塞的执行耗时脚本
  • UE学习日志#12 Niagara特效大致了解(水文,主要是花时间读了读文档和文章)
  • 【数据结构】_链表经典算法OJ:合并两个有序数组