JAX study notes[16]
文章目录
- Pytrees
- references
Pytrees
- in essence, JAX function and transform act on arrays,actually most opeartion handling arrays base on the collection of arrays.
- JAX use the Pytree which is an abstract object to control a lot of collections with consolidated former instead of make various structures for different cases.
import jax
import jax.numpy as jnp
params = [11,120,1000000000,"abcd",jnp.ones(3),{'n': 5, 'W': jnp.zeros(2)}]
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PS E:\learn\learnpy> & D:/Python312/python.exe e:/learn/learnpy/learn1.py
PyTreeDef([*, *, *, *, *, {'W': *, 'n': *}])
[11, 120, 1000000000, 'abcd', Array([1., 1., 1.], dtype=float32), Array([0., 0.], dtype=float32), 5]
- JAX provide plenty of facilities to work with PyTrees.
- tree.map
to make a new pytree through puting the input some arguments formed as pytree into a function.
jax.tree.map(f, tree, *rest, is_leaf=None)
import jax
import jax.numpy as jnp
import math
params1 = [x for x in jnp.arange(1,10,2)]
params2 = [x for x in jnp.arange(10,1,-2)]
print(jax.tree.map(lambda a,b: math.sqrt(a^2+b^2),params1,params2))
- tree.reduce
to achieve reduce manipulation and get reduced value.
jax.tree.reduce(function: Callable[[T, Any], T], tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) → T
import jax
import operatorparams1 = [1,2,3]
params2 = [4,5]result=jax.tree.reduce(operator.add, [params1, params2])
print(result)
PS E:\learn\learnpy> & D:/Python312/python.exe e:/learn/learnpy/learn1.py
15
references
https://docs.jax.dev/