当前位置: 首页 > news >正文

python 代码使用 DeepXDE 库实现了一个求解二维非线性偏微分方程(PDE)的功能

import deepxde as dde
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf# 设置时空计算域
Lx = 1  # x 范围从 0 到 1
Ly = 1  # y 范围从 0 到 1
Lt = 0.05  # t 范围从 0 到 0.05
geom = dde.geometry.Rectangle([0, 0], [Lx, Ly])  # 空间域
timedomain = dde.geometry.TimeDomain(0, Lt)  # 时间域
geomtime = dde.geometry.GeometryXTime(geom, timedomain)# 设置 PDE 方程
def pde(x, y):u = y[:, 0:1]  # 提取 u(x, y, t)u_x = dde.grad.jacobian(y, x, i=0, j=0)  # u 对 x 的一阶导数u_y = dde.grad.jacobian(y, x, i=0, j=1)  # u 对 y 的一阶导数u_t = dde.grad.jacobian(y, x, i=0, j=2)  # u 对 t 的一阶导数# 计算 u 对 x 和 y 的梯度grad_u = tf.concat([u_x, u_y], axis=1)  # 使用 TensorFlow 的 concat 函数拼接张量# 计算 (u^2 - u + 1) * 梯度项term = u ** 2 - u + 1  # 计算 (u^2 - u + 1)A = term * grad_u  # 乘以梯度# 计算散度:对 (A) 进行求导,即计算 (A_x + A_y)A_x = dde.grad.jacobian(A, x, i=0, j=0)  # A 对 x 的导数A_y = dde.grad.jacobian(A, x, i=0, j=1)  # A 对 y 的导数# 散度 = A_x + A_ydiv_A = A_x + A_y# 返回 PDE 方程 u_t - div((u^2 - u + 1) * grad(u)) = 0return u_t - div_A# 边界条件:u = 0,在边界上
def boundary(x, on_boundary):return on_boundarybc = dde.icbc.DirichletBC(geomtime, lambda x: 0, boundary, component=0)# 初始条件:u = sin(pi * x) * sin(pi * y)
def ic_func(x):return np.sin(np.pi * x[:, 0:1]) * np.sin(np.pi * x[:, 1:2])ic = dde.icbc.IC(geomtime, ic_func, lambda x, on_initial: on_initial, component=0)# 创建数据对象
data = dde.data.TimePDE(geomtime,pde,[bc, ic],num_domain=10000,  # 训练样本数量num_boundary=8000,  # 边界上的训练样本数量num_initial=5000,  # 初始条件上的训练样本数量num_test=10000,  # 测试样本数量
)# 设置神经网络架构
layer_size = [3] + [50] * 3 + [1]  # 输入层(3维:x, y, t),4个隐藏层,每层80个神经元,输出层(u)
activation = "tanh"  # 激活函数
initializer = "Glorot uniform"  # 权重初始化方法net = dde.nn.FNN(layer_size, activation, initializer)# 创建模型并训练
model = dde.Model(data, net)
model.compile("adam", lr=1e-3)  # 使用 Adam 优化器,学习率为 1e-3
losshistory, train_state = model.train(iterations=3000, display_every=200)# 保存训练历史和状态
dde.saveplot(losshistory, train_state, issave=False, isplot=True)# 可视化结果,绘制 t=0.05 时刻的 u(x, y)
xx, yy = np.meshgrid(np.linspace(0, 1, 28), np.linspace(0, 1, 28))
xy = np.vstack((xx.ravel(), yy.ravel())).T
t = np.ones((xy.shape[0], 1)) * 0.05  # 设置 t=0.05
xy_t = np.hstack((xy, t))  # 合并 x, y, tu_pred = model.predict(xy_t)  # 预测 u(x, y, t) 在 t=0.05 时的值
u_pred = u_pred.reshape(xx.shape)  # 重塑为网格形状# 筛选 u >= 0
u_pred = np.maximum(u_pred, 0)# 绘制 u(x, y) 的 3D 图
fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(xx, yy, u_pred, cmap="viridis")
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('u(x, y, t=0.05)')
ax.set_title('u(x, y, t=0.05)')plt.show()

这段代码使用 DeepXDE 库实现了一个求解二维非线性偏微分方程(PDE)的功能。以下是对代码功能的详细解释:

  1. 设置时空计算域

    • 定义了空间范围 LxLy 分别为 1,时间范围 Lt 为 0.05。
    • 创建了空间域 geom(一个矩形)、时间域 timedomain 以及时空域 geomtime
  2. 设置 PDE 方程

    • 定义了一个函数 pde 来描述 PDE 方程。
    • 计算了函数 uxyt 的一阶导数。
    • 计算了 u 的梯度和 (u^2 - u + 1) * 梯度 项。
    • 计算了上述项的散度,并构建了 PDE 方程 u_t - div((u^2 - u + 1) * grad(u)) = 0
  3. 定义边界条件和初始条件

    • 边界条件:定义了一个函数 boundary 来判断点是否在边界上,并设置边界条件为 u = 0
    • 初始条件:定义了一个函数 ic_func 来描述初始条件 u = sin(pi * x) * sin(pi * y)
  4. 创建数据对象

    • 使用 dde.data.TimePDE 创建了一个数据对象 data,包含了 PDE 方程、边界条件和初始条件。
    • 定义了训练样本、边界样本、初始条件样本和测试样本的数量。
  5. 设置神经网络架构

    • 定义了神经网络的层结构,包括输入层(3 维:xyt)、4 个隐藏层(每层 50 个神经元)和输出层(1 维:u)。
    • 选择了激活函数 tanh 和权重初始化方法 Glorot uniform
  6. 创建模型并训练

    • 使用 dde.Model 创建了一个模型,将数据对象和神经网络传入。
    • 使用 Adam 优化器,学习率为 1e-3 对模型进行编译。
    • 训练模型,迭代 3000 次,并每 200 次迭代显示一次训练信息。
  7. 保存和可视化结果

    • 使用 dde.saveplot 保存训练历史和状态,并绘制损失曲线。
    • t = 0.05 时刻,生成网格点 xxyy,并将其与 t 合并为 xy_t
    • 使用训练好的模型预测 u(x, y, t)t = 0.05 时的值,并重塑为网格形状。
    • 筛选出 u >= 0 的值。
    • 绘制 u(x, y, t = 0.05) 的 3D 图。

综上所述,这段代码实现了使用深度学习方法求解二维非线性 PDE 的功能,并对结果进行了可视化展示。

http://www.lryc.cn/news/519849.html

相关文章:

  • 【Go】:深入解析 Go 1.24:新特性、改进与最佳实践
  • VUE3 一些常用的 npm 和 cnpm 命令,涵盖了修改源、清理缓存、修改 SSL 协议设置等内容。
  • 【SpringBoot】@Value 没有注入预期的值
  • 【STM32-学习笔记-6-】DMA
  • js实现一个可以自动重链的websocket客户端
  • 企业总部和分支通过GRE VPN互通
  • 油猴支持阿里云自动登陆插件
  • 【2024年华为OD机试】(C卷,100分)- 字符串筛选排序 (Java JS PythonC/C++)
  • iOS - runtime总结
  • 第33 章 - ES 实战篇 - MySQL 与 Elasticsearch 的一致性问题
  • Artec Leo 3D扫描仪与Ray助力野生水生动物法医鉴定【沪敖3D】
  • PythonQT5打包exe线程使用
  • 【Powershell】Windows大法powershell好(二)
  • 前端学习-环境this对象以及回调函数(二十七)
  • Element-plus、Element-ui之Tree 树形控件回显Bug问题。
  • 互联网全景消息(10)之Kafka深度剖析(中)
  • Oracle Dataguard(主库为双节点集群)配置详解(5):将主库复制到备库并启动同步
  • pytorch小记(一):pytorch矩阵乘法:torch.matmul(x, y)
  • PyTorch环境配置常见报错的解决办法
  • 罗永浩再创业,这次盯上了 AI?
  • VUE3 provide 和 inject,跨越多层级组件传递数据
  • git打补丁
  • 机械燃油车知识图谱、知识大纲、知识结构(持续更新...)
  • Vue3学习总结
  • Type-C双屏显示器方案
  • 【读书与思考】焦虑与内耗
  • 基于python的网页表格数据下载--转excel
  • Vue.js开发入门:从零开始搭建你的第一个项目
  • LS1046+XILINX XDMA PCIE调通
  • HarmonyOS:@LocalBuilder装饰器: 维持组件父子关系