当前位置: 首页 > news >正文

JAX study notes[15]

文章目录

  • the symmetric difference of sets
  • limit superior and limit inferior
  • references

the symmetric difference of sets

the symmetric difference can be express as follows:
A Δ B = ( A \ B ) ∪ ( B \ A ) A \Delta B=(A\backslash B)\cup (B \backslash A) AΔB=(A\B)(B\A)
the symmetric difference of sets A and B mean the element belong to A or belong to B but not in their intersection.

import jax.numpy as jnp
from jax import jitdef symmetric_difference(a, b):"""计算两个集合的对称差集参数:a, b: 两个一维JAX数组,表示输入集合返回:一维JAX数组,包含只在a或只在b中的元素"""# 找出在a中但不在b中的元素a_only = jnp.setdiff1d(a, b)# 找出在b中但不在a中的元素b_only = jnp.setdiff1d(b, a)# 合并结果return jnp.concatenate([a_only, b_only])# 创建两个集合
set_a = jnp.array([11, 22, 26, 41, 5])
set_b = jnp.array([41, 52, 26, 7, 8])# 计算对称差集
result = symmetric_difference(set_a, set_b)
print(result)  
[ 5 11 22  7  8 52]

limit superior and limit inferior

在这里插入图片描述

  • In JAX (and NumPy), when you apply jnp.cumsum to a boolean array, the boolean values are automatically upcast to integers before the cumulative sum is computed.
    How It Works:

    True is treated as 1

    False is treated as 0

import jax.numpy as jnpbool_arr = jnp.array([True, False, True, True, False])result = jnp.cumsum(bool_arr)
print(result)  # Output: [1 1 2 3 3]

to hand the Multi-dimensional Arrays

bool_matrix = jnp.array([[True, False], [False, True]])# Cumulative sum along axis=0 (rows)
print(jnp.cumsum(bool_matrix, axis=0))
# Output:
# [[1 0]
#  [1 1]]# Cumulative sum along axis=1 (columns)
print(jnp.cumsum(bool_matrix, axis=1))
# Output:
# [[1 1]
#  [0 1]]

JAX accumulates along the downward direction by column when axis equals1 and along the rightward direction by row when it equals 0.

how to use jnp.all can be explained as follows:

matrix = jnp.array([[True, False], [True, True]])
print(jnp.all(matrix, axis=0))  # 沿列聚合 → [True, False]
print(jnp.all(matrix, axis=1))  # 沿行聚合 → [False, True]

to calculate the cumulative product of the array elements can use jnp.cumprod

import jax.numpy as jnparr = jnp.array([1, 2, 3, 4])
result = jnp.cumprod(arr)
print(result)  # 输出: [1, 2, 6, 24] (计算过程:1, 1×2=2, 2×3=6, 6×4=24)
matrix = jnp.array([[1, 2], [3, 4]])# 沿 axis=0(行方向)
print(jnp.cumprod(matrix, axis=0))
# 输出: [[1, 2], [1×3=3, 2×4=8]]# 沿 axis=1(列方向)
print(jnp.cumprod(matrix, axis=1))
# 输出: [[1, 1×2=2], [3, 3×4=12]]
import jax.numpy as jnp
from jax import vmap, jit
import jaxdef sets_to_mask(sets, all_elements):"""将集合列转换为布尔掩码矩阵"""def is_in_set(s):return jnp.isin(all_elements, s)return vmap(is_in_set)(jnp.stack(sets))def safety_concat(arr):return jnp.unique(jnp.concatenate(arr))  def limsup_jax(sets):"""计算上限集:元素属于无限多个集合"""all_elements = safety_concat(sets)mask = sets_to_mask(sets, all_elements)def is_in_limsup(j):return jnp.all(jnp.cumsum(mask[:, j][::-1]) > 0)limsup_mask = vmap(is_in_limsup)(jnp.arange(mask.shape[1]))return all_elements[limsup_mask]def liminf_jax(sets):"""计算下限集:元素从某时刻开始永远属于集合"""all_elements = safety_concat(sets)mask = sets_to_mask(sets, all_elements)liminf_masks=jnp.zeros((0, mask.shape[1]))def is_in_liminf(i,j):return jnp.all(jnp.cumprod(mask[:, j][i::1]) > 0)for ii in jnp.arange(mask.shape[0]):liminf_mask = vmap(is_in_liminf,in_axes=(None, 0))(ii,jnp.arange(mask.shape[1]))liminf_masks=jnp.vstack([liminf_masks, liminf_mask])return all_elements[jnp.any(liminf_masks[:-1,:], axis=0)]# 测试用例
sets = [jnp.array([1, 2 ]),  # A1jnp.array([2, 3]),  # A2jnp.array([1, 3]),  # A3jnp.array([ 2, 3]),  # A4
]# 预期结果分析
print("所有元素:", jnp.unique(jnp.concatenate(sets)))  # [1 2 3]
print("掩码矩阵:\n", sets_to_mask(sets, jnp.unique(jnp.concatenate(sets))))print("\nLimsup:", limsup_jax(sets))  # 正确输出应为 [2 3]print("\nLiminf:", liminf_jax(sets))   # 正确输出应为 [ 3]
所有元素: [1 2 3]
掩码矩阵:[[ True  True False][False  True  True][ True False  True][False  True  True]]Limsup: [2 3]Liminf: [3]

references

  1. 《实变函数论(周民强)》
  2. deepseek
http://www.lryc.cn/news/581514.html

相关文章:

  • 百度文心大模型 4.5 开源深度测评:技术架构、部署实战与生态协同全解析
  • 前端环境nvm/pnpm下载配置
  • 在C#中,可以不实例化一个类而直接调用其静态字段
  • 【Elasticsearch入门到落地】15、DSL排序、分页及高亮
  • 【HarmonyOS】鸿蒙应用开发Text控件常见错误
  • 深入解析Vue中v-model的双向绑定实现原理
  • D3 面试题100道之(61-80)
  • Qt实现外网双向音视频通话/支持嵌入式板子/实时性好延迟低/可以加水印
  • C++基础复习笔记
  • 【网络系列】HTTP 429 状态码
  • Debezium日常分享系列之:认识Debezium Operator
  • Go语言实现双Token登录的思路与实现
  • UNIX程序设计基本概念和术语
  • 玄机——第一章日志分析-mysql应急响应
  • docker 无法拉取镜像解决方法
  • 系统架构设计师论文分享-论软件体系结构的演化
  • Apache Iceberg数据湖基础
  • 极简的神经网络反向传播例子
  • 探寻《答案之书》:在随机中寻找生活的指引
  • 5种高效解决Maven依赖冲突的方法
  • Golang读取ZIP压缩包并显示Gin静态html网站
  • c++对象池
  • 数据库|达梦DM数据库安装步骤
  • [论文阅读] 人工智能 + 软件工程 | 自然语言驱动结构代码搜索:突破DSL学习壁垒的创新方法
  • 分布式压测
  • python高级变量XIII
  • jenkins安装
  • 分布式事务解决方案(二)
  • 探索实现C++ STL容器适配器:优先队列priority_queue
  • react当中的this指向