MLIR TableGen
简介
TableGen 是一种领域特定语言(DSL),TableGen 的设计目标是允许编写灵活的描述,并将记录的通用特性提取出来,从而减少重复代码并提高代码的可维护性。
TableGen的工作流程:
前端解析:
-
TableGen 的前端解析
.td
文件,这些文件包含了用 TableGen 语言编写的声明和定义。 -
前端将这些声明和定义实例化,生成一个中间表示(IR),这个 IR 包含了所有定义的记录(records)和类(classes)。
后端处理:
-
生成的中间表示(IR)会被传递给特定领域的后端进行处理。
-
后端根据 IR 生成目标代码,通常是 C++ 代码。不同的后端可以生成不同类型的代码,例如 LLVM 的指令集描述、MLIR 的操作定义等。
TableGen DSL当前主要的应用:
- LLVM Target-Independent Code Generator
-
Clang diagnostics and attributes
-
MLIR Dialects Code Generator
由于本文是Triton源码解析的系列文章,后续重点分析在MLIR Dialects Code Generator中的应用。在MLIR中,TableGen主要用于代码生成,减少新增Dialect/Pass等需要手写的代码。
TableGen 基本概念
TableGen 的语法基于 C++ 模板,包含built-in types和specification。此外,TableGen 的语法引入了一些自动化概念,multiclass、foreach、let 等。
TableGen文件包含2个关键部分:classe和definition,这两者都是record。
TableGen record组成:
- 唯一的名字
- values列表
- superclasses列表
TableGen definition
TableGen definition是concrete record,通常不包含未定义的值,使用def关键字标记。
示例:Triton IR中float类型:
// Floating-point Type
def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2,
F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
TableGen class
TableGen class是abstract record,用于构建和描述其他record。允许用户构建领域抽象。class可以通过def关键字实例化,生成一个definition。
示例:Triton IR中的TritonTypeDef:
class TritonTypeDef<string name, string _mnemonic, list<Trait> traits = []>: TypeDef<Triton_Dialect, name, traits> {// Used by printer/parserlet mnemonic = _mnemonic;
}// Pointer Type in C++ (corresponding to `TT_PtrOf`)
def TT_PtrType : TritonTypeDef<"Pointer", "ptr"> {let summary = "Pointer type (`::mlir::triton::PointerType`) in Triton IR type system";let description = [{Pointer type in Triton IR type system, which could be pointing to scalars or tensors.}];let parameters = (ins "Type":$pointeeType, "int":$addressSpace);let builders = [TypeBuilderWithInferredContext<(ins"Type":$pointeeType,"int":$addressSpace), [{return $_get(pointeeType.getContext(), pointeeType, addressSpace);}]>];let hasCustomAssemblyFormat = 1;let skipDefaultBuilders = 1;
}
TableGen multiclass
multiclass是一种特殊的class,表示一组相关的abstract records,通过def实例化,生成一组definitions。
示例(Triton和MLIR中未使用):
multiclass ro_signed_pats<string T, string Rm, dag Base, dag Offset, dag Extend,dag address, ValueType sty> {
def : Pat<(i32 (!cast<SDNode>("sextload" # sty) address)),(!cast<Instruction>("LDRS" # T # "w_" # Rm # "_RegOffset")Base, Offset, Extend)>;def : Pat<(i64 (!cast<SDNode>("sextload" # sty) address)),(!cast<Instruction>("LDRS" # T # "x_" # Rm # "_RegOffset")Base, Offset, Extend)>;
}defm : ro_signed_pats<"B", Rm, Base, Offset, Extend,!foreach(decls.pattern, address,!subst(SHIFT, imm_eq0, decls.pattern)),i8>;
TableGen 语法
Literals
支持Numeric Literals和String Literals。
Identifirs
和C++类似,但支持数字开头,保留关键字有:
assert bit bits class code
dag def dump else false
foreach defm defset defvar field
if in include int let
list multiclass string then true
Bang Operators
支持算术运算,逻辑运算,类型转换和检查,列表和集合操作,DAG 操作,字符串操作等。
BangOperator ::= one of!add !and !cast !con !dag!div !empty !eq !exists !filter!find !foldl !foreach !ge !getdagarg!getdagname !getdagop !getdagopname !gt !head!if !initialized !instances !interleave !isa!le !listconcat !listflatten !listremove !listsplat!logtwo !lt !match !mul !ne!not !or !range !repr !setdagarg!setdagname !setdagop !setdagopname !shl !size!sra !srl !strconcat !sub !subst!substr !tail !tolower !toupper !xor
Include
和C++类似:
IncludeDirective ::= "include" TokString
Preprocess
和C++类似:
PreprocessorDirective ::= "#define" | "#ifdef" | "#ifndef"
Types
静态类型,支持的类型:
Type ::= "bit" | "int" | "string" | "dag" | "code"| "bits" "<" TokInteger ">"| "list" "<" Type ">"| ClassID
ClassID ::= TokIdentifier
Value & Expression
SimpleValue ::= SimpleValue1| SimpleValue2| SimpleValue3| SimpleValue4| SimpleValue5| SimpleValue6| SimpleValue7| SimpleValue8| SimpleValue9
SimpleValue1 ::= TokInteger | TokString+ | TokCode
Statement
语句:
TableGenFile ::= (Statement | IncludeDirective| PreprocessorDirective)*
Statement ::= Assert | Class | Def | Defm | Defset | Deftype| Defvar | Dump | Foreach | If | Let | MultiClass
class
定义了一个抽象的,可以被其他record继承的record。
Class ::= "class" ClassID [TemplateArgList] RecordBody
TemplateArgList ::= "<" TemplateArgDecl ("," TemplateArgDecl)* ">"
TemplateArgDecl ::= Type TokIdentifier ["=" Value]
Record bodies
跟在class和definition后面:
RecordBody ::= ParentClassList Body
ParentClassList ::= [":" ParentClassListNE]
ParentClassListNE ::= ClassRef ("," ClassRef)*
ClassRef ::= (ClassID | MultiClassID) ["<" [ArgValueList] ">"]
ArgValueList ::= PostionalArgValueList [","] NamedArgValueList
PostionalArgValueList ::= [Value {"," Value}*]
NamedArgValueList ::= [NameValue "=" Value {"," NameValue "=" Value}*]
Body ::= ";" | "{" BodyItem* "}"
BodyItem ::= Type TokIdentifier ["=" Value] ";"| "let" TokIdentifier ["{" RangeList "}"] "=" Value ";"| "defvar" TokIdentifier "=" Value ";"| Assert
def
定义一个新的concrete record:
Def ::= "def" [NameValue] RecordBody
NameValue ::= Value (parsed in a special mode)
let
let
语句收集一组字段值,并将这些值应用于 let
语句作用域内定义的所有class和record:
Let ::= "let" LetList "in" "{" Statement* "}"| "let" LetList "in" Statement
LetList ::= LetItem ("," LetItem)*
LetItem ::= TokIdentifier ["<" RangeList ">"] "=" Value
let的语义是
设置默认值或者覆盖继承的值(override),但不能覆盖template参数的值。
当record中只有少数字段需要覆盖(override)的时候,可以使用top-level的let来减少重复代码,并且let还可以嵌套,在如下示例中,isCall和Defs会分别覆盖里面3个record(CALLpcrel32/CALL32r/CALL32m)的字段值isCall和Defs:
let isCall = true in// All calls clobber the non-callee saved registers...let Defs = [EAX, ECX, EDX, FP0, FP1, FP2, FP3, FP4, FP5, FP6, ST0,MM0, MM1, MM2, MM3, MM4, MM5, MM6, MM7, XMM0, XMM1, XMM2,XMM3, XMM4, XMM5, XMM6, XMM7, EFLAGS] in {def CALLpcrel32 : Ii32<0xE8, RawFrm, (outs), (ins i32imm:$dst, variable_ops),"call\t${dst:call}", []>;def CALL32r : I<0xFF, MRM2r, (outs), (ins GR32:$dst, variable_ops),"call\t{*}$dst", [(X86call GR32:$dst)]>;def CALL32m : I<0xFF, MRM2m, (outs), (ins i32mem:$dst, variable_ops),"call\t{*}$dst", []>;}
multiclasses
方便一次实例化多个definition。
MultiClass ::= "multiclass" TokIdentifier [TemplateArgList]ParentClassList"{" MultiClassStatement+ "}"
MultiClassID ::= TokIdentifier
MultiClassStatement ::= Assert | Def | Defm | Defvar | Foreach | If | Let
defm
和multiclasses配套使用,一次实例化多个definition。
示例:
假设ISA中,对所有具体的指令,都存在两种instruction形式:
reg = reg op reg
reg = reg op imm
这样就可以用multiclass来同时定义两种形式,然后用defm来定义具体的instrution:
def ops;
def GPR;
def Imm;
class inst <int opc, string asmstr, dag operandlist>;multiclass ri_inst <int opc, string asmstr> {def _rr : inst<opc, !strconcat(asmstr, " $dst, $src1, $src2"),(ops GPR:$dst, GPR:$src1, GPR:$src2)>;def _ri : inst<opc, !strconcat(asmstr, " $dst, $src1, $src2"),(ops GPR:$dst, GPR:$src1, Imm:$src2)>;
}// Define records for each instruction in the RR and RI formats.
defm ADD : ri_inst<0b111, "add">;
defm SUB : ri_inst<0b101, "sub">;
defm MUL : ri_inst<0b100, "mul">;
defset
将一组record收集到一个全局list中:
Defset ::= "defset" Type TokIdentifier "=" "{" Statement* "}"
示例:
class MyRecord<string Name, int Value> {string name = Name;int value = Value;
}defset list<MyRecord> MyRecords = {def R1 : MyRecord<"Record1", 10>;def R2 : MyRecord<"Record2", 20>;def R3 : MyRecord<"Record3", 30>;
};
deftype
定义一个类型,类似c++的using,右边只能是primitive types和type aliases:
Deftype ::= "deftype" TokIdentifier "=" Type ";"
defvar
定义一个变量:
Defvar ::= "defvar" TokIdentifier "=" Value ";"
示例:
defvar i = !add(i, 1);
foreach
for循环:
Foreach ::= "foreach" ForeachIterator "in" "{" Statement* "}"| "foreach" ForeachIterator "in" Statement
ForeachIterator ::= TokIdentifier "=" ("{" RangeList "}" | RangePiece | Value)
示例:
foreach i = [0, 1, 2, 3] in {def R#i : Register<...>;def F#i : Register<...>;
}
dump
打印输出到stderr,用作调试:
Dump ::= "dump" Value ";"
如果在顶层,会直接打印;如果在record中,会在该record每次实例化时打印。
示例:
multiclass MC<dag s> {dump "s = " # !repr(s);
}
if
根据条件从2个statement中选1个:
If ::= "if" Value "then" IfBody| "if" Value "then" IfBody "else" IfBody
IfBody ::= "{" Statement* "}" | Statement
assert
断言:
Assert ::= "assert" Value "," Value ";"
mlir-tblgen工具
在MLIR编译过程中,会使用mlir-tblgen工具将Dialect或Pass的td文件,编译为对应的C++代码:
set(LLVM_TARGET_DEFINITIONS TritonDialect.td)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls) # 生成声明
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) # 生成定义
add_mlir_doc(TritonDialect TritonDialect dialects/ -gen-dialect-doc) # 生成文档
Triton Dialect的td文件内容如下:
#ifndef TRITON_DIALECT
#define TRITON_DIALECTinclude "mlir/IR/OpBase.td"def Triton_Dialect : Dialect {let name = "tt";let cppNamespace = "::mlir::triton";let summary = "The Triton IR in MLIR";let description = [{Triton Dialect.Dependent Dialects:* Arith:* addf, addi, andi, cmpf, cmpi, divf, fptosi, ...* Math:* exp, sin, cos, log, ...* StructuredControlFlow:* for, if, while, yield, condition* ControlFlow:* br, cond_br}];let dependentDialects = ["arith::ArithDialect","math::MathDialect","scf::SCFDialect","cf::ControlFlowDialect","ub::UBDialect"];let extraClassDeclaration = [{void registerTypes();static TritonDialect *getLoaded(MLIRContext *ctx) {return ctx->getLoadedDialect<TritonDialect>();}static TritonDialect *getLoaded(Operation *op) {return getLoaded(op->getContext());}}];let discardableAttrs = (ins"::mlir::IntegerAttr":$num_stages,"::mlir::IntegerAttr":$latency,"::mlir::IntegerAttr":$self_latency);let hasConstantMaterializer = 1;let useDefaultTypePrinterParser = 1;let usePropertiesForAttributes = 1;
}include "triton/Dialect/Triton/IR/TritonTypes.td"#endif // TRITON_DIALECT
生成的声明文件如下:
/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\
|* *|
|* Dialect Declarations *|
|* *|
|* Automatically generated file, do not edit! *|
|* From: TritonDialect.td *|
|* *|
\*===----------------------------------------------------------------------===*/namespace mlir {
namespace triton {/// The Triton IR in MLIR
/// Triton Dialect.
///
/// Dependent Dialects:
/// * Arith:
/// * addf, addi, andi, cmpf, cmpi, divf, fptosi, ...
/// * Math:
/// * exp, sin, cos, log, ...
/// * StructuredControlFlow:
/// * for, if, while, yield, condition
/// * ControlFlow:
/// * br, cond_br
class TritonDialect : public ::mlir::Dialect {explicit TritonDialect(::mlir::MLIRContext *context);void initialize();friend class ::mlir::MLIRContext;
public:~TritonDialect() override;static constexpr ::llvm::StringLiteral getDialectNamespace() {return ::llvm::StringLiteral("tt");}/// Parse a type registered to this dialect.::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;/// Print a type registered to this dialect.void printType(::mlir::Type type,::mlir::DialectAsmPrinter &os) const override;/// Materialize a single constant operation from a given attribute value with/// the desired resultant type.::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder,::mlir::Attribute value,::mlir::Type type,::mlir::Location loc) override;/// Helper to manage the discardable attribute `num_stages`.class NumStagesAttrHelper {::mlir::StringAttr name;public:static constexpr ::llvm::StringLiteral getNameStr() {return "tt.num_stages";}constexpr ::mlir::StringAttr getName() {return name;}NumStagesAttrHelper(::mlir::MLIRContext *ctx): name(::mlir::StringAttr::get(ctx, getNameStr())) {}::mlir::IntegerAttr getAttr(::mlir::Operation *op) {return op->getAttrOfType<::mlir::IntegerAttr>(name);}void setAttr(::mlir::Operation *op, ::mlir::IntegerAttr val) {op->setAttr(name, val);}bool isAttrPresent(::mlir::Operation *op) {return op->hasAttrOfType<::mlir::IntegerAttr>(name);}void removeAttr(::mlir::Operation *op) {assert(op->hasAttrOfType<::mlir::IntegerAttr>(name));op->removeAttr(name);}};NumStagesAttrHelper getNumStagesAttrHelper() {return numStagesAttrName;}private:NumStagesAttrHelper numStagesAttrName;public:/// Helper to manage the discardable attribute `latency`.class LatencyAttrHelper {::mlir::StringAttr name;public:static constexpr ::llvm::StringLiteral getNameStr() {return "tt.latency";}constexpr ::mlir::StringAttr getName() {return name;}LatencyAttrHelper(::mlir::MLIRContext *ctx): name(::mlir::StringAttr::get(ctx, getNameStr())) {}::mlir::IntegerAttr getAttr(::mlir::Operation *op) {return op->getAttrOfType<::mlir::IntegerAttr>(name);}void setAttr(::mlir::Operation *op, ::mlir::IntegerAttr val) {op->setAttr(name, val);}bool isAttrPresent(::mlir::Operation *op) {return op->hasAttrOfType<::mlir::IntegerAttr>(name);}void removeAttr(::mlir::Operation *op) {assert(op->hasAttrOfType<::mlir::IntegerAttr>(name));op->removeAttr(name);}};LatencyAttrHelper getLatencyAttrHelper() {return latencyAttrName;}private:LatencyAttrHelper latencyAttrName;public:/// Helper to manage the discardable attribute `self_latency`.class SelfLatencyAttrHelper {::mlir::StringAttr name;public:static constexpr ::llvm::StringLiteral getNameStr() {return "tt.self_latency";}constexpr ::mlir::StringAttr getName() {return name;}SelfLatencyAttrHelper(::mlir::MLIRContext *ctx): name(::mlir::StringAttr::get(ctx, getNameStr())) {}::mlir::IntegerAttr getAttr(::mlir::Operation *op) {return op->getAttrOfType<::mlir::IntegerAttr>(name);}void setAttr(::mlir::Operation *op, ::mlir::IntegerAttr val) {op->setAttr(name, val);}bool isAttrPresent(::mlir::Operation *op) {return op->hasAttrOfType<::mlir::IntegerAttr>(name);}void removeAttr(::mlir::Operation *op) {assert(op->hasAttrOfType<::mlir::IntegerAttr>(name));op->removeAttr(name);}};SelfLatencyAttrHelper getSelfLatencyAttrHelper() {return selfLatencyAttrName;}private:SelfLatencyAttrHelper selfLatencyAttrName;public:void registerTypes();static TritonDialect *getLoaded(MLIRContext *ctx) {return ctx->getLoadedDialect<TritonDialect>();}static TritonDialect *getLoaded(Operation *op) {return getLoaded(op->getContext());}};
} // namespace triton
} // namespace mlir
MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::triton::TritonDialect)
该文件被Triton IR的Dialect.h包含:
生成的定义如下:
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Interfaces.h"
#include "triton/Dialect/Triton/IR/Types.h"#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/TypeSwitch.h"#include "triton/Dialect/Triton/IR/AttrInterfaces.cpp.inc"
#include "triton/Dialect/Triton/IR/Dialect.cpp.inc"
#include "triton/Dialect/Triton/IR/OpInterfaces.cpp.inc"using namespace mlir;
using namespace mlir::triton;//===----------------------------------------------------------------------===//
// TritonDialect Dialect Interfaces
//===----------------------------------------------------------------------===//bool TritonInlinerInterface::isLegalToInline(Operation *call,Operation *callable,bool wouldBeCloned) const {auto funcOp = dyn_cast<triton::FuncOp>(callable);if (!funcOp)return true;if (funcOp->hasAttr("noinline"))return !funcOp->getAttrOfType<BoolAttr>("noinline").getValue();return true;
}/// Handle the given inlined terminator by replacing it with a new operation
/// as necessary.
void TritonInlinerInterface::handleTerminator(Operation *op,Block *newDest) const {// Only return needs to be handled here.auto returnOp = dyn_cast<triton::ReturnOp>(op);if (!returnOp)return;// Replace the return with a branch to the dest.OpBuilder builder(op);builder.create<mlir::cf::BranchOp>(op->getLoc(), newDest,returnOp.getOperands());op->erase();
}/// Handle the given inlined terminator by replacing it with a new operation
/// as necessary.
void TritonInlinerInterface::handleTerminator(Operation *op,ValueRange valuesToRepl) const {// Only return needs to be handled here.auto returnOp = cast<triton::ReturnOp>(op);// Replace the values directly with the return operands.assert(returnOp.getNumOperands() == valuesToRepl.size());for (const auto &it : llvm::enumerate(returnOp.getOperands()))valuesToRepl[it.index()].replaceAllUsesWith(it.value());
}void TritonDialect::initialize() {registerTypes();addOperations<
#define GET_OP_LIST
#include "triton/Dialect/Triton/IR/Ops.cpp.inc">();// We can also add interface here.addInterfaces<TritonInlinerInterface>();
}Operation *TritonDialect::materializeConstant(OpBuilder &builder,Attribute value, Type type,Location loc) {return arith::ConstantOp::materialize(builder, value, type, loc);
}
该文件被Triton IR的Dialect.cpp包含:
参考资料:
TableGen Overview — LLVM 22.0.0git documentation
1 TableGen Programmer’s Reference — LLVM 22.0.0git documentation