当前位置: 首页 > news >正文

JAX study notes[16]

文章目录

  • Pytrees
  • references

Pytrees

  1. in essence, JAX function and transform act on arrays,actually most opeartion handling arrays base on the collection of arrays.
  2. JAX use the Pytree which is an abstract object to control a lot of collections with consolidated former instead of make various structures for different cases.
import jax
import jax.numpy as jnp
params = [11,120,1000000000,"abcd",jnp.ones(3),{'n': 5, 'W': jnp.zeros(2)}]
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PS E:\learn\learnpy> & D:/Python312/python.exe e:/learn/learnpy/learn1.py
PyTreeDef([*, *, *, *, *, {'W': *, 'n': *}])
[11, 120, 1000000000, 'abcd', Array([1., 1., 1.], dtype=float32), Array([0., 0.], dtype=float32), 5]
  1. JAX provide plenty of facilities to work with PyTrees.
  • tree.map
    to make a new pytree through puting the input some arguments formed as pytree into a function.
jax.tree.map(f, tree, *rest, is_leaf=None)
import jax
import jax.numpy as jnp
import math
params1 = [x for x in jnp.arange(1,10,2)]
params2 = [x for x in jnp.arange(10,1,-2)]
print(jax.tree.map(lambda a,b: math.sqrt(a^2+b^2),params1,params2))
  • tree.reduce
    to achieve reduce manipulation and get reduced value.
jax.tree.reduce(function: Callable[[T, Any], T], tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) → T
import jax
import operatorparams1 = [1,2,3]
params2 = [4,5]result=jax.tree.reduce(operator.add, [params1, params2])
print(result)
PS E:\learn\learnpy> & D:/Python312/python.exe e:/learn/learnpy/learn1.py
15

references

https://docs.jax.dev/

http://www.lryc.cn/news/585649.html

相关文章:

  • Java项目中图片加载路径问题解析
  • Python Day10
  • LLM场景下的强化学习【GRPO】
  • Spring Boot整合MyBatis+MySQL实战指南(Java 1.8 + 单元测试)
  • 上位机知识篇---端口
  • latex格式中插入eps格式的图像的编译命令
  • 异步复习(线程)
  • 【第四节】ubuntu server安装docker
  • 从0开始学习R语言--Day44--LR检验
  • 文章发布易优CMS(Eyoucms)网站技巧
  • 企业IT管理——医院数据备份与存储制度模板
  • Linux自动化构建工具(一)
  • 多表查询-2-多表查询概述
  • 蔚来测开一面:HashMap从1.7开始到1.8的过程,既然都解决不了并发安全问题,为什么还要进一步解决环形链表的问题?
  • 前端面试专栏-算法篇:23. 图结构与遍历算法
  • USB一线连多屏?Display Link技术深度解析
  • React中Redux基础和路由介绍
  • 适配多场景,工业显示器让操作更高效
  • 前端八股-promise
  • Spring的事务控制——学习历程
  • C++设计秘籍:为什么所有参数都需类型转换时,非成员函数才是王道?
  • Python-正则表达式-信息提取-滑动窗口-数据分发-文件加载及分析器-浏览器分析-学习笔记
  • (补充)RS422
  • Qt 实现新手引导
  • 分布式推客系统全栈开发指南:SpringCloud+Neo4j+Redis实战解析
  • 【世纪龙科技】几何G6新能源汽车结构原理教学软件
  • 【龙泽科技】新能源汽车维护与动力蓄电池检测仿真教学软件【吉利几何G6】
  • 重构下一代智能电池“神经中枢”:GCKontrol定义高性能BMS系统级设计标杆
  • Java :T extends Comparable<? super T> 和 T extends Comparable<T>的区别
  • 李沐动手学深度学习Pytorch-v2笔记【07自动求导代码实现】