当前位置: 首页 > news >正文

机器学习洞察 | 一文带你“讲透” JAX

上篇文章中,我们详细分享了 JAX 这一新兴的机器学习模型的发展和优势,本文我们将通过 Amazon SageMaker 示例展示如何部署并使用 JAX。

JAX 的工作机制


JAX 的完整工作机制可以用下面这幅图详细解释:

图片来源:“Intro to JAX” video on YouTube by Jake VanderPlas, Tech leader from JAX team

在图片左侧是开发者自己编写的 Python 代码,JAX 会追踪并变换成 JAX IR 的中间表示,并按照 Python 代码,通过 jax.jit 将其编译成 HLO (High Level Optimized) 代码,代表高级的优化代码,提供给 XLA 进行读取。XLA 在获取编译的 HLO 代码之后,会分配到对应的 CPU、GPU、TPU 或者 ASIC。

对于开发者来说,只需完成您的 Python 代码即可实现这一流程。开发者可以将 JAX 转换视为首先对 Python 函数进行跟踪专门化,然后将其转换为一个小而行为良好的中间形式,然后使用特定于转换的解释规则进行解释。

为什么 JAX 可以在如此小的软件包中提供如此强大的功能呢?

首先,它从熟悉且灵活的编程接口(使用 NumPy 的 Python)开始,并且使用实际的 Python 解释器来完成大部分繁重的工作;其次,它将计算的本质提炼成一个静态具有高阶功能的类型表达式语言,即 Jaxpr 语言。

JAX 应用场景


自 2019 年 JAX 出现之后,使用它的开发者逐年增多。在 2022 年更是达到了非常火热的状态,甚至有人认为它有可能会取代其他的机器学习框架。

支持 JAX 生态的应用场景包括:

  • 深度学习 (Deep Learning):JAX 在深度学习场景下应用很广泛,很多团队基于 JAX 开发了更加高级的 API 支持不同的场景,方便开发者使用。

  • 科学模拟 (Scientific Simulation):JAX 的出现不仅仅是针对于深度学习,其实也拥有很多其他的使命,如科学模拟。

  • 机器人与控制系统 (Robotics and Control Systems)

  • 概率编程 (Probabilistic Programming)

训练和部署深度学习模型


我们用下面这个具体例子展示使用 JAX 来和 Amazon SageMaker 训练和部署深度学习模型,会用到 Amazon SageMaker 的 BYOC 这种模式。

如上图所示,在这个 Amazon SageMaker 的示例中提供了 JAX 的代码示例:https://sagemaker-examples.readthedocs.io/en/latest/advanced_functionality/jax_bring_your_own/train_deploy_jax.html

在 Amazon SageMaker 上基于 JAX 的框架可使用自定义的容器来训练神经网络。

如图的 Amazon SageMaker Examples 提供的 JAX 示例中,我们使用自定义容器在 SageMaker 上 基于 JAX 框架或库训练神经网络。这在单个容器上是可能的,因为我们使用了 sagemaker-training-toolkit,它允许你在自己的自定义容器中使用脚本模式。自定义容器可以使用内置的 SageMaker 训练作业功能,如竞价训练和超参数调整。

训练模型后,您可以将经过训练的模型部署到托管端点。如前所述,SageMaker 具有推理容器,这些容器已针对亚马逊云科技的硬件和常用深度学习框架进行了优化。其中一项优化是针对 TensorFlow 框架的优化。由于 JAX 支持将模型导出为 TensorFlow SavedModel 格式,因此我们使用该功能来展示如何在优化的 SageMaker TensorFlow 推理端点上部署经过训练的模型。

整个训练和部署主要分为以下五个步骤:

  1. 创建 Docker 镜像并将其推送到 Amazon ECR。

  1. 使用 SageMaker 开发工具包传教自定义框架估算器,以便将模型输出归类为 TensorFlowModel。

  1. 代码仓库中有训练估算器的脚本。

  1. 使用 GPU 上的 SageMaker 训练作业来训练每个模型。

  1. 将模型部署到完全托管的终端节点。

下面我们来看看详细步骤:

  1. 创建 Docker 镜像并将其推送到 Amazon ECR。

*创建使用 JAX 训练模型容器的 Dockerfile

Docker 映像是在 NVIDIA 提供的支持 CUDA 的容器之上构建的。为了确保作为 JAX 中功能基础的 jaxlibpackage 支持 CUDA,请从 jax_releases 存储库中下载 jaxlib 软件包。

  • AX releases

https://storage.googleapis.com/jax-releases/jax_releases.html

这里需要注意的是:为了确保作为 JAX 中的功能基础的 JAX library package 能够支持 cuda,建议在去做这个创建自定义容器时,去看一下目前 JAX release 这个存储库中,它下载的这个 JAX library 包的版本号或者相关注意事项等等。

2、使用 SageMaker 开发工具包创建自定义框架估算器,以便将模型输出归类为 TensorFlowModel。

创建基本 SageMaker 框架估算器的子类,将估算器的模型类型指定为 TensorFlow 模型。为此,我们指定了一个自定义 create_model 方法,该方法使用现有的 TensorFlowModel 类来启动推理容器。

3、通过代码仓库训练估算器的脚本。

您可以通过传统的 SageMaker Python SDK 工作流通过模型执行训练、部署和运行推理。我们确保导入并初始化自定义框架估算器的代码片段中定义的 JaxEstimator,然后运行标准的 .fit () 和 .deploy () 调用。

对于 JAX ,可以调用 jax2tf 函数来执行相同的操作。代码在存储库中可用。设置正确的路径 /opt/ml/model/1 非常重要,这是 SageMaker wrapper(封装器) 假定模型已存储的地方。、

前面提到的 JAX 和 TF 的互操作性,目前 JAX 是通过 JAX to TF 这样的一个软件包,来为 JAX 和 TF 的互操作性提供支持,那 jax2tf.convert 是用于在 TensorFlow 的上下文中使用 JAX 函数,那 jax2tf.call_tf 是用于在 JAX 的上下文中使用的 TensorFlow 函数互操作来完成的。

4、使用 GPU 上的 SageMaker 训练作业来训练每个模型。

  1. 将模型部署到完全托管的终端节点。

vanilla_jax_predictor = vanilla_jax_estimator.deploy(initial_instance_count=1, instance_type="ml.m4.xlarge"
)
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
def test_image(predictor, test_images, test_labels, image_number):np_img 
= np.expand_dims(np.expand_dims(test_images[image_number], axis=
-1
), axis=
0
)result = predictor.predict(np_img)pred_y = np.argmax(result["predictions"])print("True Label:", test_labels[image_number])print("Predicted Label:", pred_y)plt.imshow(test_images[image_number]

*部署和准备输入的测试图像

*进行推理

有关在 Amazon SageMaker 上使用 JAX 训练和部署深度学习模型的详细过程和代码,请参考亚马逊云科技官方博客

如图所示,上面的两张图是一个部署模型的例子,下面的图是进行推理的例子。由于我们的 Framework Estimator 知道模型将使用 TensorFlowModel 提供服务,因此部署这些端点只是对 estimator.deploy () 方法做调用即可。

参考资料

  • Training and Deploying ML Models using JAX on SageMaker

  • Train and deploy deep learning models using JAX with Amazon SageMaker

  • AX core from scratch

  • Building JAX from source

JAX 是一种越来越流行的库,它支持原生 Python 或 NumPy 函数的可组合函数转换,可用于高性能数值计算和机器学习研究。JAX 提供了编写 NumPy 程序的能力,这些程序可以使用 GPU/TPU 自动差分和加速,从而形成了更灵活的框架来支持现代深度学习架构。在这两篇文章中我们讨论了有关 JAX 的一些主题,希望对您用使用 JAX 这一框架进行深度学习研究有所帮助。

往期推荐


  • 机器学习洞察 | JAX,机器学习领域的“新面孔”

  • 机器学习洞察 | 降本增效,无服务器推理是怎么做到的?

  • 机器学习洞察 | 分布式训练让机器学习更加快速准确

http://www.lryc.cn/news/10575.html

相关文章:

  • OpenFaaS介绍
  • 【算法设计与分析】STL容器、递归算法、分治法、蛮力法、回溯法、分支限界法、贪心法、动态规划;各类算法代码汇总
  • vue初识
  • 火山引擎入选《2022 爱分析 · DataOps 厂商全景报告》,旗下 DataLeap 产品能力获认可
  • java-spring_bean的生命周期
  • 微服务相关概念
  • 论文解读:(TransA)TransA: An Adaptive Approach for Knowledge Graph Embedding
  • js将数字转十进制+十六进制(联动el-ui下拉选择框)
  • 关于RedissonLock的一些所思
  • C++:倒牛奶问题
  • MySQL8.x group_by报错的4种解决方法
  • 具有非线性动态行为的多车辆列队行驶问题的基于强化学习的方法
  • TrueNas篇-硬盘直通
  • 手机子品牌的“性能战事”:一场殊途同归的大混战
  • dockerfile自定义镜像安装jdk8,nginx,后端jar包和前端静态文件,并启动容器访问
  • MongoDB 全文检索
  • JS中声明变量,使用 var、let、const的区别
  • 汽车改装避坑指南:大尾翼
  • 【Unity资源下载】POLYGON Dungeon Realms - Low Poly 3D Art by Synty
  • 知识汇总:Python办公自动化应该学习哪些内容
  • 软件架构知识5-架构设计流程
  • 【银河麒麟V10操作系统】修改屏幕分辨率的方法
  • pdf生成为二维码
  • Yaklang websocket劫持教程
  • 基于AIOT技术的智慧校园空调集中管控系统设计与实现
  • 【每日一题】 将一句话单词倒置,标点不倒置
  • 宽刈幅干涉雷达高度计SWOT(Surface Water and Ocean Topography)卫星进展(待完善)
  • openjdk源码==类加载过程
  • vue2的后台管理系统 迁移到 vue3后台管理系统
  • 2023年美赛F题