当前位置: 首页 > news >正文

Matlab 深度学习工具箱 案例学习与测试————求二阶微分方程

clc
clear% 定义输入变量
x = linspace(0,2,10000)';% 定义网络的层参数
inputSize = 1;
layers = [featureInputLayer(inputSize,Normalization="none")fullyConnectedLayer(10)sigmoidLayerfullyConnectedLayer(1)sigmoidLayer];
% 创建网络
net = dlnetwork(layers);% 训练轮数
numEpochs = 15;
% 每个Batch的数据个数
miniBatchSize = 100;

% SGDM优化方法设置的参数
initialLearnRate = 0.5;
learnRateDropFactor = 0.5;
learnRateDropPeriod = 5;
momentum = 0.9;
velocity = [];

% 损失函数里面考虑初始条件的系数
icCoeff = 7;% ArrayDatastore
ads = arrayDatastore(x,IterationDimension=1);
% 创建一个用于处理管理深度学习数据的对象
mbq = minibatchqueue(ads, ...MiniBatchSize=miniBatchSize, ...PartialMiniBatch="discard", ...MiniBatchFormat="BC");% 用于迭代过程监控
numObservationsTrain = numel(x);
numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;% 创建监控对象 
% 由于计时器在您创建监控器对象时启动,因此请确保在靠近训练循环的位置创建对象。
monitor = trainingProgressMonitor( ...Metrics="LogLoss", ...Info=["Epoch" "LearnRate"], ...XLabel="Iteration");% Train the network using a custom training loop
epoch = 0;
iteration = 0;
learnRate = initialLearnRate;
start = tic;% Loop over epochs.
while epoch < numEpochs  && ~monitor.Stopepoch = epoch + 1;% Shuffle data,打乱数据.mbq.shuffle% Loop over mini-batches.while hasdata(mbq) && ~monitor.Stopiteration = iteration + 1;% Read mini-batch of data.X = next(mbq);% Evaluate the model gradients and loss using dlfeval and the modelLoss function.[loss,gradients] = dlfeval(@modelLoss, net, X, icCoeff);% Update network parameters using the SGDM optimizer.[net,velocity] = sgdmupdate(net,gradients,velocity,learnRate,momentum);% Update the training progress monitor.recordMetrics(monitor,iteration,LogLoss=log(loss));updateInfo(monitor,Epoch=epoch,LearnRate=learnRate);monitor.Progress = 100 * iteration/numIterations;end% Reduce the learning rate.if mod(epoch,learnRateDropPeriod)==0learnRate = learnRate*learnRateDropFactor;end
endxTest = linspace(0,4,1000)';yModel = minibatchpredict(net,xTest);yAnalytic = exp(-xTest.^2);figure;
plot(xTest,yAnalytic,"-")
hold on
plot(xTest,yModel,"--")
legend("Analytic","Model")

在深度学习中,被求导的对象(样本/输入)一般是多元的(向量x),绝大多数情况是标量y对向量x进行求导,很少向量y对向量x进行求导,否则就会得到复杂的微分矩阵。所以经常把一个样本看做一个整体,它包含多个变量(属性),对其所有属性求导后再加和,就得到了这个样本的偏导数之和。

% 损失函数
function [loss,gradients] = modelLoss(net, X, icCoeff)% 前向传播计算y = forward(net,X);% Evaluate the gradient of y with respect to x. % Since another derivative will be taken, set EnableHigherDerivatives to true.dy = dlgradient(sum(y,"all"),X,EnableHigherDerivatives=true);% Define ODE loss.eq = dy + 2*y.*X;% Define initial condition loss.ic = forward(net,dlarray(0,"CB")) - 1;% Specify the loss as a weighted sum of the ODE loss and the initial condition loss.loss = mean(eq.^2,"all") + icCoeff * ic.^2;% Evaluate model gradients.gradients = dlgradient(loss, net.Learnables);end
http://www.lryc.cn/news/491590.html

相关文章:

  • django authentication 登录注册
  • 三种蓝牙架构实现方案
  • ffmpeg 视频滤镜:高斯模糊-gblur
  • 期权懂|在期权市场中,如何用好双买期权?
  • 【Linux学习】【Ubuntu入门】2-3 make工具和makefile引入
  • 《黑神话:悟空》游戏辅助修改器工具下载指南与操作方法详解
  • C语言菜鸟入门·关键字·union的用法
  • ensp静态路由实验
  • 构建 Java Web 应用程序:从 Servlet 到数据库交互(Eclipse使用JDBC连接Mysql数据库)
  • mfc100u.dll是什么?分享几种mfc100u.dll丢失的解决方法
  • Java面试之多线程并发篇
  • 视频推拉流EasyDSS互联网直播点播平台技术特点及应用场景剖析
  • 安全加固方案
  • Linux firewall防火墙规则
  • 速盾:CDN缓存的工作原理是什么?
  • 日常开发记录-正确的prop传参,reduce搭配promise的使用
  • Hyper-V配置-cnblog
  • 运维Tips:Docker或K8s集群拉取Harbor私有容器镜像仓库配置指南
  • openssl颁发包含主题替代名的证书–SAN
  • Stable Diffusion入门教程
  • H.265流媒体播放器EasyPlayer.js无插件H5播放器关于移动端(H5)切换网络的时候,播放器会触发什么事件
  • conan2 c/c++包管理入门之--------------------------conanfile.py
  • DICOM图像深入解析:为何部分DR/CR图像默认显示为反色?
  • 重新定义社媒引流:AI社媒引流王如何为品牌赋能?
  • 【FPGA】Verilog:利用 4 个串行输入- 串行输出的 D 触发器实现 Shift_register
  • 《硬件架构的艺术》笔记(五):低功耗设计
  • Hive离线数仓结构分析
  • 鱼眼相机模型-MEI
  • GPT系列文章
  • 微软Ignite 2024:建立一个Agentic世界!