当前位置: 首页 > news >正文

Triton IR

Triton IR语法

Triton IR的语句遵从MLIR Dialect的语法定义规范,示例如下:

%3 = tt.splat %1 : i32 -> tensor<1024xi32> loc(#loc5)

其中:

%0:右边expression的结果值的名字(Value的name)

tt:表示Dialect名称空间为tt(Triton)

splat:operation的名字

%1:operation的输入

i32:%1的类型

tensor<1024*i32>:operation的结果类型(即3%的类型)

loc(%loc5):对应源码的行号,调试信息。

 如下是一个pytorch cat算子的Triton DSL(inductor产生)

@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):xnumel = 3645440xoffset = tl.program_id(0) * XBLOCKxindex = xoffset + tl.arange(0, XBLOCK)[:]xmask = tl.full([XBLOCK], True, tl.int1)x0 = xindex % 890x1 = (xindex // 890)x2 = xindextmp0 = x0tmp1 = tl.full([1], 0, tl.int64)tmp2 = tmp0 >= tmp1tmp3 = tl.full([1], 390, tl.int64)tmp4 = tmp0 < tmp3tmp5 = tl.load(in_ptr0 + ((390*x1) + x0), tmp4, eviction_policy='evict_last', other=0.0)tmp6 = tmp0 >= tmp3tmp7 = tl.full([1], 890, tl.int64)tmp8 = tmp0 < tmp7tmp9 = tl.load(in_ptr1 + ((500*x1) + ((-390) + x0)), tmp6, eviction_policy='evict_last', other=0.0)tmp10 = tl.where(tmp4, tmp5, tmp9)tl.store(out_ptr0 + (x2), tmp10, None)
''', device_str='cuda')

编译生成的Triton IR如下::

#loc = loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":18:0)
module {tt.func public @triton_(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":18:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":18:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":18:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":18:0)) attributes {noinline = false} {%cst = arith.constant dense<-390> : tensor<1024xi32> loc(#loc1)%cst_0 = arith.constant dense<500> : tensor<1024xi32> loc(#loc1)%cst_1 = arith.constant dense<0.000000e+00> : tensor<1024xf32> loc(#loc1)%cst_2 = arith.constant dense<390> : tensor<1024xi32> loc(#loc1)%cst_3 = arith.constant dense<390> : tensor<1024xi64> loc(#loc1)%cst_4 = arith.constant dense<890> : tensor<1024xi32> loc(#loc1)%c1024_i32 = arith.constant 1024 : i32 loc(#loc1)%0 = tt.get_program_id x : i32 loc(#loc2)%1 = arith.muli %0, %c1024_i32 : i32 loc(#loc3)%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc4)%3 = tt.splat %1 : i32 -> tensor<1024xi32> loc(#loc5)%4 = arith.addi %3, %2 : tensor<1024xi32> loc(#loc5)%5 = arith.remsi %4, %cst_4 : tensor<1024xi32> loc(#loc6)%6 = arith.divsi %4, %cst_4 : tensor<1024xi32> loc(#loc7)%7 = arith.extsi %5 : tensor<1024xi32> to tensor<1024xi64> loc(#loc8)%8 = arith.cmpi slt, %7, %cst_3 : tensor<1024xi64> loc(#loc8)%9 = arith.muli %6, %cst_2 : tensor<1024xi32> loc(#loc9)%10 = arith.addi %9, %5 : tensor<1024xi32> loc(#loc10)%11 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc11)%12 = tt.addptr %11, %10 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc11)%13 = tt.load %12, %8, %cst_1 evictionPolicy = evict_last : tensor<1024x!tt.ptr<f32>> loc(#loc12)%14 = arith.cmpi sge, %7, %cst_3 : tensor<1024xi64> loc(#loc13)%15 = arith.muli %6, %cst_0 : tensor<1024xi32> loc(#loc14)%16 = arith.addi %5, %cst : tensor<1024xi32> loc(#loc15)%17 = arith.addi %15, %16 : tensor<1024xi32> loc(#loc16)%18 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc17)%19 = tt.addptr %18, %17 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc17)%20 = tt.load %19, %14, %cst_1 evictionPolicy = evict_last : tensor<1024x!tt.ptr<f32>> loc(#loc18)%21 = arith.select %8, %13, %20 : tensor<1024xi1>, tensor<1024xf32> loc(#loc19)%22 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc20)%23 = tt.addptr %22, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc20)tt.store %23, %21 : tensor<1024x!tt.ptr<f32>> loc(#loc21)tt.return loc(#loc22)} loc(#loc)
} loc(#loc)
#loc1 = loc(unknown)
#loc2 = loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":20:28)
#loc3 = loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":20:33)
#loc4 = loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":21:36)
#loc5 = loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":21:23)
#loc6 = loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":23:18)
#loc7 = loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":24:20)
#loc8 = loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":30:18)
#loc9 = loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":31:35)
#loc10 = loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":31:41)
#loc11 = loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":31:30)
#loc12 = loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":31:46)
#loc13 = loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":32:19)
#loc14 = loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":35:35)
#loc15 = loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":35:51)
#loc16 = loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":35:42)
#loc17 = loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":35:30)
#loc18 = loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":35:57)
#loc19 = loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":36:33)
#loc20 = loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":37:25)
#loc21 = loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":37:37)
#loc22 = loc("/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py":37:4)

Triton IR依赖的Dialects

编写完triton程序后,导出的IR中,可以看到不止有triton IR,还包含其他的MLIR Dialects,其中包含:

  • Arith: addf, addi, andi, cmpf, cmpi, divf, fptosi, …

  • Math: exp, sin, cos, log, …

  • StructuredControlFlow(scf): for, if, while, yield, condition

  • ControlFlow(cf): br, cond_br

Triton IR Operations

tt.call (triton::CallOp)

语法:

operation ::= `tt.call` $callee `(` $operands `)` attr-dict `:` functional-type($operands, results)

tt.call表示对同一个符号作用域内的函数的直接调用。

示例:

%2 = tt.call @my_add(%0, %1) : (f32, f32) -> f32

tt.func (triton::FuncOp)

function声明或定义,function是一个SSACFG region。

function内的Operation不能隐式地捕获function外定义的值。所有外部引用都必须通过arguments或者attribute来传递。在MLIR中,function的arguments是通过第一个block的block arguments来表达的。

示例:

// External function definitions.
tt.func @abort()
tt.func @scribble(i32, i64, memref<? x 128 x f32, #layout_map0>) -> f64// A function that returns its argument twice:
tt.func @count(%x: i64) -> (i64, i64)attributes {fruit: "banana"} {return %x, %x: i64, i64
}// A function with an argument attribute
tt.func @example_fn_arg(%x: i32 {swift.self = unit})// A function with a result attribute
tt.func @example_fn_result() -> (f64 {dialectName.attrName = 0 : i64})// A function with an attribute
tt.func @example_fn_attr() attributes {dialectName.attrName = false}

SSACFG region

SSACFG region内的语句满足SSA形式,且不包含子Region(既不能包含scf.if/scf.for等),如下就是一个SSACFG region:

func.func @example(%a : i32) -> i32 {// 这是一个 SSACFG Region%cmp = arith.cmpi slt, %a, %c10 : i32cond_br %cmp, ^bb1, ^bb2^bb1:%x = arith.addi %a, %c1 : i32br ^exit(%x : i32)^bb2:%y = arith.subi %a, %c1 : i32br ^exit(%y : i32)^exit(%result : i32):return %result : i32
}

如下不是一个SSACFG Region:

scf.if %cond {// 这里是一个新的 Region(嵌套)scf.yield
}

Block Arguments

对如下函数:

func.func @foo(%arg0: i32, %arg1: f32) -> f32 {// 函数体直接使用 %arg0, %arg1%result = arith.addf %arg1, %arg1 : f32return %result : f32
}

在MLIR的内部实现里,是把function的arguments作为function内第一个基本块(entry block)的 block arguments 来存储:

func.func @foo() -> f32 {
^bb0(%arg0: i32, %arg1: f32):   // ← 参数实际属于 entry block%result = arith.addf %arg1, %arg1 : f32return %result : f32
}

这是因为MLIR的设计要求所有 SSA 值都由某个 block 或 op 产生,这样做也解决了LLVM IR中的phi node的问题。

在LLVM IR中,通过phi node来汇聚不同前驱路径的值,示例如下:

entry:br i1 %cond, label %left, label %rightleft:br label %mergeright:br label %mergemerge:%x = phi i32 [ %v1, %left ], [ %v2, %right ]   ; ← φ 节点ret i32 %x

在MLIR中,通过block arguments,可以实现同等的效果:

func.func @foo(%cond: i1, %v1: i32, %v2: i32) -> i32 {cf.cond_br %cond, ^left, ^right^left:cf.br ^merge(%v1 : i32)     // 把 %v1 作为参数传给 merge^right:cf.br ^merge(%v2 : i32)     // 把 %v2 作为参数传给 merge^merge(%x : i32):             // ← block argument 取代 φreturn %x : i32

tt.return (triton::ReturnOp)

语法:

operation ::= `tt.return` attr-dict ($srcs^ `:` type($srcs))?

表达function的返回操作,拥有变长个数的操作数,操作数的个数和类型必须和function的签名匹配。

示例:

tt.func @foo() : (i32, f8) {...tt.return %0, %1 : i32, f8
}

tt.addptr (triton::AddPtrOp)

语法:

operation ::= `tt.addptr` $ptr `,` $offset attr-dict `:` type($result) `,` type($offset)

张量或标量指针地址线性偏移计算。

示例:

%base = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%idx  = tt.make_range {start = 0, end = 1024} : tensor<1024xi32>// 生成偏移地址
%ptrs = tt.addptr %base, %idx: tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>// 加载数据
%vals = tt.load %ptrs : tensor<1024xf32>

tt.advance (triton::AdvanceOp)

语法:

operation ::= `tt.advance` $ptr `,` `[` $offsets `]` attr-dict `:` type($result)

!tt.ptr<tensor<...>> 类型的指针按给定的 多维偏移量 进行偏移计算,返回一个新的张量指针。

示例:

scf.for %i = %c0 to %c128 step %c32iter_args(%tile_ptr = %base_ptr) -> (!tt.ptr<tensor<32x32xf16>>) {// 使用当前 tile%vals = tt.load %tile_ptr : !tt.ptr<tensor<32x32xf16>>// 推进到下一个 tile(第1个维度上推进 32,第2个维度保持不变)%next_ptr = tt.advance %tile_ptr, [%c32_i32, %c0_i32] : !tt.ptr<tensor<32x32xf16>>scf.yield %next_ptr : !tt.ptr<tensor<32x32xf16>>
}

tt.assert (triton::AssertOp)

语法:

operation ::= `tt.assert` $condition `,` $message attr-dict `:` type($condition)

tt.assert作用在device侧,接收1个condition(i1 类型的标量或张量)和一个string. 如果condition为false,则打印message并终止程序。

示例:

%in_bounds = arith.cmpi slt, %idx, %size : i32
tt.assert %in_bounds, "index out of bounds"

TODO

参考资料:

TritonOps — Triton documentation

http://www.lryc.cn/news/601450.html

相关文章:

  • Python折线图
  • Java面试新趋势:云原生与新兴框架实战解析
  • 零基础学习性能测试第五章:Tomcat的性能分析与调优-Tomcat原理,核心配置项,性能瓶颈分析,调优
  • MySQL ROUTER安装部署
  • Java面试实战:安全框架与大数据技术深度解析
  • 深度解析 inaSpeechSegmenter:高效音频语音分割与检测开源工具
  • 基于 LSTM 与 SVM 融合的时间序列预测模型:理论框架与协同机制—实践算法(1)
  • maven命令详解
  • Redis C++客户端——命令使用
  • 《不只是接口:GraphQL与RESTful的本质差异》
  • Libevent(4)之使用教程(3)配置
  • PHP框架之Laravel框架教程:3. 数据库操作(简要)
  • net8.0一键创建支持(RabbitMQ)
  • 积分兑换小程序Java
  • Torchv Unstrustured 文档解析库
  • Matplotlib(二)- Matplotlib简单绘图
  • 在docker中安装frp实现内网穿透
  • 【数据结构与算法】数据结构初阶:详解排序(二)——交换排序中的快速排序
  • 【51单片机和数码管仿真显示问题共阴共阳代码】2022-9-24
  • 算法竞赛阶段二-数据结构(36)数据结构双向链表模拟实现
  • hackthebox-Pwn-Restaurant(ret2libc)
  • MySQL 8.4 Windows 版安装记录与步骤参考
  • STM32-USART串口实现接收数据三种方法(1.根据\r\n标志符、2.空闲帧中断、3.根据定时器辅助接收)
  • 数据结构第1问:什么是数据结构?
  • 三、构建一个Agent
  • 栈----5.柱状图中最大的矩形
  • RabbitMq 常用命令和REST API
  • 基于分组规则的Excel数据分组优化系统设计与实现
  • 阿里 Qwen3 四模型齐发,字节 Coze 全面开源,GPT-5 8 月初发布!| AI Weekly 7.21-7.27
  • GPT 生成一个打字练习页面