JAX study notes[15]
文章目录
- the symmetric difference of sets
- limit superior and limit inferior
- references
the symmetric difference of sets
the symmetric difference can be express as follows:
A Δ B = ( A \ B ) ∪ ( B \ A ) A \Delta B=(A\backslash B)\cup (B \backslash A) AΔB=(A\B)∪(B\A)
the symmetric difference of sets A and B mean the element belong to A or belong to B but not in their intersection.
import jax.numpy as jnp
from jax import jitdef symmetric_difference(a, b):"""计算两个集合的对称差集参数:a, b: 两个一维JAX数组,表示输入集合返回:一维JAX数组,包含只在a或只在b中的元素"""# 找出在a中但不在b中的元素a_only = jnp.setdiff1d(a, b)# 找出在b中但不在a中的元素b_only = jnp.setdiff1d(b, a)# 合并结果return jnp.concatenate([a_only, b_only])# 创建两个集合
set_a = jnp.array([11, 22, 26, 41, 5])
set_b = jnp.array([41, 52, 26, 7, 8])# 计算对称差集
result = symmetric_difference(set_a, set_b)
print(result)
[ 5 11 22 7 8 52]
limit superior and limit inferior
-
In JAX (and NumPy), when you apply
jnp.cumsum
to a boolean array, the boolean values are automatically upcast to integers before the cumulative sum is computed.
How It Works:True is treated as 1
False is treated as 0
import jax.numpy as jnpbool_arr = jnp.array([True, False, True, True, False])result = jnp.cumsum(bool_arr)
print(result) # Output: [1 1 2 3 3]
to hand the Multi-dimensional Arrays
bool_matrix = jnp.array([[True, False], [False, True]])# Cumulative sum along axis=0 (rows)
print(jnp.cumsum(bool_matrix, axis=0))
# Output:
# [[1 0]
# [1 1]]# Cumulative sum along axis=1 (columns)
print(jnp.cumsum(bool_matrix, axis=1))
# Output:
# [[1 1]
# [0 1]]
JAX accumulates along the downward direction by column when axis equals1 and along the rightward direction by row when it equals 0.
how to use jnp.all
can be explained as follows:
matrix = jnp.array([[True, False], [True, True]])
print(jnp.all(matrix, axis=0)) # 沿列聚合 → [True, False]
print(jnp.all(matrix, axis=1)) # 沿行聚合 → [False, True]
to calculate the cumulative product of the array elements can use jnp.cumprod
import jax.numpy as jnparr = jnp.array([1, 2, 3, 4])
result = jnp.cumprod(arr)
print(result) # 输出: [1, 2, 6, 24] (计算过程:1, 1×2=2, 2×3=6, 6×4=24)
matrix = jnp.array([[1, 2], [3, 4]])# 沿 axis=0(行方向)
print(jnp.cumprod(matrix, axis=0))
# 输出: [[1, 2], [1×3=3, 2×4=8]]# 沿 axis=1(列方向)
print(jnp.cumprod(matrix, axis=1))
# 输出: [[1, 1×2=2], [3, 3×4=12]]
import jax.numpy as jnp
from jax import vmap, jit
import jaxdef sets_to_mask(sets, all_elements):"""将集合列转换为布尔掩码矩阵"""def is_in_set(s):return jnp.isin(all_elements, s)return vmap(is_in_set)(jnp.stack(sets))def safety_concat(arr):return jnp.unique(jnp.concatenate(arr)) def limsup_jax(sets):"""计算上限集:元素属于无限多个集合"""all_elements = safety_concat(sets)mask = sets_to_mask(sets, all_elements)def is_in_limsup(j):return jnp.all(jnp.cumsum(mask[:, j][::-1]) > 0)limsup_mask = vmap(is_in_limsup)(jnp.arange(mask.shape[1]))return all_elements[limsup_mask]def liminf_jax(sets):"""计算下限集:元素从某时刻开始永远属于集合"""all_elements = safety_concat(sets)mask = sets_to_mask(sets, all_elements)liminf_masks=jnp.zeros((0, mask.shape[1]))def is_in_liminf(i,j):return jnp.all(jnp.cumprod(mask[:, j][i::1]) > 0)for ii in jnp.arange(mask.shape[0]):liminf_mask = vmap(is_in_liminf,in_axes=(None, 0))(ii,jnp.arange(mask.shape[1]))liminf_masks=jnp.vstack([liminf_masks, liminf_mask])return all_elements[jnp.any(liminf_masks[:-1,:], axis=0)]# 测试用例
sets = [jnp.array([1, 2 ]), # A1jnp.array([2, 3]), # A2jnp.array([1, 3]), # A3jnp.array([ 2, 3]), # A4
]# 预期结果分析
print("所有元素:", jnp.unique(jnp.concatenate(sets))) # [1 2 3]
print("掩码矩阵:\n", sets_to_mask(sets, jnp.unique(jnp.concatenate(sets))))print("\nLimsup:", limsup_jax(sets)) # 正确输出应为 [2 3]print("\nLiminf:", liminf_jax(sets)) # 正确输出应为 [ 3]
所有元素: [1 2 3]
掩码矩阵:[[ True True False][False True True][ True False True][False True True]]Limsup: [2 3]Liminf: [3]
references
- 《实变函数论(周民强)》
- deepseek