深度学习入门:用pytorch跑通GitHub的UNET-ZOO项目
一、环境配置
首先打开Anaconda prompt,创建一个虚拟环境,python版本号选3.9
conda create -n unet python=3.9
激活
conda activate unet
在当前环境下下载pytorch
建议在官网安装,镜像容易装错版本。
(1)打开pytorch官网:PyTorch,点击Get Started
要选择pip,我装的是11.8,兼容的版本多一点,不容易出错。
把复制链接到Anaconda Prompt创建的虚拟环境中安装。
然后下载CUDA,版本根据自己的cuda version来,不知道的就查一下,
win+R,输入CMD,输入以下代码查看
nvidia-smi
查看CUDA Version,我的是12.9,所以要去官网装12.9以下的版本,我装的11.8.0
直接官网装:CUDA Toolkit Archive | NVIDIA Developer
安装好一会,在Anaconda Prompt创建的虚拟环境中测试一下,依次输入以下代码:
python
import torch
torch.cuda.is_available()
结果返回true,则表示torch安装成功,并且可以调用了。
二、下载UNET-ZOO项目到本地
项目地址为:Andy-zhujunwen/UNET-ZOO: including unet,unet++,attention-unet,r2unet,cenet,segnet ,fcn.
新建一个文件夹,下载项目ZIP到这个文件夹
再下载项目提供的数据集,这里建议下载Liver数据集
把压缩包解压(数据集压缩包解压到项目目录下),用pycharm打开这个项目
三、跑通项目
记得要把pycharm环境切换至虚拟环境,步骤为
File->settings
点击Add interpreter
use existing environment选择你配置第一步配置的虚拟环境名字,点ok就行了。
然后,在pycharm找到dataset.py,修改路径为你的数据集路径
然后在terminal中输入以下代码
python main.py --action train --arch UNet --epoch 21 --batch_size 21
模型就可以跑通了