diff --git a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt @@ -44,6 +44,18 @@ add_mlir_dialect(LinalgOps linalg) +set(LLVM_TARGET_DEFINITIONS LinalgOps.td) +mlir_tablegen(LinalgOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(LinalgOpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRLinalgOpsEnumsIncGen) +add_dependencies(mlir-headers MLIRLinalgOpsEnumsIncGen) + +set(LLVM_TARGET_DEFINITIONS LinalgOps.td) +mlir_tablegen(LinalgOpsAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(LinalgOpsAttrDefs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(MLIRLinalgOpsAttributesIncGen) +add_dependencies(mlir-headers MLIRLinalgOpsAttributesIncGen) + add_mlir_doc(LinalgDoc LinalgOps Dialects/ -gen-op-doc) add_dependencies(LinalgOpsDocGen LinalgOdsGen) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h --- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h @@ -104,6 +104,19 @@ #include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc" +//===----------------------------------------------------------------------===// +// Linalg Enums +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/LinalgOpsEnums.h.inc" + +//===----------------------------------------------------------------------===// +// Linalg Attributes +//===----------------------------------------------------------------------===// + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.h.inc" + //===----------------------------------------------------------------------===// // Linalg Interfaces //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -13,6 +13,7 @@ #ifndef LINALG_BASE #define LINALG_BASE +include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" def Linalg_Dialect : Dialect { @@ -57,4 +58,17 @@ }]; } +// Define a TypeFn enum matching the OpDSL TypeFn class. +def TypeFn : I32EnumAttr<"TypeFn", "", [ + I32EnumAttrCase<"cast", 0>, + I32EnumAttrCase<"cast_unsigned", 1> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::linalg"; +} + +def TypeFnAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + #endif // LINALG_BASE diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -15,19 +15,23 @@ args: - !LinalgOperandDefConfig name: A - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> - !LinalgOperandDefConfig name: B - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)> - !LinalgOperandDefConfig name: C - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> + - !LinalgOperandDefConfig + name: cast + kind: type_fn_attr + default_fn: cast indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)> @@ -52,18 +56,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A + attr_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B + attr_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: matmul_unsigned @@ -79,17 +83,17 @@ args: - !LinalgOperandDefConfig name: A - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> - !LinalgOperandDefConfig name: B - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)> - !LinalgOperandDefConfig name: C - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> indexing_maps: !LinalgIndexingMapsConfig @@ -116,18 +120,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast_unsigned type_var: U operands: - !ScalarExpression scalar_arg: A + fn_name: cast_unsigned - !ScalarExpression type_fn: - fn_name: cast_unsigned type_var: U operands: - !ScalarExpression scalar_arg: B + fn_name: cast_unsigned --- !LinalgOpConfig metadata: !LinalgOpMetadata name: quantized_matmul @@ -143,25 +147,25 @@ args: - !LinalgOperandDefConfig name: A - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> - !LinalgOperandDefConfig name: B - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)> - !LinalgOperandDefConfig name: AZp - usage: Input + kind: scalar type_var: I32 - !LinalgOperandDefConfig name: BZp - usage: Input + kind: scalar type_var: I32 - !LinalgOperandDefConfig name: C - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> indexing_maps: !LinalgIndexingMapsConfig @@ -194,36 +198,36 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: AZp + fn_name: cast - !ScalarExpression arith_fn: fn_name: sub operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: BZp + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: mmt4d @@ -244,17 +248,17 @@ args: - !LinalgOperandDefConfig name: lhs - usage: Input + kind: input_tensor type_var: LhsType shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s2, s3)> - !LinalgOperandDefConfig name: rhs - usage: Input + kind: input_tensor type_var: RhsType shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s1, s5, s3)> - !LinalgOperandDefConfig name: accum - usage: Output + kind: output_tensor type_var: AccumType shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s4, s2, s5)> indexing_maps: !LinalgIndexingMapsConfig @@ -287,18 +291,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: AccumType operands: - !ScalarExpression scalar_arg: lhs + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: AccumType operands: - !ScalarExpression scalar_arg: rhs + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: batch_matmul @@ -314,17 +318,17 @@ args: - !LinalgOperandDefConfig name: A - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)> - !LinalgOperandDefConfig name: B - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)> - !LinalgOperandDefConfig name: C - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)> indexing_maps: !LinalgIndexingMapsConfig @@ -352,18 +356,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: quantized_batch_matmul @@ -379,25 +383,25 @@ args: - !LinalgOperandDefConfig name: A - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)> - !LinalgOperandDefConfig name: B - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)> - !LinalgOperandDefConfig name: AZp - usage: Input + kind: scalar type_var: I32 - !LinalgOperandDefConfig name: BZp - usage: Input + kind: scalar type_var: I32 - !LinalgOperandDefConfig name: C - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)> indexing_maps: !LinalgIndexingMapsConfig @@ -431,36 +435,36 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: AZp + fn_name: cast - !ScalarExpression arith_fn: fn_name: sub operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: BZp + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: matvec @@ -476,17 +480,17 @@ args: - !LinalgOperandDefConfig name: A - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1] -> (s0, s1)> - !LinalgOperandDefConfig name: y - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1] -> (s1)> - !LinalgOperandDefConfig name: x - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1] -> (s0)> indexing_maps: !LinalgIndexingMapsConfig @@ -512,18 +516,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: y + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: vecmat @@ -539,17 +543,17 @@ args: - !LinalgOperandDefConfig name: y - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1] -> (s0)> - !LinalgOperandDefConfig name: A - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1] -> (s0, s1)> - !LinalgOperandDefConfig name: x - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1] -> (s1)> indexing_maps: !LinalgIndexingMapsConfig @@ -575,18 +579,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: y + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: batch_matvec @@ -602,17 +606,17 @@ args: - !LinalgOperandDefConfig name: A - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2] -> (s0, s1, s2)> - !LinalgOperandDefConfig name: B - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> - !LinalgOperandDefConfig name: C - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> indexing_maps: !LinalgIndexingMapsConfig @@ -639,18 +643,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: dot @@ -666,17 +670,17 @@ args: - !LinalgOperandDefConfig name: A - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0] -> (s0)> - !LinalgOperandDefConfig name: B - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0] -> (s0)> - !LinalgOperandDefConfig name: C - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0] -> ()> indexing_maps: !LinalgIndexingMapsConfig @@ -701,18 +705,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_1d @@ -728,17 +732,17 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1] -> (s0 + s1)> - !LinalgOperandDefConfig name: K - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1] -> (s1)> - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1] -> (s0)> indexing_maps: !LinalgIndexingMapsConfig @@ -764,18 +768,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d @@ -791,17 +795,17 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3] -> (s0 + s1, s2 + s3)> - !LinalgOperandDefConfig name: K - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3] -> (s1, s3)> - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2)> indexing_maps: !LinalgIndexingMapsConfig @@ -829,18 +833,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_3d @@ -856,17 +860,17 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0 + s1, s2 + s3, s4 + s5)> - !LinalgOperandDefConfig name: K - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s1, s3, s5)> - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s2, s4)> indexing_maps: !LinalgIndexingMapsConfig @@ -897,18 +901,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_1d_nwc_wcf @@ -924,31 +928,31 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s0, s1 * s2 + s3 * s4, s5)> - !LinalgOperandDefConfig name: K - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s3, s5, s6)> - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s0, s1, s6)> - !LinalgOperandDefConfig name: strides - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s2)> - default_vals: + default_indices: - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s4)> - default_vals: + default_indices: - 1 indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: @@ -977,18 +981,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d_nhwc_hwcf @@ -1008,36 +1012,36 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> - !LinalgOperandDefConfig name: K - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s3, s7, s9, s10)> - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, s1, s5, s10)> - !LinalgOperandDefConfig name: strides - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s2, s6)> - default_vals: + default_indices: - 1 - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s4, s8)> - default_vals: + default_indices: - 1 - 1 indexing_maps: !LinalgIndexingMapsConfig @@ -1071,18 +1075,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d_nhwc_hwcf_q @@ -1103,44 +1107,44 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> - !LinalgOperandDefConfig name: K - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s3, s7, s9, s10)> - !LinalgOperandDefConfig name: IZp - usage: Input + kind: scalar type_var: I32 - !LinalgOperandDefConfig name: KZp - usage: Input + kind: scalar type_var: I32 - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, s1, s5, s10)> - !LinalgOperandDefConfig name: strides - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s2, s6)> - default_vals: + default_indices: - 1 - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s4, s8)> - default_vals: + default_indices: - 1 - 1 indexing_maps: !LinalgIndexingMapsConfig @@ -1182,36 +1186,36 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: IZp + fn_name: cast - !ScalarExpression arith_fn: fn_name: sub operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: KZp + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d_nchw_fchw @@ -1231,36 +1235,36 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, s1, s2 * s3 + s4 * s5, s6 * s7 + s8 * s9)> - !LinalgOperandDefConfig name: K - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s10, s1, s4, s8)> - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, s10, s2, s6)> - !LinalgOperandDefConfig name: strides - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s3, s7)> - default_vals: + default_indices: - 1 - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s5, s9)> - default_vals: + default_indices: - 1 - 1 indexing_maps: !LinalgIndexingMapsConfig @@ -1294,18 +1298,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_3d_ndhwc_dhwcf @@ -1321,38 +1325,38 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9 * s10 + s11 * s12, s13)> - !LinalgOperandDefConfig name: K - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14] -> (s3, s7, s11, s13, s14)> - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14] -> (s0, s1, s5, s9, s14)> - !LinalgOperandDefConfig name: strides - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14] -> (s2, s6, s10)> - default_vals: + default_indices: - 1 - 1 - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14] -> (s4, s8, s12)> - default_vals: + default_indices: - 1 - 1 - 1 @@ -1390,18 +1394,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv_1d_nwc_wc @@ -1418,30 +1422,30 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1 * s2 + s3 * s4, s5)> - !LinalgOperandDefConfig name: K - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s3, s5)> - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s5)> - !LinalgOperandDefConfig name: strides - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2)> - default_vals: + default_indices: - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4)> - default_vals: + default_indices: - 1 indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: @@ -1469,18 +1473,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv_2d_nhwc_hwc @@ -1497,35 +1501,35 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> - !LinalgOperandDefConfig name: K - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7, s9)> - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5, s9)> - !LinalgOperandDefConfig name: strides - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)> - default_vals: + default_indices: - 1 - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)> - default_vals: + default_indices: - 1 - 1 indexing_maps: !LinalgIndexingMapsConfig @@ -1558,18 +1562,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv_2d_nhwc_hwc_q @@ -1585,43 +1589,43 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> - !LinalgOperandDefConfig name: K - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7, s9)> - !LinalgOperandDefConfig name: IZp - usage: Input + kind: scalar type_var: I32 - !LinalgOperandDefConfig name: KZp - usage: Input + kind: scalar type_var: I32 - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5, s9)> - !LinalgOperandDefConfig name: strides - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)> - default_vals: + default_indices: - 1 - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)> - default_vals: + default_indices: - 1 - 1 indexing_maps: !LinalgIndexingMapsConfig @@ -1662,36 +1666,36 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: IZp + fn_name: cast - !ScalarExpression arith_fn: fn_name: sub operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: KZp + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv_2d_nhwc_hwcm @@ -1707,36 +1711,36 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> - !LinalgOperandDefConfig name: K - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s3, s7, s9, s10)> - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, s1, s5, s9, s10)> - !LinalgOperandDefConfig name: strides - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s2, s6)> - default_vals: + default_indices: - 1 - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s4, s8)> - default_vals: + default_indices: - 1 - 1 indexing_maps: !LinalgIndexingMapsConfig @@ -1770,18 +1774,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv_2d_nhwc_hwcm_q @@ -1797,44 +1801,44 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> - !LinalgOperandDefConfig name: K - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s3, s7, s9, s10)> - !LinalgOperandDefConfig name: IZp - usage: Input + kind: scalar type_var: I32 - !LinalgOperandDefConfig name: KZp - usage: Input + kind: scalar type_var: I32 - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, s1, s5, s9, s10)> - !LinalgOperandDefConfig name: strides - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s2, s6)> - default_vals: + default_indices: - 1 - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s4, s8)> - default_vals: + default_indices: - 1 - 1 indexing_maps: !LinalgIndexingMapsConfig @@ -1876,36 +1880,36 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: IZp + fn_name: cast - !ScalarExpression arith_fn: fn_name: sub operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: KZp + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_sum @@ -1921,35 +1925,35 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> - !LinalgOperandDefConfig name: K - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)> - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5, s9)> - !LinalgOperandDefConfig name: strides - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)> - default_vals: + default_indices: - 1 - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)> - default_vals: + default_indices: - 1 - 1 indexing_maps: !LinalgIndexingMapsConfig @@ -1978,11 +1982,11 @@ scalar_arg: O - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_max @@ -1998,35 +2002,35 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> - !LinalgOperandDefConfig name: K - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)> - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5, s9)> - !LinalgOperandDefConfig name: strides - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)> - default_vals: + default_indices: - 1 - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)> - default_vals: + default_indices: - 1 - 1 indexing_maps: !LinalgIndexingMapsConfig @@ -2055,11 +2059,11 @@ scalar_arg: O - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_max_unsigned @@ -2075,35 +2079,35 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> - !LinalgOperandDefConfig name: K - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)> - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5, s9)> - !LinalgOperandDefConfig name: strides - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)> - default_vals: + default_indices: - 1 - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)> - default_vals: + default_indices: - 1 - 1 indexing_maps: !LinalgIndexingMapsConfig @@ -2132,11 +2136,11 @@ scalar_arg: O - !ScalarExpression type_fn: - fn_name: cast_unsigned type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast_unsigned --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nchw_max @@ -2152,35 +2156,35 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s2 * s3 + s4 * s5, s6 * s7 + s8 * s9)> - !LinalgOperandDefConfig name: K - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)> - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s2, s6)> - !LinalgOperandDefConfig name: strides - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)> - default_vals: + default_indices: - 1 - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s5, s9)> - default_vals: + default_indices: - 1 - 1 indexing_maps: !LinalgIndexingMapsConfig @@ -2209,11 +2213,11 @@ scalar_arg: O - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_min @@ -2229,35 +2233,35 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> - !LinalgOperandDefConfig name: K - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)> - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5, s9)> - !LinalgOperandDefConfig name: strides - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)> - default_vals: + default_indices: - 1 - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)> - default_vals: + default_indices: - 1 - 1 indexing_maps: !LinalgIndexingMapsConfig @@ -2286,11 +2290,11 @@ scalar_arg: O - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_min_unsigned @@ -2306,35 +2310,35 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> - !LinalgOperandDefConfig name: K - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)> - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5, s9)> - !LinalgOperandDefConfig name: strides - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)> - default_vals: + default_indices: - 1 - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)> - default_vals: + default_indices: - 1 - 1 indexing_maps: !LinalgIndexingMapsConfig @@ -2363,11 +2367,11 @@ scalar_arg: O - !ScalarExpression type_fn: - fn_name: cast_unsigned type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast_unsigned --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_ndhwc_sum @@ -2383,37 +2387,37 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9 * s10 + s11 * s12, s13)> - !LinalgOperandDefConfig name: K - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s3, s7, s11)> - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s0, s1, s5, s9, s13)> - !LinalgOperandDefConfig name: strides - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s2, s6, s10)> - default_vals: + default_indices: - 1 - 1 - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s4, s8, s12)> - default_vals: + default_indices: - 1 - 1 - 1 @@ -2446,11 +2450,11 @@ scalar_arg: O - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_ndhwc_max @@ -2466,37 +2470,37 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9 * s10 + s11 * s12, s13)> - !LinalgOperandDefConfig name: K - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s3, s7, s11)> - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s0, s1, s5, s9, s13)> - !LinalgOperandDefConfig name: strides - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s2, s6, s10)> - default_vals: + default_indices: - 1 - 1 - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s4, s8, s12)> - default_vals: + default_indices: - 1 - 1 - 1 @@ -2529,11 +2533,11 @@ scalar_arg: O - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_ndhwc_min @@ -2549,37 +2553,37 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9 * s10 + s11 * s12, s13)> - !LinalgOperandDefConfig name: K - usage: Input + kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s3, s7, s11)> - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s0, s1, s5, s9, s13)> - !LinalgOperandDefConfig name: strides - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s2, s6, s10)> - default_vals: + default_indices: - 1 - 1 - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s4, s8, s12)> - default_vals: + default_indices: - 1 - 1 - 1 @@ -2612,11 +2616,11 @@ scalar_arg: O - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: fill_tensor @@ -2631,11 +2635,11 @@ args: - !LinalgOperandDefConfig name: value - usage: Input + kind: scalar type_var: T1 - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<() -> ()> indexing_maps: !LinalgIndexingMapsConfig @@ -2648,11 +2652,11 @@ arg: O value: !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: value + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: fill_rng_2d @@ -2671,19 +2675,19 @@ args: - !LinalgOperandDefConfig name: min - usage: Input + kind: scalar type_var: F64 - !LinalgOperandDefConfig name: max - usage: Input + kind: scalar type_var: F64 - !LinalgOperandDefConfig name: seed - usage: Input + kind: scalar type_var: I32 - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: T shape_map: affine_map<()[s0, s1] -> (s0, s1)> indexing_maps: !LinalgIndexingMapsConfig @@ -2700,7 +2704,6 @@ arg: O value: !ScalarExpression type_fn: - fn_name: cast type_var: T operands: - !ScalarExpression @@ -2717,14 +2720,13 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: F64 operands: - !ScalarExpression scalar_const: '2147483647 : i64' + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: F64 operands: - !ScalarExpression @@ -2741,11 +2743,11 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_index: 1 + fn_name: cast - !ScalarExpression arith_fn: fn_name: add @@ -2760,41 +2762,42 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_index: 0 + fn_name: cast - !ScalarExpression scalar_arg: seed - !ScalarExpression type_fn: - fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_const: '1103515245 : i64' + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_const: '12345 : i64' + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_const: '1103515245 : i64' + fn_name: cast - !ScalarExpression type_fn: - fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_const: '12345 : i64' + fn_name: cast + fn_name: cast - !ScalarExpression arith_fn: fn_name: mul @@ -2809,13 +2812,14 @@ scalar_arg: min - !ScalarExpression type_fn: - fn_name: cast type_var: F64 operands: - !ScalarExpression scalar_const: '2.3283063999999999E-10 : f64' + fn_name: cast - !ScalarExpression scalar_arg: min + fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: soft_plus_2d @@ -2829,12 +2833,12 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T shape_map: affine_map<()[s0, s1] -> (s0, s1)> - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1] -> (s0, s1)> indexing_maps: !LinalgIndexingMapsConfig @@ -2857,19 +2861,19 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_const: '1.000000e+00 : f64' + fn_name: cast - !ScalarExpression arith_fn: fn_name: exp operands: - !ScalarExpression type_fn: - fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I + fn_name: cast diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -37,7 +37,7 @@ LogicalResult reifyResultShapes(OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()).reifyResultShapes(b, + return llvm::cast(getOperation()).reifyResultShapes(b, reifiedReturnShapes); } }]; diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt @@ -8,6 +8,8 @@ DEPENDS MLIRLinalgInterfacesIncGen + MLIRLinalgOpsAttributesIncGen + MLIRLinalgOpsEnumsIncGen MLIRLinalgOpsIncGen MLIRLinalgStructuredOpsIncGen diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp @@ -21,6 +21,7 @@ #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; @@ -95,6 +96,10 @@ } void mlir::linalg::LinalgDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.cpp.inc" + >(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc" @@ -144,3 +149,10 @@ return op->emitError() << "attribute '" << attr.getName() << "' not supported by the linalg dialect"; } + +#include "mlir/Dialect/Linalg/IR/LinalgOpsEnums.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.cpp.inc" + +#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc" diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -36,8 +36,6 @@ using namespace mlir; using namespace mlir::linalg; -#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc" - /// Forward declarations. /// Generic entry point to create the block for the region of a LinalgOp. @@ -232,14 +230,14 @@ return operand; } - // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value typefn__cast(Type toType, Value operand) { - return cast(toType, operand, false); - } - - // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value typefn__cast_unsigned(Type toType, Value operand) { - return cast(toType, operand, true); + Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) { + switch (typeFn) { + case TypeFn::cast: + return cast(toType, operand, false); + case TypeFn::cast_unsigned: + return cast(toType, operand, true); + } + llvm_unreachable("unsupported type conversion function"); } // NOLINTNEXTLINE(*-identifier-naming): externally called. diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -111,7 +111,7 @@ @property def tensor_name(self) -> str: name = self.operand_def.name - assert name is not None, "TensorDef not attached" + assert name is not None, "TensorDef not registered with an op" return name def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]: @@ -129,7 +129,8 @@ return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs) def __repr__(self): - return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]" + return (f"{self.operand_def.name}" + f"[{', '.join([repr(i) for i in self.indices])}]") class TensorArithFn(TensorExpression): @@ -156,14 +157,22 @@ class TensorTypeFn(TensorExpression): """Application of a type conversion function.""" - def __init__(self, type_fn: "TypeFn", type_var: TypeVar, + def __init__(self, type_fn: Optional["TypeFn"], + operand_def: Optional["OperandDef"], type_var: TypeVar, arg: TensorExpression): + if bool(type_fn) + bool(operand_def) != 1: + raise ValueError("Either 'type_fn' or 'operand_def' must be specified") self.type_fn = type_fn + self.operand_def = operand_def self.type_var = type_var self.arg = arg def to_scalar_expression(self) -> ScalarExpression: - return ScalarTypeFn(self.type_fn.fn_name, self.type_var, + if self.operand_def: + assert self.operand_def.name, "TypeFnAttr not registered with an op" + fn_name = self.type_fn.fn_name if self.type_fn else None + attr_name = self.operand_def.name if self.operand_def else None + return ScalarTypeFn(fn_name, attr_name, self.type_var, self.arg.to_scalar_expression()).expr() def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): @@ -171,7 +180,8 @@ self.arg.visit_tensor_exprs(callback) def __repr__(self): - return f"{repr(self.type_fn)}({self.type_var}, {self.arg})" + return (f"{repr(self.type_fn)}[{repr(self.operand_def)}]" + f"({self.type_var}, {self.arg})") class TensorReduceFn(TensorExpression): @@ -260,7 +270,7 @@ self.fn_name = fn_name def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TypeFnType": - return TensorTypeFn(self, type_var, arg) + return TensorTypeFn(self, None, type_var, arg) def __repr__(self): return f"{self.fn_name}" @@ -370,10 +380,11 @@ class OperandKind(Enum): - InputTensor = 0 - Scalar = 1 - OutputTensor = 2 - IndexAttr = 3 + INPUT_TENSOR = 0 + SCALAR = 1 + OUTPUT_TENSOR = 2 + INDEX_ATTR = 3 + TYPE_FN_ATTR = 4 class OperandDef: @@ -388,7 +399,8 @@ type_var: Optional[TypeVar] = None, size_exprs: Optional[Sequence[AffineExprDef]] = None, index_dims: Optional[Sequence[DimDef]] = None, - default_vals: Optional[Sequence[int]] = None): + default_indices: Optional[Sequence[int]] = None, + default_fn: Optional[str] = None): if type_var and not isinstance(type_var, TypeVar): raise ValueError( f"OperandDef requires a TypeVar but got {repr(type_var)}") @@ -396,25 +408,40 @@ self.type_var = type_var self.size_exprs = size_exprs self.index_dims = index_dims - self.default_vals = default_vals + self.default_indices = default_indices + self.default_fn = default_fn self.kind = kind self.name = None # type: Optional[str] self.registered_index = -1 # type: int def attach(self, index: int, name: str, owner: "LinalgOpDef"): if self.owner: - raise ValueError(f"OperandDef already registered with op: {self}") + raise ValueError(f"OperandDef already registered with an op: {self}") self.registered_index = index self.name = name self.owner = owner + def is_input(self) -> bool: + return (self.kind == OperandKind.SCALAR or + self.kind == OperandKind.INPUT_TENSOR) + + def is_tensor(self) -> bool: + return (self.kind == OperandKind.INPUT_TENSOR or + self.kind == OperandKind.OUTPUT_TENSOR) + + def is_attribute(self) -> bool: + return (self.kind == OperandKind.INDEX_ATTR or + self.kind == OperandKind.TYPE_FN_ATTR) + def __hash__(self): return hash(id(self)) def __repr__(self): return (f"{self.name}:OperandDef(kind={self.kind.name}, " - f"type={repr(self.type_var)}, size_exprs={self.size_exprs}), " - f"index_dims={self.index_dims}, default_vals={self.default_vals})") + f"type={repr(self.type_var)}, size_exprs={self.size_exprs}, " + f"index_dims={self.index_dims}, " + f"default_indices={self.default_indices}, " + f"default_fn={self.default_fn})") class TensorDef: @@ -440,12 +467,12 @@ if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims): raise ValueError(f"TensorDef requires index dims of type DimDef but " f"got {index_dims}") - kind = OperandKind.OutputTensor if output else OperandKind.InputTensor + kind = OperandKind.OUTPUT_TENSOR if output else OperandKind.INPUT_TENSOR self.operand_def = OperandDef( kind, type_var=type_var, size_exprs=shape, index_dims=index_dims) def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse: - assert self.operand_def.owner, "TensorDef is not attached to an op" + assert self.operand_def.owner, "TensorDef is not registered with an op" state = AffineBuildState( global_state=self.operand_def.owner._affine_state, allow_new_symbols=False) @@ -486,12 +513,12 @@ """ def __init__(self, type_var: TypeVar): - self.operand_def = OperandDef(OperandKind.Scalar, type_var=type_var) + self.operand_def = OperandDef(OperandKind.SCALAR, type_var=type_var) @property def scalar_name(self) -> str: name = self.operand_def.name - assert name is not None, "ScalarDef not attached" + assert name is not None, "ScalarDef not registered with an op" return name def to_scalar_expression(self) -> ScalarExpression: @@ -517,7 +544,26 @@ raise ValueError(f"IndexAttrDef expects {len(sizes)} default values " f"but got {len(default)}") self.operand_def = OperandDef( - OperandKind.IndexAttr, size_exprs=sizes, default_vals=default) + OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default) + + +class TypeFnAttrDef: + """Type conversion function attribute definition. + + Type conversion function attributes provide a way to make type conversions + parameterizable. Every attribute specifies a default type conversion function + that may be overwritten at operation instantiation time. + """ + + def __init__(self, default: "TypeFnType"): + if not isinstance(default, TypeFnType): + raise ValueError(f"TypeFnAttrDef requires default of type TypeFnType " + f"but got {default}") + self.operand_def = OperandDef( + OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name) + + def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorTypeFn: + return TensorTypeFn(None, self.operand_def, type_var, arg) ############################################################################### @@ -615,17 +661,21 @@ if name in self.registered_operands: raise ValueError(f"The operand {name} is already registered " f"to {self.registered_operands['name']}") + structured_op_methods = [ + "inputs", "outputs", "result_tensors", "region", "iterator_types", + "indexing_maps", "getRegionBuilder", "getLibraryCallName" + ] + if operand.is_attribute() and name in structured_op_methods: + raise ValueError(f"The attribute name {name} conflicts with a structured " + f"op method name") # Ensure output tensors are registered after input tensors and scalars and # attributes are registered after all other operand types. - registered_kinds = [ - operand.kind.value for operand in self.registered_operands.values() - ] - if registered_kinds: - maximum = max(registered_kinds) - if maximum > operand.kind.value and maximum > OperandKind.Scalar.value: - raise ValueError( - f"The operand {name} of kind {operand.kind.name} is registered " - f"after an operand of kind {OperandKind(maximum).name}") + if operand.is_input() and any( + not op_def.is_input() for op_def in self.registered_operands.values()): + raise ValueError(f"Input {name} registered after an output or attribute") + if operand.kind == OperandKind.OUTPUT_TENSOR and any( + op_def.is_attribute() for op_def in self.registered_operands.values()): + raise ValueError(f"Output {name} registered after an attribute") operand.attach(len(self.registered_operands), name, self) self.registered_operands[name] = operand diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -56,27 +56,25 @@ return self.operand_def.name @property - def type_var(self) -> TypeVar: - return self.operand_def.type_var + def kind(self) -> OperandKind: + return self.operand_def.kind @property - def usage(self) -> str: - if self.operand_def.kind == OperandKind.IndexAttr: - return "IndexAttr" - if self.operand_def.kind == OperandKind.OutputTensor: - return "Output" - return "Input" + def type_var(self) -> TypeVar: + return self.operand_def.type_var def to_yaml_custom_dict(self): - self_dict = dict(name=self.name, usage=self.usage) + self_dict = dict(name=self.name, kind=self.operand_def.kind.name.lower()) if self.type_var: self_dict["type_var"] = self.type_var.name if self.shape_map: self_dict["shape_map"] = _serialize_affine_map(self.shape_map) if self.index_attr_map: self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map) - if self.operand_def.default_vals: - self_dict["default_vals"] = self.operand_def.default_vals + if self.operand_def.default_indices: + self_dict["default_indices"] = self.operand_def.default_indices + if self.operand_def.default_fn: + self_dict["default_fn"] = self.operand_def.default_fn return self_dict def __repr__(self): @@ -166,7 +164,7 @@ # Collect all attribute definitions. collected_attr_defs = list() for operand in registered_operands: - if operand.kind == OperandKind.IndexAttr: + if operand.is_attribute(): collected_attr_defs.append(operand) # Collect all tensors with manual indexing annotation. @@ -244,12 +242,12 @@ # Set the indexing map of all scalar uses to the empty map. for operand_config in self.operands.values(): - if operand_config.operand_def.kind == OperandKind.Scalar: + if operand_config.operand_def.kind == OperandKind.SCALAR: operand_config.indexing_map = self._get_scalar_map() # Check all registered tensor and scalar operands have an indexing map. for operand in registered_operands: - if operand.kind == OperandKind.IndexAttr: + if operand.is_attribute(): continue if not (operand in self.operands and self.operands[operand].indexing_map): raise ValueError(f"Failed to compute an indexing map for operand " @@ -311,7 +309,8 @@ def add_operand(self, operand_def: OperandDef): if operand_def in self.operands: return - if operand_def.kind == OperandKind.Scalar: + if (operand_def.kind == OperandKind.SCALAR or + operand_def.kind == OperandKind.TYPE_FN_ATTR): self.operands[operand_def] = OperandDefConfig(operand_def) return with self.context: @@ -323,7 +322,7 @@ assert local_state.local_dim_count == 0 affine_map = _ir.AffineMap.get( dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs) - if operand_def.kind == OperandKind.IndexAttr: + if operand_def.kind == OperandKind.INDEX_ATTR: self.operands[operand_def] = OperandDefConfig( operand_def, index_attr_map=affine_map) else: @@ -429,8 +428,7 @@ context: Optional[_ir.Context] = None) -> Sequence["LinalgOpConfig"]: """Expands a LinalgOpDef into corresponding Linalg configured ops.""" # TODO: Many LinalgOpDef patterns need to expand to multiple generics. - assert len( - op_def.comprehensions) == 1, "Only one comprehension supported" + assert len(op_def.comprehensions) == 1, "Only one comprehension supported" return [ LinalgOpConfig( op_def.metadata, diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -129,7 +129,8 @@ sig = inspect.signature(dsl_func) for param_name, param in sig.parameters.items(): param_default = param.default - if isinstance(param_default, (TensorDef, ScalarDef, IndexAttrDef)): + if isinstance(param_default, + (TensorDef, ScalarDef, IndexAttrDef, TypeFnAttrDef)): op_def.add_operand(param_name, param_default.operand_def) else: raise ValueError( diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -37,11 +37,21 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, outs: ValueList, - **attrs: Sequence[int]): + **attrs: Union[Sequence[int], TypeFnType]): all_arg_defs = op_config.ordered_operands - in_arg_defs = [d for d in all_arg_defs if d.usage == "Input"] - out_arg_defs = [d for d in all_arg_defs if d.usage == "Output"] - index_attr_arg_defs = [d for d in all_arg_defs if d.usage == "IndexAttr"] + in_arg_defs = [ + d for d in all_arg_defs + if d.kind == OperandKind.SCALAR or d.kind == OperandKind.INPUT_TENSOR + ] + out_arg_defs = [ + d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR + ] + index_attr_arg_defs = [ + d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR + ] + type_fn_attr_arg_defs = [ + d for d in all_arg_defs if d.kind == OperandKind.TYPE_FN_ATTR + ] # Verify outs is a sequence or a list of results. if not isinstance(outs, (Sequence, OpResultList)): @@ -56,11 +66,11 @@ raise ValueError(f"Expected {len(out_arg_defs)} outputs but got " f"{len(outs)} for {op_config}") - # Compute a replacement list for all attribute symbols. + # Compute a replacement list for all index attribute symbols. expressions = [] # type: Sequence[AffineExpr] replacements = [] # type: Sequence[AffineExpr] for index_attr in index_attr_arg_defs: - index_attr_vals = index_attr.operand_def.default_vals + index_attr_vals = index_attr.operand_def.default_indices if index_attr.name in attrs: index_attr_vals = attrs.get(index_attr.name) assert index_attr_vals, "Index attribute has no value" @@ -125,15 +135,29 @@ array = np.array(index_attr_vals, dtype=np.int64) index_attrs[index_attr.name] = DenseElementsAttr.get(array) + # Compute the type function attribute mapping. + type_fn_attr_mapping = {} + for type_fn_attr in type_fn_attr_arg_defs: + attr_val = type_fn_attr.operand_def.default_fn + if type_fn_attr.name in attrs: + type_fn = attrs.get(type_fn_attr.name) + if not isinstance(type_fn, TypeFnType): + raise ValueError(f"Attribute {type_fn_attr.name} needs to be of type " + f"TypeFnType but got {type(attr_val)}") + attr_val = type_fn.fn_name + assert attr_val, "Type function attribute has no value" + type_fn_attr_mapping[type_fn_attr.name] = attr_val + return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, - type_mapping, indexing_maps_attr, iterator_types_attr, - index_attrs, block_arg_types) + type_mapping, indexing_maps_attr, iterator_types_attr, index_attrs, + type_fn_attr_mapping, block_arg_types) def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, outs: ValueList, **attrs: Sequence[int]): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - indexing_maps_attr, iterator_types_attr, index_attrs, block_arg_types = \ + indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \ + block_arg_types = \ prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) # An operation that accesses only scalars and scalar/rank zero tensors is @@ -147,10 +171,9 @@ tensor_map = AffineMap.get_identity(rank) indexing_maps = [] for arg_def in all_arg_defs: - if arg_def.operand_def.kind == OperandKind.Scalar: + if arg_def.operand_def.kind == OperandKind.SCALAR: indexing_maps.append(scalar_map) - if (arg_def.operand_def.kind == OperandKind.InputTensor or - arg_def.operand_def.kind == OperandKind.OutputTensor): + if arg_def.operand_def.is_tensor(): indexing_maps.append(tensor_map) indexing_maps_attr = ArrayAttr.get( [AffineMapAttr.get(am) for am in indexing_maps]) @@ -169,7 +192,8 @@ block = generic_op.regions[0].blocks.append(*block_arg_types) block_arg_mapping = dict(zip(block_arg_names, block.arguments)) with InsertionPoint(block): - body_builder = _BodyBuilder(type_mapping, block_arg_mapping) + body_builder = _BodyBuilder(type_mapping, block_arg_mapping, + type_fn_attr_mapping) for assignment in op_config.assignments: body_builder.assign(assignment) body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs)) @@ -184,7 +208,8 @@ op_class_name: str, *ins: Value, outs: ValueList, **attrs: Sequence[int]): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - indexing_maps_attr, iterator_types_attr, index_attrs, block_arg_types = \ + indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \ + block_arg_types = \ prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) # If we get here, there must exist a builtin class `op_class_name`. @@ -200,6 +225,11 @@ for name, value in index_attrs.items(): named_op.operation.attributes[name] = value + # Set the type function attributes. + for name, value in type_fn_attr_mapping.items(): + named_op.operation.attributes[name] = Attribute.parse( + f"#linalg.type_fn<{value}>") + linalg.fill_builtin_region(named_op.operation) if len(result_types) == 1: @@ -212,9 +242,11 @@ """Constructs a structured op body by evaluating assignments.""" def __init__(self, type_mapping: Dict[str, Type], - block_arg_mapping: Dict[str, Value]): + block_arg_mapping: Dict[str, Value], + type_fn_attr_mapping: Dict[str, str]): self.type_mapping = type_mapping self.block_arg_mapping = block_arg_mapping + self.type_fn_attr_mapping = type_fn_attr_mapping self.yield_mapping = dict() # type: Dict[str, Value] def assign(self, assignment: ScalarAssign): @@ -245,7 +277,10 @@ ] return fn(*operand_values) elif expr.type_fn: - fn = self._get_function(f"_typefn_{expr.type_fn.fn_name}") + fn_name = expr.type_fn.fn_name + if expr.type_fn.attr_name: + fn_name = self.type_fn_attr_mapping[expr.type_fn.attr_name] + fn = self._get_function(f"_typefn_{fn_name}") operand = self.expression(expr.type_fn.operand) return fn(expr.type_fn.type_var.name, operand) raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py @@ -46,9 +46,10 @@ class ScalarTypeFn: """A type of ScalarExpression that applies a type conversion function.""" - def __init__(self, fn_name: str, type_var: TypeVar, - operand: "ScalarExpression"): + def __init__(self, fn_name: Optional[str], attr_name: Optional[str], + type_var: TypeVar, operand: "ScalarExpression"): self.fn_name = fn_name + self.attr_name = attr_name self.type_var = type_var self.operand = operand @@ -56,7 +57,8 @@ return ScalarExpression(type_fn=self) def __repr__(self): - return f"ScalarTypeFn<{self.fn_name}>({self.type_var}, {self.operand})" + return (f"ScalarTypeFn<{self.fn_name}[{self.attr_name}]>" + f"({self.type_var}, {self.operand})") class ScalarArg: @@ -138,12 +140,15 @@ # Note that even though operands must be arity 1, we write it the # same way as for apply because it allows handling code to be more # generic vs having a special form. - return dict( - type_fn=dict( - fn_name=self.type_fn.fn_name, - type_var=self.type_fn.type_var.name, - operands=[self.type_fn.operand], - )) + type_fn_dict = dict( + type_var=self.type_fn.type_var.name, + operands=[self.type_fn.operand], + ) + if self.type_fn.fn_name: + type_fn_dict["fn_name"] = self.type_fn.fn_name + if self.type_fn.attr_name: + type_fn_dict["attr_name"] = self.type_fn.attr_name + return dict(type_fn=type_fn_dict) elif self.scalar_arg: return dict(scalar_arg=self.scalar_arg.arg) elif self.scalar_const: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -10,7 +10,8 @@ def matmul( A=TensorDef(T1, S.M, S.K), B=TensorDef(T2, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True)): + C=TensorDef(U, S.M, S.N, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast)): """Performs a matrix multiplication of two 2D inputs. Numeric casting is performed on the operands to the inner multiply, promoting @@ -18,7 +19,7 @@ """ domain(D.m, D.n, D.k) implements(ContractionOpInterface) - C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n]) + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) @linalg_structured_op diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -36,6 +36,20 @@ // CHECK-NEXT: linalg.yield %[[ADD]] : i32 // CHECK-NEXT: -> tensor<16x32xi32> + +// ----- + +// Verifies that cast attributes control the cast operations used. +func @generalize_matmul_tensor_i16i64i32_unsigned(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> { + %0 = linalg.matmul {cast = #linalg.type_fn} + ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>) + outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32> + return %0: tensor<16x32xi32> +} + +// CHECK-LABEL: @generalize_matmul_tensor_i16i64i32_unsigned +// CHECK: = arith.extui + // ----- func @generalize_matmul_tensor_i16i64f32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py @@ -1229,7 +1229,7 @@ value = self._emit_expression(expr_to_input_opnd, expr_to_info) # Emit the structured op representation for the destination tensor. dst_opnd = _emit_operand(op_def, op_info.dst_indices, op_info.dst_name, - lang.OperandKind.OutputTensor) + lang.OperandKind.OUTPUT_TENSOR) dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices) dst_use = lang.TensorUse(dst_opnd, dst_dim_syms) @@ -1864,6 +1864,6 @@ name = expr.tensor.name dim_sym = _mlir_symbols_from_index_vars(indices) - opnd = lang.OperandDef(lang.OperandKind.InputTensor, lang.T, dim_sym) + opnd = lang.OperandDef(lang.OperandKind.INPUT_TENSOR, lang.T, dim_sym) op_def.add_operand(name, opnd) return opnd diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -2,7 +2,8 @@ # RUN: mlir-linalg-ods-yaml-gen %s --o-impl=- | FileCheck %s --check-prefix=IMPL # @linalg_structured_op -# def test1(O=TensorDef(T, S.M, S.N, output=True)): +# def test1(O=TensorDef(T, S.M, S.N, output=True), +# cast=TypeFnAttrDef(default=TypeFn.cast)): # """Title. # Detailed description. @@ -21,9 +22,13 @@ args: - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: T shape_map: affine_map<()[s0, s1] -> (s0, s1)> + - !LinalgOperandDefConfig + name: cast + kind: type_fn_attr + default_fn: cast indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1)[s0, s1] -> (d0, d1)> @@ -39,18 +44,18 @@ operands: - !ScalarExpression type_fn: - fn_name: cast type_var: T operands: - !ScalarExpression scalar_const: '42 : i64' + attr_name: cast - !ScalarExpression type_fn: - fn_name: cast_unsigned type_var: T operands: - !ScalarExpression scalar_index: 1 + attr_name: cast # ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1" @@ -61,16 +66,22 @@ # ODS: let arguments = # ODS-NEXT: Variadic:$inputs, -# ODS-NEXT: Variadic:$outputs +# ODS-NEXT: Variadic:$outputs, +# ODS-NEXT: DefaultValuedAttr:$cast # ODS: let builders = # ODS: (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, # ODS-NEXT: "ValueRange":$outputs, # ODS-NEXT: CArg<"ArrayRef", "{}">:$attributes), +# ODS: (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, +# ODS-NEXT: "ValueRange":$outputs, "Attribute":$cast, +# ODS-NEXT: CArg<"ArrayRef", "{}">:$attributes), + # ODS: $_state.addOperands(inputs); # ODS-NEXT: $_state.addOperands(outputs); # ODS-NEXT: $_state.addTypes(resultTensorTypes); +# ODS-NEXT: $_state.addAttribute("cast", cast) # ODS-NEXT: $_state.addAttributes(attributes); # ODS-NEXT: $_state.addAttribute( # ODS-NEXT: "operand_segment_sizes", @@ -85,10 +96,18 @@ # IMPL-LABEL: void Test1Op::regionBuilder(ImplicitLocOpBuilder &b, # IMPL-NEXT: Block &block, ArrayRef attrs) +# IMPL: TypeFn castVal = TypeFn::cast; +# IMPL-NEXT: auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) { +# IMPL-NEXT: return attr.getName() == "cast"; }); +# IMPL-NEXT: if (castIter != attrs.end()) { +# IMPL-NEXT: if (auto attr = castIter->getValue().dyn_cast()) +# IMPL-NEXT: castVal = attr.getValue(); +# IMPL-NEXT: } + # IMPL: Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64"); -# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.typefn__cast(block.getArgument(0).getType(), [[VAL0]]); +# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL0]]); # IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1); -# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.typefn__cast_unsigned(block.getArgument(0).getType(), [[VAL2]]); +# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL2]]); # IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.arithfn__add([[VAL1]], [[VAL3]]); @@ -114,19 +133,19 @@ args: - !LinalgOperandDefConfig name: I - usage: Input + kind: input_tensor type_var: T shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)> - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: T shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)> - !LinalgOperandDefConfig name: strides - usage: IndexAttr + kind: index_attr index_attr_map: affine_map<()[s0, s1, s2, s3] -> (s2, s3)> - default_vals: + default_indices: - 1 - 2 indexing_maps: !LinalgIndexingMapsConfig @@ -201,11 +220,11 @@ args: - !LinalgOperandDefConfig name: value - usage: Input + kind: scalar type_var: T1 - !LinalgOperandDefConfig name: O - usage: Output + kind: output_tensor type_var: U shape_map: affine_map<() -> ()> indexing_maps: !LinalgIndexingMapsConfig diff --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py --- a/mlir/test/python/dialects/linalg/opdsl/arguments.py +++ b/mlir/test/python/dialects/linalg/opdsl/arguments.py @@ -7,30 +7,34 @@ # CHECK-LABEL: matmul # CHECK: args: # CHECK: name: A -# CHECK: usage: Input +# CHECK: kind: input_tensor # CHECK: type_var: T # CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> # CHECK: name: B -# CHECK: usage: Input +# CHECK: kind: input_tensor # CHECK: type_var: T # CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)> # CHECK: name: C -# CHECK: usage: Output +# CHECK: kind: output_tensor # CHECK: type_var: U # CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> +# CHECK: name: cast +# CHECK: kind: type_fn_attr +# CHECK: default_fn: cast @linalg_structured_op def matmul( A=TensorDef(T, S.M, S.K), B=TensorDef(T, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True)): - C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n]) + C=TensorDef(U, S.M, S.N, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast)): + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) # CHECK: --- # CHECK-LABEL: fill # CHECK: args: # CHECK: name: value -# CHECK: usage: Input +# CHECK: kind: scalar # CHECK-NOT: shape_map: # CHECK: type_var: T @linalg_structured_op @@ -42,17 +46,17 @@ # CHECK-LABEL: strided_copy # CHECK: args: # CHECK: name: I -# CHECK: usage: Input +# CHECK: kind: input_tensor # CHECK: type_var: T # CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)> # CHECK: name: O -# CHECK: usage: Output +# CHECK: kind: output_tensor # CHECK: type_var: T # CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)> # CHECK: name: strides -# CHECK: usage: IndexAttr +# CHECK: kind: index_attr # CHECK: index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s5)> -# CHECK: default_vals: +# CHECK: default_indices: # CHECK: - 1 # CHECK: - 2 @linalg_structured_op diff --git a/mlir/test/python/dialects/linalg/opdsl/assignments.py b/mlir/test/python/dialects/linalg/opdsl/assignments.py --- a/mlir/test/python/dialects/linalg/opdsl/assignments.py +++ b/mlir/test/python/dialects/linalg/opdsl/assignments.py @@ -19,16 +19,19 @@ # CHECK: type_var: U # CHECK: operands: # CHECK: scalar_arg: A +# CHECK: attr_name: cast # CHECK: type_fn: # CHECK: type_var: U # CHECK: operands: # CHECK: scalar_arg: B +# CHECK: attr_name: cast @linalg_structured_op def matmul( A=TensorDef(T, S.M, S.K), B=TensorDef(T, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True)): - C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n]) + C=TensorDef(U, S.M, S.N, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast)): + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) # CHECK: --- diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py @@ -24,19 +24,10 @@ def matmul_poly( A=TensorDef(T1, S.M, S.K), B=TensorDef(T2, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True)): + C=TensorDef(U, S.M, S.N, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast)): domain(D.m, D.n, D.k) - C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n]) - - -@linalg_structured_op -def matmul_unsigned_poly( - A=TensorDef(T1, S.M, S.K), - B=TensorDef(T2, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True)): - domain(D.m, D.n, D.k) - C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned( - U, B[D.k, D.n]) + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) with Context() as ctx, Location.unknown(): @@ -92,7 +83,8 @@ RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8), RankedTensorType.get((4, 8), i32)) def test_i8i8i32_matmul_unsigned(lhs, rhs, init_result): - return matmul_unsigned_poly(lhs, rhs, outs=[init_result]) + return matmul_poly( + lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned) # CHECK-LABEL: @test_i8i16i32_matmul # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32) @@ -143,7 +135,8 @@ RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8), RankedTensorType.get((4, 8), f32)) def test_i8i8f32_matmul_unsigned(lhs, rhs, init_result): - return matmul_unsigned_poly(lhs, rhs, outs=[init_result]) + return matmul_poly( + lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned) # CHECK-LABEL: @test_f16f16f32_matmul # CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32) diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -6,6 +6,8 @@ from mlir.dialects import std from mlir.dialects import arith +from mlir.dialects.linalg.opdsl.lang import * + def run(f): print("\nTEST:", f.__name__) @@ -98,12 +100,14 @@ init_result = linalg.InitTensorOp([4, 8], f32) # First check the named form with custom format # CHECK: linalg.matmul + # CHECK: cast = #linalg.type_fn # CHECK-NOT: linalg.memoized_indexing_maps # CHECK-SAME: ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>) # CHECK-SAME: outs(%{{.*}} : tensor<4x8xf32>) # CHECK-SAME: -> tensor<4x8xf32> # CHECK-NEXT: return - return linalg.matmul(lhs, rhs, outs=[init_result.result]) + return linalg.matmul( + lhs, rhs, outs=[init_result.result], cast=TypeFn.cast_unsigned) print(module) diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py --- a/mlir/test/python/integration/dialects/linalg/opsrun.py +++ b/mlir/test/python/integration/dialects/linalg/opsrun.py @@ -9,6 +9,8 @@ from mlir.passmanager import * from mlir.execution_engine import * +from mlir.dialects.linalg.opdsl.lang import * + # Log everything to stderr and flush so that we have a unified stream to match # errors/info emitted by MLIR to stderr. @@ -20,21 +22,28 @@ matmul_boiler = """ func @main() -> f32 attributes {llvm.emit_c_interface} { %v0 = arith.constant 0.0 : f32 - %v1 = arith.constant 1.0 : f32 + %v1 = arith.constant -1 : i8 %v2 = arith.constant 2.0 : f32 - %A = memref.alloc() : memref<4x16xf32> + %A = memref.alloc() : memref<4x16xi8> %B = memref.alloc() : memref<16x8xf32> - %C = memref.alloc() : memref<4x8xf32> - linalg.fill(%v1, %A) : f32, memref<4x16xf32> + %C0 = memref.alloc() : memref<4x8xf32> + %C1 = memref.alloc() : memref<4x8xf32> + linalg.fill(%v1, %A) : i8, memref<4x16xi8> linalg.fill(%v2, %B) : f32, memref<16x8xf32> - linalg.fill(%v0, %C) : f32, memref<4x8xf32> + linalg.fill(%v0, %C0) : f32, memref<4x8xf32> + linalg.fill(%v0, %C1) : f32, memref<4x8xf32> - call @matmul_on_buffers(%A, %B, %C) : - (memref<4x16xf32>, memref<16x8xf32>, memref<4x8xf32>) -> () + call @matmul_signed_on_buffers(%A, %B, %C0) : + (memref<4x16xi8>, memref<16x8xf32>, memref<4x8xf32>) -> () + call @matmul_unsigned_on_buffers(%A, %B, %C1) : + (memref<4x16xi8>, memref<16x8xf32>, memref<4x8xf32>) -> () %c0 = arith.constant 0 : index - %0 = memref.load %C[%c0, %c0] : memref<4x8xf32> + %res0 = memref.load %C0[%c0, %c0] : memref<4x8xf32> + %res1 = memref.load %C1[%c0, %c0] : memref<4x8xf32> + + %0 = arith.addf %res0, %res1 : f32 // TODO: FFI-based solution to allow testing and printing with python code. return %0 : f32 @@ -157,8 +166,8 @@ pm = PassManager.parse( "builtin.func(convert-linalg-to-loops, lower-affine, " + - "convert-scf-to-cf, arith-expand, memref-expand), convert-vector-to-llvm," + - "convert-memref-to-llvm, convert-std-to-llvm," + + "convert-scf-to-cf, arith-expand, memref-expand), convert-vector-to-llvm," + + "convert-memref-to-llvm, convert-std-to-llvm," + "reconcile-unrealized-casts") pm.run(mod) return mod @@ -168,14 +177,21 @@ with Context() as ctx, Location.unknown(): module = Module.create() f32 = F32Type.get() + i8 = IntegerType.get_signless(8) with InsertionPoint(module.body): @builtin.FuncOp.from_py_func( - MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32), + MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32), MemRefType.get((4, 8), f32)) - def matmul_on_buffers(lhs, rhs, out): + def matmul_signed_on_buffers(lhs, rhs, out): linalg.matmul(lhs, rhs, outs=[out]) + @builtin.FuncOp.from_py_func( + MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32), + MemRefType.get((4, 8), f32)) + def matmul_unsigned_on_buffers(lhs, rhs, out): + linalg.matmul(lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned) + execution_engine = ExecutionEngine(transform(module, matmul_boiler)) # TODO: FFI-based solution to allow testing and printing with python code. @@ -186,7 +202,9 @@ execution_engine.invoke("main", res) log("RESULT: ", res[0]) - # CHECK: RESULT: 32.0 + # matmul_signed_on_buffers: -1 * 2.0 * 16 = -32 + # matmul_unsigned_on_buffers: (2^8-1) * 2.0 * 16 = 8160 + # CHECK: RESULT: 8128 test_matmul_builtin() @@ -196,14 +214,22 @@ with Context() as ctx, Location.unknown(): module = Module.create() f32 = F32Type.get() + i8 = IntegerType.get_signless(8) with InsertionPoint(module.body): @builtin.FuncOp.from_py_func( - MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32), + MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32), MemRefType.get((4, 8), f32)) - def matmul_on_buffers(lhs, rhs, out): + def matmul_signed_on_buffers(lhs, rhs, out): linalg.matmul(lhs, rhs, outs=[out], emit_generic=True) + @builtin.FuncOp.from_py_func( + MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32), + MemRefType.get((4, 8), f32)) + def matmul_unsigned_on_buffers(lhs, rhs, out): + linalg.matmul( + lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned, emit_generic=True) + execution_engine = ExecutionEngine(transform(module, matmul_boiler)) # TODO: FFI-based solution to allow testing and printing with python code. @@ -214,7 +240,9 @@ execution_engine.invoke("main", res) log("RESULT: ", res[0]) - # CHECK: RESULT: 32.0 + # matmul_signed_on_buffers = -1 * 2.0 * 16 = -32 + # matmul_unsigned_on_buffers = (2^8-1) * 2.0 * 16 = 8160 + # CHECK: RESULT: 8128 test_matmul_generic() @@ -423,11 +451,7 @@ MemRefType.get((1, 2, 4, 1), i32)) # Set the strides and use the default dilations. def pooling_on_buffers(input, shape, output): - linalg.pooling_nhwc_min( - input, - shape, - outs=[output], - strides=[2, 4]) + linalg.pooling_nhwc_min(input, shape, outs=[output], strides=[2, 4]) execution_engine = ExecutionEngine(transform(module, pooling_boiler)) @@ -458,11 +482,7 @@ # Set the strides and use the default dilations. def pooling_on_buffers(input, shape, output): linalg.pooling_nhwc_min( - input, - shape, - outs=[output], - strides=[2, 4], - emit_generic=True) + input, shape, outs=[output], strides=[2, 4], emit_generic=True) execution_engine = ExecutionEngine(transform(module, pooling_boiler)) diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -61,15 +61,22 @@ AffineMap affineMap() { return affineMapAttr.getValue(); } }; -enum class LinalgOperandDefUsage { Input, Output, IndexAttr }; +enum class LinalgOperandDefKind { + InputTensor, + Scalar, + OutputTensor, + IndexAttr, + TypeFnAttr +}; struct LinalgOperandDef { std::string name; - LinalgOperandDefUsage usage; + LinalgOperandDefKind kind; Optional typeVar; Optional shapeMap; Optional indexAttrMap; - Optional> defaultVals; + Optional> defaultIndices; + Optional defaultFn; }; enum class LinalgIteratorTypeDef { @@ -91,11 +98,12 @@ }; struct ScalarTypeFn { - std::string fnName; std::string typeVar; // NOTE: This must be of arity 1, but to break the self-referential cycle, // we use a heap allocated vector. std::vector operands; + Optional fnName; + Optional attrName; }; struct ScalarExpression { @@ -180,27 +188,32 @@ /// index attribute symbols. During op creation these symbols are replaced /// by the corresponding `name` index attribue values. Only index attribute /// arguments have an `index_attr_map`. -/// - `default_vals`: An optional default initialization for index attribute +/// - `default_indices`: An optional default initialization for index +/// attribute arguments. +/// - `default_fn`: An optional default initialization for function attribute /// arguments. template <> struct MappingTraits { static void mapping(IO &io, LinalgOperandDef &info) { io.mapRequired("name", info.name); - io.mapRequired("usage", info.usage); + io.mapRequired("kind", info.kind); io.mapOptional("type_var", info.typeVar); io.mapOptional("shape_map", info.shapeMap); io.mapOptional("index_attr_map", info.indexAttrMap); - io.mapOptional("default_vals", info.defaultVals); + io.mapOptional("default_indices", info.defaultIndices); + io.mapOptional("default_fn", info.defaultFn); } }; /// Usage enum for a named argument. template <> -struct ScalarEnumerationTraits { - static void enumeration(IO &io, LinalgOperandDefUsage &value) { - io.enumCase(value, "Input", LinalgOperandDefUsage::Input); - io.enumCase(value, "Output", LinalgOperandDefUsage::Output); - io.enumCase(value, "IndexAttr", LinalgOperandDefUsage::IndexAttr); +struct ScalarEnumerationTraits { + static void enumeration(IO &io, LinalgOperandDefKind &value) { + io.enumCase(value, "input_tensor", LinalgOperandDefKind::InputTensor); + io.enumCase(value, "scalar", LinalgOperandDefKind::Scalar); + io.enumCase(value, "output_tensor", LinalgOperandDefKind::OutputTensor); + io.enumCase(value, "index_attr", LinalgOperandDefKind::IndexAttr); + io.enumCase(value, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr); } }; @@ -281,9 +294,10 @@ template <> struct MappingTraits { static void mapping(IO &io, ScalarTypeFn &info) { - io.mapRequired("fn_name", info.fnName); io.mapRequired("type_var", info.typeVar); io.mapRequired("operands", info.operands); + io.mapOptional("fn_name", info.fnName); + io.mapOptional("attr_name", info.attrName); } }; @@ -399,8 +413,9 @@ // Search all argument types. for (const auto &it : llvm::enumerate(args)) { - if (it.value().usage != LinalgOperandDefUsage::Input && - it.value().usage != LinalgOperandDefUsage::Output) + if (it.value().kind != LinalgOperandDefKind::InputTensor && + it.value().kind != LinalgOperandDefKind::Scalar && + it.value().kind != LinalgOperandDefKind::OutputTensor) continue; if (it.value().typeVar.getValue() == typeVar) return llvm::formatv("block.getArgument({0}).getType()", it.index()) @@ -552,6 +567,8 @@ $_state.addOperands(inputs); $_state.addOperands(outputs); $_state.addTypes(resultTensorTypes); + {2} + $_state.addAttributes(attributes); $_state.addAttribute( "operand_segment_sizes", $_builder.getI32VectorAttr({{ @@ -562,8 +579,6 @@ $_state, TypeRange(inputs), TypeRange(outputs)); - {2} - $_state.addAttributes(attributes); }]> )FMT"; @@ -681,42 +696,56 @@ interfaceNameList = interleaveToString(opConfig.metadata->implements, ", "); - // Assemble the attribute specific logic required for the op definition. if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { - return arg.usage == LinalgOperandDefUsage::IndexAttr; + return arg.kind == LinalgOperandDefKind::IndexAttr || + arg.kind == LinalgOperandDefKind::TypeFnAttr; })) { SmallVector attrDefs; SmallVector attrParams; SmallVector attrStmts; for (LinalgOperandDef &arg : opConfig.structuredOp->args) { - if (arg.usage != LinalgOperandDefUsage::IndexAttr) - continue; - assert(arg.indexAttrMap.hasValue()); - assert(arg.defaultVals.hasValue()); - size_t size = arg.indexAttrMap->affineMap().getNumResults(); - assert(arg.defaultVals.getValue().size() == size); - static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>"; - static const char defFmt[] = "DefaultValuedAttr<{0}, \"{1}\">:${2}"; static const char paramFmt[] = "\"Attribute\":${0}"; static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});"; - std::string defaultVals; - llvm::raw_string_ostream ss(defaultVals); - ss << "{ "; - llvm::interleave( - arg.defaultVals.getValue(), ss, - [&](int64_t val) { ss << "static_cast(" << val << ")"; }, - ", "); - ss << " }"; - attrDefs.push_back(llvm::formatv(defFmt, llvm::formatv(typeFmt, size), - ss.str(), arg.name)); - attrParams.push_back(llvm::formatv(paramFmt, arg.name)); - attrStmts.push_back(llvm::formatv(stmtFmt, arg.name)); + // Add the type conversion attributes to the op definition and builders. + if (arg.kind == LinalgOperandDefKind::TypeFnAttr) { + assert(arg.defaultFn.hasValue()); + static const char typeFmt[] = "TypeFn::{0}"; + static const char defFmt[] = "DefaultValuedAttr<{0}, \"{1}\">:${2}"; + attrDefs.push_back(llvm::formatv(defFmt, "TypeFnAttr", + llvm::formatv(typeFmt, arg.defaultFn), + arg.name)); + attrParams.push_back(llvm::formatv(paramFmt, arg.name)); + attrStmts.push_back(llvm::formatv(stmtFmt, arg.name)); + } + // Add the index attributes to the op definition and builders. + if (arg.kind == LinalgOperandDefKind::IndexAttr) { + assert(arg.indexAttrMap.hasValue()); + assert(arg.defaultIndices.hasValue()); + size_t size = arg.indexAttrMap->affineMap().getNumResults(); + assert(arg.defaultIndices.getValue().size() == size); + static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>"; + static const char defFmt[] = "DefaultValuedAttr<{0}, \"{ {1} }\">:${2}"; + std::string defaultVals; + llvm::raw_string_ostream ss(defaultVals); + llvm::interleave( + arg.defaultIndices.getValue(), ss, + [&](int64_t val) { ss << "static_cast(" << val << ")"; }, + ", "); + attrDefs.push_back(llvm::formatv(defFmt, llvm::formatv(typeFmt, size), + ss.str(), arg.name)); + attrParams.push_back(llvm::formatv(paramFmt, arg.name)); + attrStmts.push_back(llvm::formatv(stmtFmt, arg.name)); + } + } + if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { + return arg.kind == LinalgOperandDefKind::IndexAttr; + })) { + attrMethods = R"( + bool hasDynamicIndexingMaps(); + LogicalResult verifyIndexingMapRequiredAttributes(); + )"; } attrList = ",\n" + llvm::join(attrDefs, ",\n"); - attrMethods = R"( - bool hasDynamicIndexingMaps(); - LogicalResult verifyIndexingMapRequiredAttributes(); - )"; attrBuilder = llvm::formatv( structuredOpBuilderFormat, opConfig.metadata->cppClassName, llvm::join(attrParams, ", "), llvm::join(attrStmts, "\n")); @@ -746,7 +775,9 @@ // Compute the number of scalar and tensor arguments. int64_t numOfArgs = llvm::count_if(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { - return arg.usage != LinalgOperandDefUsage::IndexAttr; + return arg.kind == LinalgOperandDefKind::InputTensor || + arg.kind == LinalgOperandDefKind::Scalar || + arg.kind == LinalgOperandDefKind::OutputTensor; }); // An operation that accesses only scalars and scalar/rank zero tensors is @@ -817,7 +848,7 @@ )FMT"; // Update all symbol bindings mapped to an attribute. for (LinalgOperandDef &arg : opConfig.structuredOp->args) { - if (arg.usage != LinalgOperandDefUsage::IndexAttr) + if (arg.kind != LinalgOperandDefKind::IndexAttr) continue; assert(arg.indexAttrMap.hasValue()); for (auto &en : @@ -910,11 +941,11 @@ // hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes() if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { - return arg.usage == LinalgOperandDefUsage::IndexAttr; + return arg.kind == LinalgOperandDefKind::IndexAttr; })) { std::vector attrVerifications; for (LinalgOperandDef &arg : opConfig.structuredOp->args) { - if (arg.usage != LinalgOperandDefUsage::IndexAttr) + if (arg.kind != LinalgOperandDefKind::IndexAttr) continue; assert(arg.indexAttrMap.hasValue()); // Verify index attribute. Paramters: @@ -952,7 +983,8 @@ // Generates a regionBuilder method. Parameters. // {0}: Class name // {1}: Number of args - // {2}: Statements + // {2}: Attributes + // {3}: Statements static const char structuredOpRegionBuilderFormat[] = R"FMT( void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block, ArrayRef attrs) {{ @@ -961,6 +993,7 @@ RegionBuilderHelper helper(block.getArgument(0).getContext(), block); SmallVector yields; {2} + {3} helper.yieldOutputs(yields); } )FMT"; @@ -968,9 +1001,27 @@ auto &assignments = opConfig.structuredOp->assignments; size_t generatedAssignmentCount = 0; int localCounter = 0; + SmallVector attrs; SmallVector stmts; for (LinalgOperandDef &arg : args) { - if (arg.usage != LinalgOperandDefUsage::Output) + if (arg.kind != LinalgOperandDefKind::TypeFnAttr) + continue; + // Obtain the type function attribute values. Parameters. + // {0}: attribute name + // {1}: default type function name + static const char attrDef[] = R"FMT( +TypeFn {0}Val = TypeFn::{1}; +auto {0}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{ + return attr.getName() == "{0}"; }); +if ({0}Iter != attrs.end()) {{ + if (auto attr = {0}Iter->getValue().dyn_cast()) + {0}Val = attr.getValue(); +} +)FMT"; + attrs.push_back(llvm::formatv(attrDef, arg.name, arg.defaultFn)); + } + for (LinalgOperandDef &arg : args) { + if (arg.kind != LinalgOperandDefKind::OutputTensor) continue; // Find the assignment that correlates with the argument. @@ -1048,11 +1099,25 @@ << "an argument type but it does not"; return None; } + + // Use the function name or the attribute to build the type function. + std::string typeFunc = llvm::formatv( + "TypeFn::{0}", expression.typeFn->fnName.getValueOr("")); + if (expression.typeFn->attrName) { + if (llvm::none_of(args, [&](LinalgOperandDef &arg) { + return arg.kind == LinalgOperandDefKind::TypeFnAttr && + arg.name == expression.typeFn->attrName.getValue(); + })) { + emitError(genContext.getLoc()) + << "missing type function attribute " + << expression.typeFn->attrName.getValue(); + } + typeFunc = llvm::formatv("{0}Val", *expression.typeFn->attrName); + } std::string cppIdent = llvm::formatv("value{0}", ++localCounter); - stmts.push_back( - llvm::formatv("Value {0} = helper.typefn__{1}({2}, {3});", - cppIdent, expression.typeFn->fnName, - typeCppValue.getValue(), *operandCppValue)); + stmts.push_back(llvm::formatv( + "Value {0} = helper.buildTypeFn({1}, {2}, {3});", cppIdent, + typeFunc, typeCppValue.getValue(), *operandCppValue)); return cppIdent; } emitError(genContext.getLoc()) << "unknown ScalarExpression type"; @@ -1069,6 +1134,7 @@ << "mismatched number of assignments vs output arguments"; os << llvm::formatv(structuredOpRegionBuilderFormat, className, numOfArgs, + interleaveToString(attrs, "\n "), interleaveToString(stmts, "\n ")); } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6720,6 +6720,22 @@ ], "include/mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc", ), + ( + ["-gen-enum-decls"], + "include/mlir/Dialect/Linalg/IR/LinalgOpsEnums.h.inc", + ), + ( + ["-gen-enum-defs"], + "include/mlir/Dialect/Linalg/IR/LinalgOpsEnums.cpp.inc", + ), + ( + ["-gen-attrdef-decls"], + "include/mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.h.inc", + ), + ( + ["-gen-attrdef-defs"], + "include/mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.cpp.inc", + ), ], tblgen = ":mlir-tblgen", td_file = "include/mlir/Dialect/Linalg/IR/LinalgOps.td",