diff --git a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt @@ -44,6 +44,18 @@ add_mlir_dialect(LinalgOps linalg) +set(LLVM_TARGET_DEFINITIONS LinalgOps.td) +mlir_tablegen(LinalgOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(LinalgOpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRLinalgOpsEnumsIncGen) +add_dependencies(mlir-headers MLIRLinalgOpsEnumsIncGen) + +set(LLVM_TARGET_DEFINITIONS LinalgOps.td) +mlir_tablegen(LinalgOpsAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(LinalgOpsAttrDefs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(MLIRLinalgOpsAttributesIncGen) +add_dependencies(mlir-headers MLIRLinalgOpsAttributesIncGen) + add_mlir_doc(LinalgDoc LinalgOps Dialects/ -gen-op-doc) add_dependencies(LinalgOpsDocGen LinalgOdsGen) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h --- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h @@ -104,6 +104,19 @@ #include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc" +//===----------------------------------------------------------------------===// +// Linalg Enums +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/LinalgOpsEnums.h.inc" + +//===----------------------------------------------------------------------===// +// Linalg Attributes +//===----------------------------------------------------------------------===// + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.h.inc" + //===----------------------------------------------------------------------===// // Linalg Interfaces //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -13,6 +13,7 @@ #ifndef LINALG_BASE #define LINALG_BASE +include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" def Linalg_Dialect : Dialect { @@ -57,4 +58,15 @@ }]; } +// Define a TypeFn enum matching the OpDSL TypeFn class. +def TypeFn : I32EnumAttr<"TypeFn", "", [ + I32EnumAttrCase<"cast", 0>, + I32EnumAttrCase<"cast_unsigned", 1> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::linalg"; +} + +def TypeFnAttr : EnumAttr; + #endif // LINALG_BASE diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -28,6 +28,10 @@ usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> + - !LinalgOperandDefConfig + name: cast + usage: TypeFnAttr + default_fn: cast indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)> @@ -52,18 +56,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A + attr_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B + attr_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: matmul_unsigned @@ -116,18 +120,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast_unsigned type_var: U operands: - !ScalarExpression scalar_arg: A + fn_name: cast_unsigned - !ScalarExpression type_fn: - fn_name: cast_unsigned type_var: U operands: - !ScalarExpression scalar_arg: B + fn_name: cast_unsigned --- !LinalgOpConfig metadata: !LinalgOpMetadata name: quantized_matmul @@ -194,36 +198,36 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: AZp + fn_name: cast - !ScalarExpression arith_fn: fn_name: sub operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: BZp + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: mmt4d @@ -287,18 +291,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: AccumType operands: - !ScalarExpression scalar_arg: lhs + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: AccumType operands: - !ScalarExpression scalar_arg: rhs + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: batch_matmul @@ -352,18 +356,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: quantized_batch_matmul @@ -431,36 +435,36 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: AZp + fn_name: cast - !ScalarExpression arith_fn: fn_name: sub operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: BZp + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: matvec @@ -512,18 +516,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: y + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: vecmat @@ -575,18 +579,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: y + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: batch_matvec @@ -639,18 +643,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: dot @@ -701,18 +705,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_1d @@ -764,18 +768,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d @@ -829,18 +833,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_3d @@ -897,18 +901,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_1d_nwc_wcf @@ -977,18 +981,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d_nhwc_hwcf @@ -1071,18 +1075,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d_nhwc_hwcf_q @@ -1182,36 +1186,36 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: IZp + fn_name: cast - !ScalarExpression arith_fn: fn_name: sub operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: KZp + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d_nchw_fchw @@ -1294,18 +1298,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_3d_ndhwc_dhwcf @@ -1390,18 +1394,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv_1d_nwc_wc @@ -1469,18 +1473,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv_2d_nhwc_hwc @@ -1558,18 +1562,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv_2d_nhwc_hwc_q @@ -1662,36 +1666,36 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: IZp + fn_name: cast - !ScalarExpression arith_fn: fn_name: sub operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: KZp + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv_2d_nhwc_hwcm @@ -1770,18 +1774,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv_2d_nhwc_hwcm_q @@ -1876,36 +1880,36 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: IZp + fn_name: cast - !ScalarExpression arith_fn: fn_name: sub operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: KZp + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_sum @@ -1978,11 +1982,11 @@ scalar_arg: O - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_max @@ -2055,11 +2059,11 @@ scalar_arg: O - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_max_unsigned @@ -2132,11 +2136,11 @@ scalar_arg: O - !ScalarExpression type_fn: - fn_name: cast_unsigned type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast_unsigned --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nchw_max @@ -2209,11 +2213,11 @@ scalar_arg: O - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_min @@ -2286,11 +2290,11 @@ scalar_arg: O - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_min_unsigned @@ -2363,11 +2367,11 @@ scalar_arg: O - !ScalarExpression type_fn: - fn_name: cast_unsigned type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast_unsigned --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_ndhwc_sum @@ -2446,11 +2450,11 @@ scalar_arg: O - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_ndhwc_max @@ -2529,11 +2533,11 @@ scalar_arg: O - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_ndhwc_min @@ -2612,11 +2616,11 @@ scalar_arg: O - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: fill_tensor @@ -2648,11 +2652,11 @@ arg: O value: !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: value + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: fill_rng_2d @@ -2700,7 +2704,6 @@ arg: O value: !ScalarExpression type_fn: - fn_name: cast type_var: T operands: - !ScalarExpression @@ -2717,14 +2720,13 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: F64 operands: - !ScalarExpression scalar_const: '2147483647 : i64' + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: F64 operands: - !ScalarExpression @@ -2741,11 +2743,11 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_index: 1 + fn_name: cast - !ScalarExpression arith_fn: fn_name: add @@ -2760,41 +2762,42 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_index: 0 + fn_name: cast - !ScalarExpression scalar_arg: seed - !ScalarExpression type_fn: - fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_const: '1103515245 : i64' + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_const: '12345 : i64' + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_const: '1103515245 : i64' + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_const: '12345 : i64' + fn_name: cast + fn_name: cast - !ScalarExpression arith_fn: fn_name: mul @@ -2809,13 +2812,14 @@ scalar_arg: min - !ScalarExpression type_fn: - fn_name: cast type_var: F64 operands: - !ScalarExpression scalar_const: '2.3283063999999999E-10 : f64' + fn_name: cast - !ScalarExpression scalar_arg: min + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: soft_plus_2d @@ -2857,19 +2861,19 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_const: '1.000000e+00 : f64' + fn_name: cast - !ScalarExpression arith_fn: fn_name: exp operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -37,7 +37,7 @@ LogicalResult reifyResultShapes(OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()).reifyResultShapes(b, + return llvm::cast(getOperation()).reifyResultShapes(b, reifiedReturnShapes); } }]; diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt @@ -8,6 +8,8 @@ DEPENDS MLIRLinalgInterfacesIncGen + MLIRLinalgOpsAttributesIncGen + MLIRLinalgOpsEnumsIncGen MLIRLinalgOpsIncGen MLIRLinalgStructuredOpsIncGen diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp @@ -21,6 +21,7 @@ #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; @@ -95,6 +96,10 @@ } void mlir::linalg::LinalgDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.cpp.inc" + >(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc" @@ -144,3 +149,10 @@ return op->emitError() << "attribute '" << attr.getName() << "' not supported by the linalg dialect"; } + +#include "mlir/Dialect/Linalg/IR/LinalgOpsEnums.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.cpp.inc" + +#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc" diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -35,8 +35,6 @@ using namespace mlir; using namespace mlir::linalg; -#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc" - /// Forward declarations. /// Generic entry point to create the block for the region of a LinalgOp. @@ -231,14 +229,14 @@ return operand; } - // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value typefn__cast(Type toType, Value operand) { - return cast(toType, operand, false); - } - - // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value typefn__cast_unsigned(Type toType, Value operand) { - return cast(toType, operand, true); + Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) { + switch (typeFn) { + case TypeFn::cast: + return cast(toType, operand, false); + case TypeFn::cast_unsigned: + return cast(toType, operand, true); + } + llvm_unreachable("unsupported type conversion function"); } // NOLINTNEXTLINE(*-identifier-naming): externally called. diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -129,7 +129,8 @@ return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs) def __repr__(self): - return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]" + return (f"{self.operand_def.name}" + f"[{', '.join([repr(i) for i in self.indices])}]") class TensorArithFn(TensorExpression): @@ -156,14 +157,24 @@ class TensorTypeFn(TensorExpression): """Application of a type conversion function.""" - def __init__(self, type_fn: "TypeFn", type_var: TypeVar, + def __init__(self, + type_fn: Optional["TypeFn"], + operand_def: Optional["OperandDef"], + type_var: TypeVar, arg: TensorExpression): + if bool(type_fn) + bool(operand_def) != 1: + raise ValueError("Either 'type_fn' or 'operand_def' must be specified") self.type_fn = type_fn + self.operand_def = operand_def self.type_var = type_var self.arg = arg def to_scalar_expression(self) -> ScalarExpression: - return ScalarTypeFn(self.type_fn.fn_name, self.type_var, + if self.operand_def: + assert self.operand_def.name, "TypeFnAttr not attached" + fn_name = self.type_fn.fn_name if self.type_fn else None + attr_name = self.operand_def.name if self.operand_def else None + return ScalarTypeFn(fn_name, attr_name, self.type_var, self.arg.to_scalar_expression()).expr() def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): @@ -171,7 +182,8 @@ self.arg.visit_tensor_exprs(callback) def __repr__(self): - return f"{repr(self.type_fn)}({self.type_var}, {self.arg})" + return (f"{repr(self.type_fn)}[{repr(self.operand_def)}]" + f"({self.type_var}, {self.arg})") class TensorReduceFn(TensorExpression): @@ -260,7 +272,7 @@ self.fn_name = fn_name def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TypeFnType": - return TensorTypeFn(self, type_var, arg) + return TensorTypeFn(self, None, type_var, arg) def __repr__(self): return f"{self.fn_name}" @@ -374,6 +386,7 @@ Scalar = 1 OutputTensor = 2 IndexAttr = 3 + TypeFnAttr = 4 class OperandDef: @@ -388,7 +401,8 @@ type_var: Optional[TypeVar] = None, size_exprs: Optional[Sequence[AffineExprDef]] = None, index_dims: Optional[Sequence[DimDef]] = None, - default_vals: Optional[Sequence[int]] = None): + default_vals: Optional[Sequence[int]] = None, + default_fn: Optional[str] = None): if type_var and not isinstance(type_var, TypeVar): raise ValueError( f"OperandDef requires a TypeVar but got {repr(type_var)}") @@ -397,6 +411,7 @@ self.size_exprs = size_exprs self.index_dims = index_dims self.default_vals = default_vals + self.default_fn = default_fn self.kind = kind self.name = None # type: Optional[str] self.registered_index = -1 # type: int @@ -408,13 +423,26 @@ self.name = name self.owner = owner + def is_input(self) -> bool: + return (self.kind == OperandKind.Scalar or + self.kind == OperandKind.InputTensor) + + def is_tensor(self) -> bool: + return (self.kind == OperandKind.InputTensor or + self.kind == OperandKind.OutputTensor) + + def is_attribute(self) -> bool: + return (self.kind == OperandKind.IndexAttr or + self.kind == OperandKind.TypeFnAttr) + def __hash__(self): return hash(id(self)) def __repr__(self): return (f"{self.name}:OperandDef(kind={self.kind.name}, " - f"type={repr(self.type_var)}, size_exprs={self.size_exprs}), " - f"index_dims={self.index_dims}, default_vals={self.default_vals})") + f"type={repr(self.type_var)}, size_exprs={self.size_exprs}, " + f"index_dims={self.index_dims}, default_vals={self.default_vals}, " + f"default_fn={self.default_fn})") class TensorDef: @@ -520,6 +548,25 @@ OperandKind.IndexAttr, size_exprs=sizes, default_vals=default) +class TypeFnAttrDef: + """Type conversion function attribute definition. + + Type conversion function attributes provide a way to make type conversions + parametrizable. Every attribute specifies a default type conversion function + that may be overwritten at operation instantiation time. + """ + + def __init__(self, default: "TypeFnType"): + if not isinstance(default, TypeFnType): + raise ValueError(f"TypeFnAttrDef requires default of type TypeFnType " + f"but got {default}") + self.operand_def = OperandDef( + OperandKind.TypeFnAttr, default_fn=default.fn_name) + + def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorTypeFn: + return TensorTypeFn(None, self.operand_def, type_var, arg) + + ############################################################################### # Operation definition. ############################################################################### @@ -617,15 +664,12 @@ f"to {self.registered_operands['name']}") # Ensure output tensors are registered after input tensors and scalars and # attributes are registered after all other operand types. - registered_kinds = [ - operand.kind.value for operand in self.registered_operands.values() - ] - if registered_kinds: - maximum = max(registered_kinds) - if maximum > operand.kind.value and maximum > OperandKind.Scalar.value: - raise ValueError( - f"The operand {name} of kind {operand.kind.name} is registered " - f"after an operand of kind {OperandKind(maximum).name}") + if operand.is_input() and any( + not op_def.is_input() for op_def in self.registered_operands.values()): + raise ValueError(f"Input {name} registered after an output or attribute") + if operand.kind == OperandKind.OutputTensor and any( + op_def.is_attribute() for op_def in self.registered_operands.values()): + raise ValueError(f"Output {name} registered after an attribute") operand.attach(len(self.registered_operands), name, self) self.registered_operands[name] = operand diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -63,6 +63,8 @@ def usage(self) -> str: if self.operand_def.kind == OperandKind.IndexAttr: return "IndexAttr" + if self.operand_def.kind == OperandKind.TypeFnAttr: + return "TypeFnAttr" if self.operand_def.kind == OperandKind.OutputTensor: return "Output" return "Input" @@ -77,6 +79,8 @@ self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map) if self.operand_def.default_vals: self_dict["default_vals"] = self.operand_def.default_vals + if self.operand_def.default_fn: + self_dict["default_fn"] = self.operand_def.default_fn return self_dict def __repr__(self): @@ -166,7 +170,7 @@ # Collect all attribute definitions. collected_attr_defs = list() for operand in registered_operands: - if operand.kind == OperandKind.IndexAttr: + if operand.is_attribute(): collected_attr_defs.append(operand) # Collect all tensors with manual indexing annotation. @@ -249,9 +253,10 @@ # Check all registered tensor and scalar operands have an indexing map. for operand in registered_operands: - if operand.kind == OperandKind.IndexAttr: + if operand.is_attribute(): continue - if not (operand in self.operands and self.operands[operand].indexing_map): + if not (operand in self.operands and + self.operands[operand].indexing_map): raise ValueError(f"Failed to compute an indexing map for operand " f"{operand.name}") @@ -311,7 +316,8 @@ def add_operand(self, operand_def: OperandDef): if operand_def in self.operands: return - if operand_def.kind == OperandKind.Scalar: + if (operand_def.kind == OperandKind.Scalar or + operand_def.kind == OperandKind.TypeFnAttr): self.operands[operand_def] = OperandDefConfig(operand_def) return with self.context: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -129,7 +129,8 @@ sig = inspect.signature(dsl_func) for param_name, param in sig.parameters.items(): param_default = param.default - if isinstance(param_default, (TensorDef, ScalarDef, IndexAttrDef)): + if isinstance(param_default, + (TensorDef, ScalarDef, IndexAttrDef, TypeFnAttrDef)): op_def.add_operand(param_name, param_default.operand_def) else: raise ValueError( diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -37,11 +37,12 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, outs: ValueList, - **attrs: Sequence[int]): + **attrs: Union[Sequence[int], TypeFnType]): all_arg_defs = op_config.ordered_operands in_arg_defs = [d for d in all_arg_defs if d.usage == "Input"] out_arg_defs = [d for d in all_arg_defs if d.usage == "Output"] index_attr_arg_defs = [d for d in all_arg_defs if d.usage == "IndexAttr"] + type_fn_attr_arg_defs = [d for d in all_arg_defs if d.usage == "TypeFnAttr"] # Verify outs is a sequence or a list of results. if not isinstance(outs, (Sequence, OpResultList)): @@ -56,7 +57,7 @@ raise ValueError(f"Expected {len(out_arg_defs)} outputs but got " f"{len(outs)} for {op_config}") - # Compute a replacement list for all attribute symbols. + # Compute a replacement list for all index attribute symbols. expressions = [] # type: Sequence[AffineExpr] replacements = [] # type: Sequence[AffineExpr] for index_attr in index_attr_arg_defs: @@ -125,15 +126,29 @@ array = np.array(index_attr_vals, dtype=np.int64) index_attrs[index_attr.name] = DenseElementsAttr.get(array) + # Compute the type function attribue mapping. + type_fn_attr_mapping = {} + for type_fn_attr in type_fn_attr_arg_defs: + attr_val = type_fn_attr.operand_def.default_fn + if type_fn_attr.name in attrs: + type_fn = attrs.get(type_fn_attr.name) + if not isinstance(type_fn, TypeFnType): + raise ValueError(f"Attribute {type_fn_attr.name} needs to be of type " + f"TypeFnType but got {type(attr_val)}") + attr_val = type_fn.fn_name + assert attr_val, "Type function attribute has no value" + type_fn_attr_mapping[type_fn_attr.name] = attr_val + return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, - type_mapping, indexing_maps_attr, iterator_types_attr, - index_attrs, block_arg_types) + type_mapping, indexing_maps_attr, iterator_types_attr, index_attrs, + type_fn_attr_mapping, block_arg_types) def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, outs: ValueList, **attrs: Sequence[int]): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - indexing_maps_attr, iterator_types_attr, index_attrs, block_arg_types = \ + indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \ + block_arg_types = \ prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) # An operation that accesses only scalars and scalar/rank zero tensors is @@ -149,8 +164,7 @@ for arg_def in all_arg_defs: if arg_def.operand_def.kind == OperandKind.Scalar: indexing_maps.append(scalar_map) - if (arg_def.operand_def.kind == OperandKind.InputTensor or - arg_def.operand_def.kind == OperandKind.OutputTensor): + if arg_def.operand_def.is_tensor(): indexing_maps.append(tensor_map) indexing_maps_attr = ArrayAttr.get( [AffineMapAttr.get(am) for am in indexing_maps]) @@ -169,7 +183,8 @@ block = generic_op.regions[0].blocks.append(*block_arg_types) block_arg_mapping = dict(zip(block_arg_names, block.arguments)) with InsertionPoint(block): - body_builder = _BodyBuilder(type_mapping, block_arg_mapping) + body_builder = _BodyBuilder(type_mapping, block_arg_mapping, + type_fn_attr_mapping) for assignment in op_config.assignments: body_builder.assign(assignment) body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs)) @@ -184,7 +199,8 @@ op_class_name: str, *ins: Value, outs: ValueList, **attrs: Sequence[int]): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - indexing_maps_attr, iterator_types_attr, index_attrs, block_arg_types = \ + indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \ + block_arg_types = \ prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) # If we get here, there must exist a builtin class `op_class_name`. @@ -200,6 +216,11 @@ for name, value in index_attrs.items(): named_op.operation.attributes[name] = value + # Set the type function attributes. + for name, value in type_fn_attr_mapping.items(): + named_op.operation.attributes[name] = Attribute.parse( + f"#linalg<\"type_fn {value}\">") + linalg.fill_builtin_region(named_op.operation) if len(result_types) == 1: @@ -212,9 +233,11 @@ """Constructs a structured op body by evaluating assignments.""" def __init__(self, type_mapping: Dict[str, Type], - block_arg_mapping: Dict[str, Value]): + block_arg_mapping: Dict[str, Value], + type_fn_attr_mapping: Dict[str, str]): self.type_mapping = type_mapping self.block_arg_mapping = block_arg_mapping + self.type_fn_attr_mapping = type_fn_attr_mapping self.yield_mapping = dict() # type: Dict[str, Value] def assign(self, assignment: ScalarAssign): @@ -245,7 +268,10 @@ ] return fn(*operand_values) elif expr.type_fn: - fn = self._get_function(f"_typefn_{expr.type_fn.fn_name}") + fn_name = expr.type_fn.fn_name + if not fn_name: + fn_name = self.type_fn_attr_mapping[expr.type_fn.attr_name] + fn = self._get_function(f"_typefn_{fn_name}") operand = self.expression(expr.type_fn.operand) return fn(expr.type_fn.type_var.name, operand) raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py @@ -46,9 +46,10 @@ class ScalarTypeFn: """A type of ScalarExpression that applies a type conversion function.""" - def __init__(self, fn_name: str, type_var: TypeVar, - operand: "ScalarExpression"): + def __init__(self, fn_name: Optional[str], attr_name: Optional[str], + type_var: TypeVar, operand: "ScalarExpression"): self.fn_name = fn_name + self.attr_name = attr_name self.type_var = type_var self.operand = operand @@ -56,7 +57,8 @@ return ScalarExpression(type_fn=self) def __repr__(self): - return f"ScalarTypeFn<{self.fn_name}>({self.type_var}, {self.operand})" + return (f"ScalarTypeFn<{self.fn_name}[{self.attr_name}]>" + f"({self.type_var}, {self.operand})") class ScalarArg: @@ -138,12 +140,15 @@ # Note that even though operands must be arity 1, we write it the # same way as for apply because it allows handling code to be more # generic vs having a special form. - return dict( - type_fn=dict( - fn_name=self.type_fn.fn_name, - type_var=self.type_fn.type_var.name, - operands=[self.type_fn.operand], - )) + type_fn_dict = dict( + type_var=self.type_fn.type_var.name, + operands=[self.type_fn.operand], + ) + if self.type_fn.fn_name: + type_fn_dict["fn_name"] = self.type_fn.fn_name + if self.type_fn.attr_name: + type_fn_dict["attr_name"] = self.type_fn.attr_name + return dict(type_fn=type_fn_dict) elif self.scalar_arg: return dict(scalar_arg=self.scalar_arg.arg) elif self.scalar_const: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -10,7 +10,8 @@ def matmul( A=TensorDef(T1, S.M, S.K), B=TensorDef(T2, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True)): + C=TensorDef(U, S.M, S.N, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast)): """Performs a matrix multiplication of two 2D inputs. Numeric casting is performed on the operands to the inner multiply, promoting @@ -18,7 +19,7 @@ """ domain(D.m, D.n, D.k) implements(ContractionOpInterface) - C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n]) + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) @linalg_structured_op diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -21,7 +21,8 @@ // Verifies that different argument types is legal. func @generalize_matmul_tensor_i16i64i32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> { - %0 = linalg.matmul ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>) + %0 = linalg.matmul {cast = #linalg<"type_fn cast">} + ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>) outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32> return %0: tensor<16x32xi32> } @@ -36,6 +37,20 @@ // CHECK-NEXT: linalg.yield %[[ADD]] : i32 // CHECK-NEXT: -> tensor<16x32xi32> + +// ----- + +// Verifies the cast attribute controls the cast operations used. +func @generalize_matmul_tensor_i16i64i32_unsigned(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> { + %0 = linalg.matmul {cast = #linalg<"type_fn cast_unsigned">} + ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>) + outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32> + return %0: tensor<16x32xi32> +} + +// CHECK-LABEL: @generalize_matmul_tensor_i16i64i32_unsigned +// CHECK: = arith.extui + // ----- func @generalize_matmul_tensor_i16i64f32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -2,7 +2,8 @@ # RUN: mlir-linalg-ods-yaml-gen %s --o-impl=- | FileCheck %s --check-prefix=IMPL # @linalg_structured_op -# def test1(O=TensorDef(T, S.M, S.N, output=True)): +# def test1(O=TensorDef(T, S.M, S.N, output=True), +# cast=TypeFnAttrDef(default=TypeFn.cast)): # """Title. # Detailed description. @@ -24,6 +25,10 @@ usage: Output type_var: T shape_map: affine_map<()[s0, s1] -> (s0, s1)> + - !LinalgOperandDefConfig + name: cast + usage: TypeFnAttr + default_fn: cast indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1)[s0, s1] -> (d0, d1)> @@ -39,18 +44,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: T operands: - !ScalarExpression scalar_const: '42 : i64' + attr_name: cast - !ScalarExpression type_fn: - fn_name: cast_unsigned type_var: T operands: - !ScalarExpression scalar_index: 1 + attr_name: cast # ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1" @@ -61,16 +66,22 @@ # ODS: let arguments = # ODS-NEXT: Variadic:$inputs, -# ODS-NEXT: Variadic:$outputs +# ODS-NEXT: Variadic:$outputs, +# ODS-NEXT: DefaultValuedAttr:$cast # ODS: let builders = # ODS: (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, # ODS-NEXT: "ValueRange":$outputs, # ODS-NEXT: CArg<"ArrayRef", "{}">:$attributes), +# ODS: (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, +# ODS-NEXT: "ValueRange":$outputs, "Attribute":$cast, +# ODS-NEXT: CArg<"ArrayRef", "{}">:$attributes), + # ODS: $_state.addOperands(inputs); # ODS-NEXT: $_state.addOperands(outputs); # ODS-NEXT: $_state.addTypes(resultTensorTypes); +# ODS-NEXT: $_state.addAttribute("cast", cast) # ODS-NEXT: $_state.addAttributes(attributes); # ODS-NEXT: $_state.addAttribute( # ODS-NEXT: "operand_segment_sizes", @@ -85,10 +96,18 @@ # IMPL-LABEL: void Test1Op::regionBuilder(ImplicitLocOpBuilder &b, # IMPL-NEXT: Block &block, ArrayRef attrs) +# IMPL: TypeFn castVal = TypeFn::cast; +# IMPL-NEXT: auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) { +# IMPL-NEXT: return attr.getName() == "cast"; }); +# IMPL-NEXT: if (castIter != attrs.end()) { +# IMPL-NEXT: if (auto attr = castIter->getValue().dyn_cast()) +# IMPL-NEXT: castVal = attr.getValue(); +# IMPL-NEXT: } + # IMPL: Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64"); -# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.typefn__cast(block.getArgument(0).getType(), [[VAL0]]); +# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL0]]); # IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1); -# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.typefn__cast_unsigned(block.getArgument(0).getType(), [[VAL2]]); +# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL2]]); # IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.arithfn__add([[VAL1]], [[VAL3]]); diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py @@ -24,19 +24,10 @@ def matmul_poly( A=TensorDef(T1, S.M, S.K), B=TensorDef(T2, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True)): + C=TensorDef(U, S.M, S.N, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast)): domain(D.m, D.n, D.k) - C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n]) - - -@linalg_structured_op -def matmul_unsigned_poly( - A=TensorDef(T1, S.M, S.K), - B=TensorDef(T2, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True)): - domain(D.m, D.n, D.k) - C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned( - U, B[D.k, D.n]) + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) with Context() as ctx, Location.unknown(): @@ -92,7 +83,8 @@ RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8), RankedTensorType.get((4, 8), i32)) def test_i8i8i32_matmul_unsigned(lhs, rhs, init_result): - return matmul_unsigned_poly(lhs, rhs, outs=[init_result]) + return matmul_poly( + lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned) # CHECK-LABEL: @test_i8i16i32_matmul # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32) @@ -143,7 +135,8 @@ RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8), RankedTensorType.get((4, 8), f32)) def test_i8i8f32_matmul_unsigned(lhs, rhs, init_result): - return matmul_unsigned_poly(lhs, rhs, outs=[init_result]) + return matmul_poly( + lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned) # CHECK-LABEL: @test_f16f16f32_matmul # CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32) diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py @@ -16,12 +16,13 @@ I=TensorDef(T1, S.N, S.H, S.W, S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast), strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw]( - TypeFn.cast( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) + cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, + D.c])) @linalg_structured_op @@ -29,12 +30,13 @@ I=TensorDef(T1, S.N, S.H, S.W, S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast_unsigned), strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw]( - TypeFn.cast_unsigned( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) + cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, + D.c])) @linalg_structured_op @@ -42,12 +44,13 @@ I=TensorDef(T1, S.N, S.H, S.W, S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast), strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw]( - TypeFn.cast( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) + cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, + D.c])) @linalg_structured_op @@ -55,12 +58,13 @@ I=TensorDef(T1, S.N, S.H, S.W, S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast_unsigned), strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw]( - TypeFn.cast_unsigned( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) + cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, + D.c])) with Context() as ctx, Location.unknown(): @@ -150,5 +154,20 @@ return pooling_min_poly( input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]) + # CHECK-LABEL: @test_f32i32_min_pooling_cast_attr + # CHECK: = arith.fptoui + # CHECK: = arith.minsi + @builtin.FuncOp.from_py_func( + RankedTensorType.get((1, 4, 16, 1), f32), + RankedTensorType.get((2, 2), f32), + RankedTensorType.get((1, 2, 4, 1), i32)) + def test_f32i32_min_pooling_cast_attr(input, shape, init_result): + return pooling_min_poly( + input, + shape, + outs=[init_result], + cast=TypeFn.cast_unsigned, + strides=[2, 4], + dilations=[1, 2]) print(module) diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -6,6 +6,7 @@ from mlir.dialects import std from mlir.dialects import arith +from mlir.dialects.linalg.opdsl.lang import * def run(f): print("\nTEST:", f.__name__) @@ -98,12 +99,14 @@ init_result = linalg.InitTensorOp([4, 8], f32) # First check the named form with custom format # CHECK: linalg.matmul + # CHECK: cast = #linalg<"type_fn cast_unsigned"> # CHECK-NOT: linalg.memoized_indexing_maps # CHECK-SAME: ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>) # CHECK-SAME: outs(%{{.*}} : tensor<4x8xf32>) # CHECK-SAME: -> tensor<4x8xf32> # CHECK-NEXT: return - return linalg.matmul(lhs, rhs, outs=[init_result.result]) + return linalg.matmul( + lhs, rhs, outs=[init_result.result], cast=TypeFn.cast_unsigned) print(module) diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py --- a/mlir/test/python/integration/dialects/linalg/opsrun.py +++ b/mlir/test/python/integration/dialects/linalg/opsrun.py @@ -9,6 +9,7 @@ from mlir.passmanager import * from mlir.execution_engine import * +from mlir.dialects.linalg.opdsl.lang import * # Log everything to stderr and flush so that we have a unified stream to match # errors/info emitted by MLIR to stderr. @@ -20,21 +21,28 @@ matmul_boiler = """ func @main() -> f32 attributes {llvm.emit_c_interface} { %v0 = arith.constant 0.0 : f32 - %v1 = arith.constant 1.0 : f32 + %v1 = arith.constant -1 : i8 %v2 = arith.constant 2.0 : f32 - %A = memref.alloc() : memref<4x16xf32> + %A = memref.alloc() : memref<4x16xi8> %B = memref.alloc() : memref<16x8xf32> - %C = memref.alloc() : memref<4x8xf32> - linalg.fill(%v1, %A) : f32, memref<4x16xf32> + %C0 = memref.alloc() : memref<4x8xf32> + %C1 = memref.alloc() : memref<4x8xf32> + linalg.fill(%v1, %A) : i8, memref<4x16xi8> linalg.fill(%v2, %B) : f32, memref<16x8xf32> - linalg.fill(%v0, %C) : f32, memref<4x8xf32> + linalg.fill(%v0, %C0) : f32, memref<4x8xf32> + linalg.fill(%v0, %C1) : f32, memref<4x8xf32> - call @matmul_on_buffers(%A, %B, %C) : - (memref<4x16xf32>, memref<16x8xf32>, memref<4x8xf32>) -> () + call @matmul_signed_on_buffers(%A, %B, %C0) : + (memref<4x16xi8>, memref<16x8xf32>, memref<4x8xf32>) -> () + call @matmul_unsigned_on_buffers(%A, %B, %C1) : + (memref<4x16xi8>, memref<16x8xf32>, memref<4x8xf32>) -> () %c0 = arith.constant 0 : index - %0 = memref.load %C[%c0, %c0] : memref<4x8xf32> + %res0 = memref.load %C0[%c0, %c0] : memref<4x8xf32> + %res1 = memref.load %C1[%c0, %c0] : memref<4x8xf32> + + %0 = arith.addf %res0, %res1 : f32 // TODO: FFI-based solution to allow testing and printing with python code. return %0 : f32 @@ -168,14 +176,21 @@ with Context() as ctx, Location.unknown(): module = Module.create() f32 = F32Type.get() + i8 = IntegerType.get_signless(8) with InsertionPoint(module.body): @builtin.FuncOp.from_py_func( - MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32), + MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32), MemRefType.get((4, 8), f32)) - def matmul_on_buffers(lhs, rhs, out): + def matmul_signed_on_buffers(lhs, rhs, out): linalg.matmul(lhs, rhs, outs=[out]) + @builtin.FuncOp.from_py_func( + MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32), + MemRefType.get((4, 8), f32)) + def matmul_unsigned_on_buffers(lhs, rhs, out): + linalg.matmul(lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned) + execution_engine = ExecutionEngine(transform(module, matmul_boiler)) # TODO: FFI-based solution to allow testing and printing with python code. @@ -186,7 +201,9 @@ execution_engine.invoke("main", res) log("RESULT: ", res[0]) - # CHECK: RESULT: 32.0 + # matmul_signed_on_buffers: -1 * 2.0 * 16 = -32 + # matmul_unsigned_on_buffers: (2^8-1) * 2.0 * 16 = 8160 + # CHECK: RESULT: 8128 test_matmul_builtin() @@ -196,14 +213,22 @@ with Context() as ctx, Location.unknown(): module = Module.create() f32 = F32Type.get() + i8 = IntegerType.get_signless(8) with InsertionPoint(module.body): @builtin.FuncOp.from_py_func( - MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32), + MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32), MemRefType.get((4, 8), f32)) - def matmul_on_buffers(lhs, rhs, out): + def matmul_signed_on_buffers(lhs, rhs, out): linalg.matmul(lhs, rhs, outs=[out], emit_generic=True) + @builtin.FuncOp.from_py_func( + MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32), + MemRefType.get((4, 8), f32)) + def matmul_unsigned_on_buffers(lhs, rhs, out): + linalg.matmul( + lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned, emit_generic=True) + execution_engine = ExecutionEngine(transform(module, matmul_boiler)) # TODO: FFI-based solution to allow testing and printing with python code. @@ -214,7 +239,9 @@ execution_engine.invoke("main", res) log("RESULT: ", res[0]) - # CHECK: RESULT: 32.0 + # matmul_signed_on_buffers = -1 * 2.0 * 16 = -32 + # matmul_unsigned_on_buffers = (2^8-1) * 2.0 * 16 = 8160 + # CHECK: RESULT: 8128 test_matmul_generic() diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -61,7 +61,7 @@ AffineMap affineMap() { return affineMapAttr.getValue(); } }; -enum class LinalgOperandDefUsage { Input, Output, IndexAttr }; +enum class LinalgOperandDefUsage { Input, Output, IndexAttr, TypeFnAttr }; struct LinalgOperandDef { std::string name; @@ -70,6 +70,7 @@ Optional shapeMap; Optional indexAttrMap; Optional> defaultVals; + Optional defaultFn; }; enum class LinalgIteratorTypeDef { @@ -91,11 +92,12 @@ }; struct ScalarTypeFn { - std::string fnName; std::string typeVar; // NOTE: This must be of arity 1, but to break the self-referential cycle, // we use a heap allocated vector. std::vector operands; + Optional fnName; + Optional attrName; }; struct ScalarExpression { @@ -182,6 +184,8 @@ /// arguments have an `index_attr_map`. /// - `default_vals`: An optional default initialization for index attribute /// arguments. +/// - `default_fn`: An optional default initialization for function attribute +/// arguments. template <> struct MappingTraits { static void mapping(IO &io, LinalgOperandDef &info) { @@ -191,6 +195,7 @@ io.mapOptional("shape_map", info.shapeMap); io.mapOptional("index_attr_map", info.indexAttrMap); io.mapOptional("default_vals", info.defaultVals); + io.mapOptional("default_fn", info.defaultFn); } }; @@ -201,6 +206,7 @@ io.enumCase(value, "Input", LinalgOperandDefUsage::Input); io.enumCase(value, "Output", LinalgOperandDefUsage::Output); io.enumCase(value, "IndexAttr", LinalgOperandDefUsage::IndexAttr); + io.enumCase(value, "TypeFnAttr", LinalgOperandDefUsage::TypeFnAttr); } }; @@ -281,9 +287,10 @@ template <> struct MappingTraits { static void mapping(IO &io, ScalarTypeFn &info) { - io.mapRequired("fn_name", info.fnName); io.mapRequired("type_var", info.typeVar); io.mapRequired("operands", info.operands); + io.mapOptional("fn_name", info.fnName); + io.mapOptional("attr_name", info.attrName); } }; @@ -552,6 +559,8 @@ $_state.addOperands(inputs); $_state.addOperands(outputs); $_state.addTypes(resultTensorTypes); + {2} + $_state.addAttributes(attributes); $_state.addAttribute( "operand_segment_sizes", $_builder.getI32VectorAttr({{ @@ -562,8 +571,6 @@ $_state, TypeRange(inputs), TypeRange(outputs)); - {2} - $_state.addAttributes(attributes); }]> )FMT"; @@ -681,42 +688,56 @@ interfaceNameList = interleaveToString(opConfig.metadata->implements, ", "); - // Assemble the attribute specific logic required for the op definition. if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { - return arg.usage == LinalgOperandDefUsage::IndexAttr; + return arg.usage == LinalgOperandDefUsage::IndexAttr || + arg.usage == LinalgOperandDefUsage::TypeFnAttr; })) { SmallVector attrDefs; SmallVector attrParams; SmallVector attrStmts; for (LinalgOperandDef &arg : opConfig.structuredOp->args) { - if (arg.usage != LinalgOperandDefUsage::IndexAttr) - continue; - assert(arg.indexAttrMap.hasValue()); - assert(arg.defaultVals.hasValue()); - size_t size = arg.indexAttrMap->affineMap().getNumResults(); - assert(arg.defaultVals.getValue().size() == size); - static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>"; - static const char defFmt[] = "DefaultValuedAttr<{0}, \"{1}\">:${2}"; static const char paramFmt[] = "\"Attribute\":${0}"; static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});"; - std::string defaultVals; - llvm::raw_string_ostream ss(defaultVals); - ss << "{ "; - llvm::interleave( - arg.defaultVals.getValue(), ss, - [&](int64_t val) { ss << "static_cast(" << val << ")"; }, - ", "); - ss << " }"; - attrDefs.push_back(llvm::formatv(defFmt, llvm::formatv(typeFmt, size), - ss.str(), arg.name)); - attrParams.push_back(llvm::formatv(paramFmt, arg.name)); - attrStmts.push_back(llvm::formatv(stmtFmt, arg.name)); + // Add the type conversion attributes to the op definition and builders. + if (arg.usage == LinalgOperandDefUsage::TypeFnAttr) { + assert(arg.defaultFn.hasValue()); + static const char typeFmt[] = "TypeFn::{0}"; + static const char defFmt[] = "DefaultValuedAttr<{0}, \"{1}\">:${2}"; + attrDefs.push_back(llvm::formatv(defFmt, "TypeFnAttr", + llvm::formatv(typeFmt, arg.defaultFn), + arg.name)); + attrParams.push_back(llvm::formatv(paramFmt, arg.name)); + attrStmts.push_back(llvm::formatv(stmtFmt, arg.name)); + } + // Add the index attributes to the op definition and builders. + if (arg.usage == LinalgOperandDefUsage::IndexAttr) { + assert(arg.indexAttrMap.hasValue()); + assert(arg.defaultVals.hasValue()); + size_t size = arg.indexAttrMap->affineMap().getNumResults(); + assert(arg.defaultVals.getValue().size() == size); + static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>"; + static const char defFmt[] = "DefaultValuedAttr<{0}, \"{ {1} }\">:${2}"; + std::string defaultVals; + llvm::raw_string_ostream ss(defaultVals); + llvm::interleave( + arg.defaultVals.getValue(), ss, + [&](int64_t val) { ss << "static_cast(" << val << ")"; }, + ", "); + attrDefs.push_back(llvm::formatv(defFmt, llvm::formatv(typeFmt, size), + ss.str(), arg.name)); + attrParams.push_back(llvm::formatv(paramFmt, arg.name)); + attrStmts.push_back(llvm::formatv(stmtFmt, arg.name)); + } + } + if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { + return arg.usage == LinalgOperandDefUsage::IndexAttr; + })) { + attrMethods = R"( + bool hasDynamicIndexingMaps(); + LogicalResult verifyIndexingMapRequiredAttributes(); + )"; } attrList = ",\n" + llvm::join(attrDefs, ",\n"); - attrMethods = R"( - bool hasDynamicIndexingMaps(); - LogicalResult verifyIndexingMapRequiredAttributes(); - )"; attrBuilder = llvm::formatv( structuredOpBuilderFormat, opConfig.metadata->cppClassName, llvm::join(attrParams, ", "), llvm::join(attrStmts, "\n")); @@ -746,7 +767,8 @@ // Compute the number of scalar and tensor arguments. int64_t numOfArgs = llvm::count_if(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { - return arg.usage != LinalgOperandDefUsage::IndexAttr; + return arg.usage == LinalgOperandDefUsage::Input || + arg.usage == LinalgOperandDefUsage::Output; }); // An operation that accesses only scalars and scalar/rank zero tensors is @@ -952,7 +974,8 @@ // Generates a regionBuilder method. Parameters. // {0}: Class name // {1}: Number of args - // {2}: Statements + // {2}: Attributes + // {3}: Statements static const char structuredOpRegionBuilderFormat[] = R"FMT( void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block, ArrayRef attrs) {{ @@ -961,6 +984,7 @@ RegionBuilderHelper helper(block.getArgument(0).getContext(), block); SmallVector yields; {2} + {3} helper.yieldOutputs(yields); } )FMT"; @@ -968,7 +992,25 @@ auto &assignments = opConfig.structuredOp->assignments; size_t generatedAssignmentCount = 0; int localCounter = 0; + SmallVector attrs; SmallVector stmts; + for (LinalgOperandDef &arg : args) { + if (arg.usage != LinalgOperandDefUsage::TypeFnAttr) + continue; + // Obtain the type function attribute values. Parameters. + // {0}: attribute name + // {1}: default type function name + static const char attrDef[] = R"FMT( +TypeFn {0}Val = TypeFn::{1}; +auto {0}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{ + return attr.getName() == "{0}"; }); +if ({0}Iter != attrs.end()) {{ + if (auto attr = {0}Iter->getValue().dyn_cast()) + {0}Val = attr.getValue(); +} +)FMT"; + attrs.push_back(llvm::formatv(attrDef, arg.name, arg.defaultFn)); + } for (LinalgOperandDef &arg : args) { if (arg.usage != LinalgOperandDefUsage::Output) continue; @@ -1048,11 +1090,25 @@ << "an argument type but it does not"; return None; } + + // Use the function name or the attribute to build the type function. + std::string typeFunc = llvm::formatv( + "TypeFn::{0}", expression.typeFn->fnName.getValueOr("")); + if (expression.typeFn->attrName) { + if (llvm::none_of(args, [&](LinalgOperandDef &arg) { + return arg.usage == LinalgOperandDefUsage::TypeFnAttr && + arg.name == expression.typeFn->attrName.getValue(); + })) { + emitError(genContext.getLoc()) + << "missing type function attribute " + << expression.typeFn->attrName.getValue(); + } + typeFunc = llvm::formatv("{0}Val", *expression.typeFn->attrName); + } std::string cppIdent = llvm::formatv("value{0}", ++localCounter); - stmts.push_back( - llvm::formatv("Value {0} = helper.typefn__{1}({2}, {3});", - cppIdent, expression.typeFn->fnName, - typeCppValue.getValue(), *operandCppValue)); + stmts.push_back(llvm::formatv( + "Value {0} = helper.buildTypeFn({1}, {2}, {3});", cppIdent, + typeFunc, typeCppValue.getValue(), *operandCppValue)); return cppIdent; } emitError(genContext.getLoc()) << "unknown ScalarExpression type"; @@ -1069,6 +1125,7 @@ << "mismatched number of assignments vs output arguments"; os << llvm::formatv(structuredOpRegionBuilderFormat, className, numOfArgs, + interleaveToString(attrs, "\n "), interleaveToString(stmts, "\n ")); } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6695,6 +6695,22 @@ ], "include/mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc", ), + ( + ["-gen-enum-decls"], + "include/mlir/Dialect/Linalg/IR/LinalgOpsEnums.h.inc", + ), + ( + ["-gen-enum-defs"], + "include/mlir/Dialect/Linalg/IR/LinalgOpsEnums.cpp.inc", + ), + ( + ["-gen-attrdef-decls"], + "include/mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.h.inc", + ), + ( + ["-gen-attrdef-defs"], + "include/mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.cpp.inc", + ), ], tblgen = ":mlir-tblgen", td_file = "include/mlir/Dialect/Linalg/IR/LinalgOps.td",