diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md --- a/mlir/docs/Dialects/Linalg/OpDSL.md +++ b/mlir/docs/Dialects/Linalg/OpDSL.md @@ -105,7 +105,7 @@ copy_and_scale(val, in_tensor, outs=[out_tensor]) ``` -## Attributes +## Index Attributes Attributes are compile-time constant parameters only accessible in index expressions. They can be used to parameterize the access pattern of a structured @@ -118,7 +118,7 @@ @linalg_structured_op def strided_copy(I=TensorDef(T, S.IH, S.IW), O=TensorDef(T, S.OH, S.OW, output=True), - strides=IndexAttrDef(S.SH, S.SW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1])): """Copy a subset of the input tensor elements to the output tensor""" O[D.oh, D.ow] = I[D.oh * S.SH, D.ow * S.SW] ``` @@ -129,11 +129,12 @@ When instantiating the operation, the attribute is set using a named argument: ```python -strided_copy(in_tensor, outs=[out_tensor], strides=[1,2]) +strided_copy(in_tensor, outs=[out_tensor], strides=[1, 2]) ``` The `strides` vector elements substitute the symbols `S.SH` and `S.SW` in the -index expressions of the operation instance. +index expressions of the operation instance. If no strides are provided the +`default` vector elements are used instead. Attributes are currently limited to integer vectors and only accessible in index expressions. An operation may have multiple attributes all of them placed at the @@ -157,8 +158,8 @@ I=TensorDef(T1, S.N, S.H, S.W, S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): O[D.n, D.oh, D.ow, D.c] += TypeFn.cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) ``` diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -15,17 +15,17 @@ args: - !LinalgOperandDefConfig name: A - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> - !LinalgOperandDefConfig name: B - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)> - !LinalgOperandDefConfig name: C - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> indexing_maps: !LinalgIndexingMapsConfig @@ -79,17 +79,17 @@ args: - !LinalgOperandDefConfig name: A - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> - !LinalgOperandDefConfig name: B - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)> - !LinalgOperandDefConfig name: C - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> indexing_maps: !LinalgIndexingMapsConfig @@ -143,25 +143,25 @@ args: - !LinalgOperandDefConfig name: A - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> - !LinalgOperandDefConfig name: B - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)> - !LinalgOperandDefConfig name: AZp - usage: InputOperand + usage: Input type_var: I32 - !LinalgOperandDefConfig name: BZp - usage: InputOperand + usage: Input type_var: I32 - !LinalgOperandDefConfig name: C - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> indexing_maps: !LinalgIndexingMapsConfig @@ -244,17 +244,17 @@ args: - !LinalgOperandDefConfig name: lhs - usage: InputOperand + usage: Input type_var: LhsType shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s2, s3)> - !LinalgOperandDefConfig name: rhs - usage: InputOperand + usage: Input type_var: RhsType shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s1, s5, s3)> - !LinalgOperandDefConfig name: accum - usage: OutputOperand + usage: Output type_var: AccumType shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s4, s2, s5)> indexing_maps: !LinalgIndexingMapsConfig @@ -314,17 +314,17 @@ args: - !LinalgOperandDefConfig name: A - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)> - !LinalgOperandDefConfig name: B - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)> - !LinalgOperandDefConfig name: C - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)> indexing_maps: !LinalgIndexingMapsConfig @@ -379,25 +379,25 @@ args: - !LinalgOperandDefConfig name: A - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)> - !LinalgOperandDefConfig name: B - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)> - !LinalgOperandDefConfig name: AZp - usage: InputOperand + usage: Input type_var: I32 - !LinalgOperandDefConfig name: BZp - usage: InputOperand + usage: Input type_var: I32 - !LinalgOperandDefConfig name: C - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)> indexing_maps: !LinalgIndexingMapsConfig @@ -476,17 +476,17 @@ args: - !LinalgOperandDefConfig name: A - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1] -> (s0, s1)> - !LinalgOperandDefConfig name: y - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1] -> (s1)> - !LinalgOperandDefConfig name: x - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1] -> (s0)> indexing_maps: !LinalgIndexingMapsConfig @@ -539,17 +539,17 @@ args: - !LinalgOperandDefConfig name: y - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1] -> (s0)> - !LinalgOperandDefConfig name: A - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1] -> (s0, s1)> - !LinalgOperandDefConfig name: x - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1] -> (s1)> indexing_maps: !LinalgIndexingMapsConfig @@ -602,17 +602,17 @@ args: - !LinalgOperandDefConfig name: A - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2] -> (s0, s1, s2)> - !LinalgOperandDefConfig name: B - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> - !LinalgOperandDefConfig name: C - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> indexing_maps: !LinalgIndexingMapsConfig @@ -666,17 +666,17 @@ args: - !LinalgOperandDefConfig name: A - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0] -> (s0)> - !LinalgOperandDefConfig name: B - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0] -> (s0)> - !LinalgOperandDefConfig name: C - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0] -> ()> indexing_maps: !LinalgIndexingMapsConfig @@ -728,17 +728,17 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1] -> (s0 + s1)> - !LinalgOperandDefConfig name: K - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1] -> (s1)> - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1] -> (s0)> indexing_maps: !LinalgIndexingMapsConfig @@ -791,17 +791,17 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3] -> (s0 + s1, s2 + s3)> - !LinalgOperandDefConfig name: K - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3] -> (s1, s3)> - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2)> indexing_maps: !LinalgIndexingMapsConfig @@ -856,17 +856,17 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0 + s1, s2 + s3, s4 + s5)> - !LinalgOperandDefConfig name: K - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s1, s3, s5)> - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s2, s4)> indexing_maps: !LinalgIndexingMapsConfig @@ -924,30 +924,32 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s0, s1 * s2 + s3 * s4, s5)> - !LinalgOperandDefConfig name: K - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s3, s5, s6)> - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s0, s1, s6)> - !LinalgOperandDefConfig name: strides - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s2)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s2)> + default_vals: + - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s4)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s4)> + default_vals: + - 1 indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4)[s0, s1, s2, s3, s4, s5, s6] -> (d0, d1 * s2 @@ -1006,34 +1008,38 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> - !LinalgOperandDefConfig name: K - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s3, s7, s9, s10)> - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, s1, s5, s10)> - !LinalgOperandDefConfig name: strides - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s2, - s6)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> + (s2, s6)> + default_vals: + - 1 + - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s4, - s8)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> + (s4, s8)> + default_vals: + - 1 + - 1 indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, @@ -1097,42 +1103,46 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> - !LinalgOperandDefConfig name: K - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s3, s7, s9, s10)> - !LinalgOperandDefConfig name: IZp - usage: InputOperand + usage: Input type_var: I32 - !LinalgOperandDefConfig name: KZp - usage: InputOperand + usage: Input type_var: I32 - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, s1, s5, s10)> - !LinalgOperandDefConfig name: strides - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s2, - s6)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> + (s2, s6)> + default_vals: + - 1 + - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s4, - s8)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> + (s4, s8)> + default_vals: + - 1 + - 1 indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, @@ -1221,34 +1231,38 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, s1, s2 * s3 + s4 * s5, s6 * s7 + s8 * s9)> - !LinalgOperandDefConfig name: K - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s10, s1, s4, s8)> - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, s10, s2, s6)> - !LinalgOperandDefConfig name: strides - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s3, - s7)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> + (s3, s7)> + default_vals: + - 1 + - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s5, - s9)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> + (s5, s9)> + default_vals: + - 1 + - 1 indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, @@ -1307,35 +1321,41 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9 * s10 + s11 * s12, s13)> - !LinalgOperandDefConfig name: K - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14] -> (s3, s7, s11, s13, s14)> - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14] -> (s0, s1, s5, s9, s14)> - !LinalgOperandDefConfig name: strides - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14] -> (s2, s6, s10)> + default_vals: + - 1 + - 1 + - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14] -> (s4, s8, s12)> + default_vals: + - 1 + - 1 + - 1 indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6, @@ -1398,29 +1418,31 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1 * s2 + s3 * s4, s5)> - !LinalgOperandDefConfig name: K - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s3, s5)> - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s5)> - !LinalgOperandDefConfig name: strides - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2)> + default_vals: + - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4)> + default_vals: + - 1 indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] -> (d0, d1 * s2 + d3 * s4, @@ -1475,31 +1497,37 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> - !LinalgOperandDefConfig name: K - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7, s9)> - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5, s9)> - !LinalgOperandDefConfig name: strides - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, + s6)> + default_vals: + - 1 + - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, + s8)> + default_vals: + - 1 + - 1 indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] @@ -1557,39 +1585,45 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> - !LinalgOperandDefConfig name: K - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7, s9)> - !LinalgOperandDefConfig name: IZp - usage: InputOperand + usage: Input type_var: I32 - !LinalgOperandDefConfig name: KZp - usage: InputOperand + usage: Input type_var: I32 - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5, s9)> - !LinalgOperandDefConfig name: strides - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, + s6)> + default_vals: + - 1 + - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, + s8)> + default_vals: + - 1 + - 1 indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] @@ -1673,34 +1707,38 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> - !LinalgOperandDefConfig name: K - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s3, s7, s9, s10)> - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, s1, s5, s9, s10)> - !LinalgOperandDefConfig name: strides - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s2, - s6)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> + (s2, s6)> + default_vals: + - 1 + - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s4, - s8)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> + (s4, s8)> + default_vals: + - 1 + - 1 indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, @@ -1759,42 +1797,46 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> - !LinalgOperandDefConfig name: K - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s3, s7, s9, s10)> - !LinalgOperandDefConfig name: IZp - usage: InputOperand + usage: Input type_var: I32 - !LinalgOperandDefConfig name: KZp - usage: InputOperand + usage: Input type_var: I32 - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, s1, s5, s9, s10)> - !LinalgOperandDefConfig name: strides - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s2, - s6)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> + (s2, s6)> + default_vals: + - 1 + - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s4, - s8)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> + (s4, s8)> + default_vals: + - 1 + - 1 indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, @@ -1879,31 +1921,37 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> - !LinalgOperandDefConfig name: K - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)> - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5, s9)> - !LinalgOperandDefConfig name: strides - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, + s6)> + default_vals: + - 1 + - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, + s8)> + default_vals: + - 1 + - 1 indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] @@ -1950,31 +1998,37 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> - !LinalgOperandDefConfig name: K - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)> - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5, s9)> - !LinalgOperandDefConfig name: strides - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, + s6)> + default_vals: + - 1 + - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, + s8)> + default_vals: + - 1 + - 1 indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] @@ -2021,31 +2075,37 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> - !LinalgOperandDefConfig name: K - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)> - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5, s9)> - !LinalgOperandDefConfig name: strides - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, + s6)> + default_vals: + - 1 + - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, + s8)> + default_vals: + - 1 + - 1 indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] @@ -2092,31 +2152,37 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s2 * s3 + s4 * s5, s6 * s7 + s8 * s9)> - !LinalgOperandDefConfig name: K - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)> - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s2, s6)> - !LinalgOperandDefConfig name: strides - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, + s7)> + default_vals: + - 1 + - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s5, s9)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s5, + s9)> + default_vals: + - 1 + - 1 indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] @@ -2163,31 +2229,37 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> - !LinalgOperandDefConfig name: K - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)> - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5, s9)> - !LinalgOperandDefConfig name: strides - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, + s6)> + default_vals: + - 1 + - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, + s8)> + default_vals: + - 1 + - 1 indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] @@ -2234,31 +2306,37 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> - !LinalgOperandDefConfig name: K - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)> - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5, s9)> - !LinalgOperandDefConfig name: strides - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, + s6)> + default_vals: + - 1 + - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, + s8)> + default_vals: + - 1 + - 1 indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] @@ -2305,34 +2383,40 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9 * s10 + s11 * s12, s13)> - !LinalgOperandDefConfig name: K - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s3, s7, s11)> - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s0, s1, s5, s9, s13)> - !LinalgOperandDefConfig name: strides - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s2, s6, s10)> + default_vals: + - 1 + - 1 + - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s4, s8, s12)> + default_vals: + - 1 + - 1 + - 1 indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7, @@ -2382,34 +2466,40 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9 * s10 + s11 * s12, s13)> - !LinalgOperandDefConfig name: K - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s3, s7, s11)> - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s0, s1, s5, s9, s13)> - !LinalgOperandDefConfig name: strides - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s2, s6, s10)> + default_vals: + - 1 + - 1 + - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s4, s8, s12)> + default_vals: + - 1 + - 1 + - 1 indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7, @@ -2459,34 +2549,40 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9 * s10 + s11 * s12, s13)> - !LinalgOperandDefConfig name: K - usage: InputOperand + usage: Input type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s3, s7, s11)> - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s0, s1, s5, s9, s13)> - !LinalgOperandDefConfig name: strides - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s2, s6, s10)> + default_vals: + - 1 + - 1 + - 1 - !LinalgOperandDefConfig name: dilations - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13] -> (s4, s8, s12)> + default_vals: + - 1 + - 1 + - 1 indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7, @@ -2535,11 +2631,11 @@ args: - !LinalgOperandDefConfig name: value - usage: InputOperand + usage: Input type_var: T1 - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<() -> ()> indexing_maps: !LinalgIndexingMapsConfig @@ -2575,19 +2671,19 @@ args: - !LinalgOperandDefConfig name: min - usage: InputOperand + usage: Input type_var: F64 - !LinalgOperandDefConfig name: max - usage: InputOperand + usage: Input type_var: F64 - !LinalgOperandDefConfig name: seed - usage: InputOperand + usage: Input type_var: I32 - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: T shape_map: affine_map<()[s0, s1] -> (s0, s1)> indexing_maps: !LinalgIndexingMapsConfig @@ -2733,12 +2829,12 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T shape_map: affine_map<()[s0, s1] -> (s0, s1)> - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<()[s0, s1] -> (s0, s1)> indexing_maps: !LinalgIndexingMapsConfig diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -135,7 +135,7 @@ InputTensor = 0 Scalar = 1 OutputTensor = 2 - Attribute = 3 + IndexAttr = 3 class OperandDef: @@ -147,16 +147,18 @@ def __init__(self, kind: OperandKind, - type_var: TypeVar, + type_var: Optional[TypeVar] = None, size_exprs: Optional[Sequence[AffineExprDef]] = None, - index_dims: Optional[Sequence[DimDef]] = None): - if not isinstance(type_var, TypeVar): + index_dims: Optional[Sequence[DimDef]] = None, + default_vals : Optional[Sequence[int]] = None): + if type_var and not isinstance(type_var, TypeVar): raise ValueError( f"OperandDef requires a TypeVar but got {repr(type_var)}") self.owner = None # type: Optional["LinalgOpDef"] self.type_var = type_var self.size_exprs = size_exprs self.index_dims = index_dims + self.default_vals = default_vals self.kind = kind self.name = None # type: Optional[str] self.registered_index = -1 # type: int @@ -174,7 +176,7 @@ def __repr__(self): return (f"{self.name}:OperandDef(kind={self.kind.name}, " f"type={repr(self.type_var)}, size_exprs={self.size_exprs}), " - f"index_dims={self.index_dims})") + f"index_dims={self.index_dims}, default_vals={self.default_vals})") class TensorDef: @@ -202,7 +204,7 @@ f"got {index_dims}") kind = OperandKind.OutputTensor if output else OperandKind.InputTensor self.operand_def = OperandDef( - kind, type_var, size_exprs=shape, index_dims=index_dims) + kind, type_var=type_var, size_exprs=shape, index_dims=index_dims) def __getitem__(self, dims) -> TensorUse: assert self.operand_def.owner, "TensorDef is not attached to an op" @@ -246,7 +248,7 @@ """ def __init__(self, type_var: TypeVar): - self.operand_def = OperandDef(OperandKind.Scalar, type_var) + self.operand_def = OperandDef(OperandKind.Scalar, type_var=type_var) @property def scalar_name(self) -> str: @@ -259,18 +261,25 @@ class IndexAttrDef: - """Index Attribute definition. + """Index attribute definition. Index attributes provide a way to define and set symbols that can be used in indexing expressions. Every attribute specifies a tuple of symbols that at - compile-time are replaced by integer values. + compile-time are replaced by integer values as well as their default values. """ - def __init__(self, *sizes: SymbolDef): + def __init__(self, *sizes: SymbolDef, default: Sequence[int]): if any(not isinstance(size, SymbolDef) for size in sizes): - raise ValueError(f"IndexAttrDef requires sizes of type SymbolDef but got " - f"{sizes}") - self.operand_def = OperandDef(OperandKind.Attribute, I64, size_exprs=sizes) + raise ValueError(f"IndexAttrDef requires sizes of type SymbolDef " + f"but got {sizes}") + if any(not isinstance(default_val, int) for default_val in default): + raise ValueError(f"IndexAttrDef requires default values of type int " + f"but got {default}") + if len(sizes) != len(default): + raise ValueError(f"IndexAttrDef expects {len(sizes)} default values " + f"but got {len(default)}") + self.operand_def = OperandDef( + OperandKind.IndexAttr, size_exprs=sizes, default_vals=default) class Comprehension: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -45,10 +45,10 @@ def __init__(self, operand_def: OperandDef, shape_map: Optional[_ir.AffineMap] = None, - attribute_map: Optional[_ir.AffineMap] = None): + index_attr_map: Optional[_ir.AffineMap] = None): self.operand_def = operand_def self.shape_map = shape_map # type: Optional[_ir.AffineMap] - self.attribute_map = attribute_map # type: Optional[_ir.AffineMap] + self.index_attr_map = index_attr_map # type: Optional[_ir.AffineMap] self.indexing_map = None # type: Optional[_ir.AffineMap] @property @@ -61,24 +61,28 @@ @property def usage(self) -> str: - if self.operand_def.kind == OperandKind.Attribute: - return "IndexAttribute" + if self.operand_def.kind == OperandKind.IndexAttr: + return "IndexAttr" if self.operand_def.kind == OperandKind.OutputTensor: - return "OutputOperand" - return "InputOperand" + return "Output" + return "Input" def to_yaml_custom_dict(self): - self_dict = dict( - name=self.name, usage=self.usage, type_var=self.type_var.name) + self_dict = dict(name=self.name, usage=self.usage) + if self.type_var: + self_dict["type_var"] = self.type_var.name if self.shape_map: self_dict["shape_map"] = _serialize_affine_map(self.shape_map) - if self.attribute_map: - self_dict["attribute_map"] = _serialize_affine_map(self.attribute_map) + if self.index_attr_map: + self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map) + if self.operand_def.default_vals: + self_dict["default_vals"] = self.operand_def.default_vals return self_dict def __repr__(self): return (f"OperandDefConfig({self.operand_def}, " - f"shape_map={self.shape_map}, attribute_map={self.attribute_map}, " + f"shape_map={self.shape_map}, " + f"index_attr_map={self.index_attr_map}, " f"indexing_map={self.indexing_map})") @@ -162,7 +166,7 @@ # Collect all attribute definitions. collected_attr_defs = list() for operand in registered_operands: - if operand.kind == OperandKind.Attribute: + if operand.kind == OperandKind.IndexAttr: collected_attr_defs.append(operand) # Collect all tensors with manual indexing annotation. @@ -210,9 +214,9 @@ if operand_config.shape_map: operand_config.shape_map = self._normalize_affine_map( operand_config.shape_map, with_dims=False) - if operand_config.attribute_map: - operand_config.attribute_map = self._normalize_affine_map( - operand_config.attribute_map, with_dims=False) + if operand_config.index_attr_map: + operand_config.index_attr_map = self._normalize_affine_map( + operand_config.index_attr_map, with_dims=False) # Now for each write use, propagate the indexing maps from the use to the # tensor, ensuring that there are not conflicts. @@ -245,7 +249,7 @@ # Check all registered tensor and scalar operands have an indexing map. for operand in registered_operands: - if operand.kind == OperandKind.Attribute: + if operand.kind == OperandKind.IndexAttr: continue if not (operand in self.operands and self.operands[operand].indexing_map): raise ValueError(f"Failed to compute an indexing map for operand " @@ -319,9 +323,9 @@ assert local_state.local_dim_count == 0 affine_map = _ir.AffineMap.get( dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs) - if operand_def.kind == OperandKind.Attribute: + if operand_def.kind == OperandKind.IndexAttr: self.operands[operand_def] = OperandDefConfig( - operand_def, attribute_map=affine_map) + operand_def, index_attr_map=affine_map) else: self.operands[operand_def] = OperandDefConfig( operand_def, shape_map=affine_map) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -39,15 +39,14 @@ *ins: Value, outs: ValueList, **attrs: Sequence[int]): all_arg_defs = op_config.ordered_operands - in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "InputOperand"] - out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "OutputOperand"] - attr_arg_defs = [arg for arg in all_arg_defs if arg.usage == "IndexAttribute"] + in_arg_defs = [d for d in all_arg_defs if d.usage == "Input"] + out_arg_defs = [d for d in all_arg_defs if d.usage == "Output"] + index_attr_arg_defs = [d for d in all_arg_defs if d.usage == "IndexAttr"] # Verify outs is a sequence or a list of results. if not isinstance(outs, (Sequence, OpResultList)): - raise ValueError( - f"Expected named argument outs to have type Sequence or OpResultLis but got {type(outs)}" - ) + raise ValueError(f"Expected named argument outs to have type Sequence or " + f"OpResultLis but got {type(outs)}") # Arity validation. if len(ins) != len(in_arg_defs): @@ -60,18 +59,19 @@ # Compute a replacement list for all attribute symbols. expressions = [] # type: Sequence[AffineExpr] replacements = [] # type: Sequence[AffineExpr] - for attr in attr_arg_defs: - if attr.name not in attrs: - raise ValueError(f"Expected named argument for the attribute {attr.name}") - attribute_values = attrs.get(attr.name) - if not all(isinstance(value, int) for value in attribute_values): - raise ValueError(f"Attribute {attr.name} needs to be of type " - f"Sequence[int] but got {type(attribute_values)}") - results = attr.attribute_map.results # type: AffineExprList - if len(attribute_values) != len(results): - raise ValueError(f"Attribute {attr.name} has length {len(results)} " - f"but got {len(attribute_values)} values") - for expr, value in zip(results, attribute_values): + for index_attr in index_attr_arg_defs: + index_attr_vals = index_attr.operand_def.default_vals + if index_attr.name in attrs: + index_attr_vals = attrs.get(index_attr.name) + assert index_attr_vals, "Index attribute has no value" + if not all(isinstance(value, int) for value in index_attr_vals): + raise ValueError(f"Attribute {index_attr.name} needs to be of type " + f"Sequence[int] but got {type(index_attr_vals)}") + results = index_attr.index_attr_map.results # type: AffineExprList + if len(index_attr_vals) != len(results): + raise ValueError(f"Attribute {index_attr.name} has length {len(results)} " + f"but got {len(index_attr_vals)} values") + for expr, value in zip(results, index_attr_vals): expressions.append(expr) replacements.append(AffineConstantExpr.get(value)) @@ -116,22 +116,24 @@ iterator_types_attr = ArrayAttr.get( [StringAttr.get(s) for s in op_config.iterator_types]) - # Compute a dictionary storing all index attributes. - index_attributes = {} # type: Dict[str, DenseElementAttr] - for attr in attr_arg_defs: - attribute_values = attrs.get(attr.name) - array = np.array(attribute_values, dtype=np.int64) - index_attributes[attr.name] = DenseElementsAttr.get(array) + # Compute the index attributes used when emitting a named structured op. + index_attrs = {} # type: Dict[str, DenseElementAttr] + for index_attr in index_attr_arg_defs: + index_attr_vals = attrs.get(index_attr.name) + # Only forward attributes set to a non-default value. + if index_attr_vals: + array = np.array(index_attr_vals, dtype=np.int64) + index_attrs[index_attr.name] = DenseElementsAttr.get(array) return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, indexing_maps_attr, iterator_types_attr, - index_attributes, block_arg_types) + index_attrs, block_arg_types) def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, outs: ValueList, **attrs: Sequence[int]): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \ + indexing_maps_attr, iterator_types_attr, index_attrs, block_arg_types = \ prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) # An operation that accesses only scalars and scalar/rank zero tensors is @@ -182,7 +184,7 @@ op_class_name: str, *ins: Value, outs: ValueList, **attrs: Sequence[int]): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \ + indexing_maps_attr, iterator_types_attr, index_attrs, block_arg_types = \ prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) # If we get here, there must exist a builtin class `op_class_name`. @@ -195,7 +197,7 @@ # Set the index attributes used to compute the indexing maps. named_op = getattr(linalg, op_class_name)(ins, outs, result_types) - for name, value in index_attributes.items(): + for name, value in index_attrs.items(): named_op.operation.attributes[name] = value linalg.fill_builtin_region(named_op.operation) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -224,8 +224,8 @@ I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KW, S.C, S.F), O=TensorDef(U, S.N, S.OW, S.F, output=True), - strides=IndexAttrDef(S.SW), - dilations=IndexAttrDef(S.DW)): + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1])): """Performs 1-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -244,8 +244,8 @@ S.C), K=TensorDef(T2, S.KH, S.KW, S.C, S.F), O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs 2-D convolution. Layout: @@ -270,8 +270,8 @@ IZp=ScalarDef(I32), KZp=ScalarDef(I32), O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs 2-D convolution with zero point offsets. Layout: @@ -297,8 +297,8 @@ S.OW * S.SW + S.KW * S.DW), K=TensorDef(T2, S.F, S.C, S.KH, S.KW), O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs 2-D convolution. Layout: @@ -321,8 +321,8 @@ S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True), - strides=IndexAttrDef(S.SD, S.SH, S.SW), - dilations=IndexAttrDef(S.DD, S.DH, S.DW)): + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])): """Performs 3-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -341,8 +341,8 @@ I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC), K=TensorDef(T2, S.KW, S.IC), O=TensorDef(U, S.N, S.OW, S.IC, output=True), - strides=IndexAttrDef(S.SW), - dilations=IndexAttrDef(S.DW)): + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1])): """Performs depth-wise 1-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -362,8 +362,8 @@ S.IC), K=TensorDef(T2, S.KH, S.KW, S.IC), O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs depth-wise 2-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -385,8 +385,8 @@ IZp=ScalarDef(I32), KZp=ScalarDef(I32), O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs depth-wise 2-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -407,8 +407,8 @@ S.IC), K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs depth-wise 2-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -429,8 +429,8 @@ IZp=ScalarDef(I32), KZp=ScalarDef(I32), O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs depth-wise 2-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -451,8 +451,8 @@ S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs sum pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -470,8 +470,8 @@ S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs max pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -490,8 +490,8 @@ S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs unsigned max pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -510,8 +510,8 @@ S.OW * S.SW + S.KW * S.DW), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs max pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -531,8 +531,8 @@ S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs min pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -551,8 +551,8 @@ S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs unsigned min pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -571,8 +571,8 @@ S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SD, S.SH, S.SW), - dilations=IndexAttrDef(S.DD, S.DH, S.DW)): + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])): """Performs 3D sum pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -591,8 +591,8 @@ S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SD, S.SH, S.SW), - dilations=IndexAttrDef(S.DD, S.DH, S.DW)): + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])): """Performs 3D max pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -612,8 +612,8 @@ S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SD, S.SH, S.SW), - dilations=IndexAttrDef(S.DD, S.DH, S.DW)): + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])): """Performs 3D min pooling. Numeric casting is performed on the input operand, promoting it to the same diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -97,19 +97,12 @@ // ----- -func @depthwise_conv_2d_input_nhwc_filter_missing_stride(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { - // expected-error @+1 {{missing indexing map required attribute 'strides'}} - linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>} - ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) - outs(%output: memref<1x56x56x96xf32>) - return -} - -// ----- - -func @depthwise_conv_2d_input_nhwc_filter_missing_dilations(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { - // expected-error @+1 {{missing indexing map required attribute 'dilations'}} - linalg.depthwise_conv_2d_nhwc_hwc {strides = dense<1> : vector<2xi64>} +// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_default_attributes +func @depthwise_conv_2d_input_nhwc_filter_default_attributes(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { + // CHECK: linalg.depthwise_conv_2d_nhwc_hwc + // CHECK-NOT: strides = + // CHECK-NOT: dilations = + linalg.depthwise_conv_2d_nhwc_hwc ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) outs(%output: memref<1x56x56x96xf32>) return @@ -118,7 +111,7 @@ // ----- func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { - // expected-error @+1 {{incorrect element type for indexing map required attribute 'strides'}} + // expected-error @+1 {{incorrect element type for index attribute 'strides'}} linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2.0> : vector<2xf32>} ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) outs(%output: memref<1x56x56x96xf32>) @@ -128,7 +121,7 @@ // ----- func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_size(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { - // expected-error @+1 {{incorrect shape for indexing map required attribute 'strides'}} + // expected-error @+1 {{incorrect shape for index attribute 'strides'}} linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<3xi64> } ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) outs(%output: memref<1x56x56x96xf32>) @@ -566,7 +559,7 @@ %arg0 : tensor, %arg2 : tensor, %arg1 : tensor) -> tensor { // expected-error @+1 {{unexpected input index map for convolutions}} %0 = "linalg.conv_2d_nhwc_hwcf"(%arg0, %arg1, %arg2) ({ - ^bb0(%arg3: f32, %arg4: f32, %arg5 : f32): + ^bb0(%arg3: f32, %arg4: f32, %arg5 : f32): %1 = "arith.mulf"(%arg3, %arg4) : (f32, f32) -> f32 %2 = "arith.addf"(%arg5, %1) : (f32, f32) -> f32 "linalg.yield"(%2) : (f32) -> () @@ -583,7 +576,7 @@ %arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { // expected-error @+1 {{expected output/filter indexing maps to be projected permutations}} %0 = "linalg.conv_2d_nhwc_hwcf"(%arg0, %arg1, %arg2) ({ - ^bb0(%arg3: f32, %arg4: f32, %arg5 : f32): + ^bb0(%arg3: f32, %arg4: f32, %arg5 : f32): %1 = "arith.mulf"(%arg3, %arg4) : (f32, f32) -> f32 %2 = "arith.addf"(%arg5, %1) : (f32, f32) -> f32 "linalg.yield"(%2) : (f32) -> () diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -21,7 +21,7 @@ args: - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: T shape_map: affine_map<()[s0, s1] -> (s0, s1)> indexing_maps: !LinalgIndexingMapsConfig @@ -95,7 +95,7 @@ # @linalg_structured_op # def test2(I=TensorDef(T, S.M, S.N), # O=TensorDef(T, S.M, S.N, output=True), -# strides=IndexAttrDef(S.SM, S.SN)): +# strides=IndexAttrDef(S.SM, S.SN, default=[1, 2])): # """Title. # Detailed description. @@ -114,19 +114,21 @@ args: - !LinalgOperandDefConfig name: I - usage: InputOperand + usage: Input type_var: T shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)> - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: T shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)> - !LinalgOperandDefConfig name: strides - usage: IndexAttribute - type_var: I64 - attribute_map: affine_map<()[s0, s1, s2, s3] -> (s2, s3)> + usage: IndexAttr + index_attr_map: affine_map<()[s0, s1, s2, s3] -> (s2, s3)> + default_vals: + - 1 + - 2 indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1)[s0, s1, s2, s3] -> (d1 * s2, d0 * s3)> @@ -145,7 +147,8 @@ # ODS: let arguments = # ODS-NEXT: Variadic:$inputs, # ODS-NEXT: Variadic:$outputs, -# ODS-NEXT: RankedI64ElementsAttr<[2]>:$strides +# ODS-NEXT: DefaultValuedAttr +# ODS-SAME: "{ static_cast(1), static_cast(2) }">:$strides # ODS: "Attribute":$strides # ODS: $_state.addAttribute("strides", strides); @@ -169,8 +172,8 @@ # IMPL: Test2Op::hasDynamicIndexingMaps() { return true; } # IMPL: Test2Op::verifyIndexingMapRequiredAttributes() # IMPL: auto attr = op->getAttrOfType("strides") -# IMPL: "missing indexing map required attribute 'strides'" - +# IMPL: "incorrect element type for index attribute 'strides'" +# IMPL: "incorrect shape for index attribute 'strides'" # IMPL: void Test2Op::regionBuilder(ImplicitLocOpBuilder &b, Block &block) # IMPL-NEXT: assert(2 > 0 && block.getNumArguments() == 2 && @@ -197,11 +200,11 @@ args: - !LinalgOperandDefConfig name: value - usage: InputOperand + usage: Input type_var: T1 - !LinalgOperandDefConfig name: O - usage: OutputOperand + usage: Output type_var: U shape_map: affine_map<() -> ()> indexing_maps: !LinalgIndexingMapsConfig diff --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py --- a/mlir/test/python/dialects/linalg/opdsl/arguments.py +++ b/mlir/test/python/dialects/linalg/opdsl/arguments.py @@ -7,15 +7,15 @@ # CHECK-LABEL: matmul # CHECK: args: # CHECK: name: A -# CHECK: usage: InputOperand +# CHECK: usage: Input # CHECK: type_var: T # CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> # CHECK: name: B -# CHECK: usage: InputOperand +# CHECK: usage: Input # CHECK: type_var: T # CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)> # CHECK: name: C -# CHECK: usage: OutputOperand +# CHECK: usage: Output # CHECK: type_var: U # CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> @linalg_structured_op @@ -30,7 +30,7 @@ # CHECK-LABEL: fill # CHECK: args: # CHECK: name: value -# CHECK: usage: InputOperand +# CHECK: usage: Input # CHECK-NOT: shape_map: # CHECK: type_var: T @linalg_structured_op @@ -42,20 +42,22 @@ # CHECK-LABEL: strided_copy # CHECK: args: # CHECK: name: I -# CHECK: usage: InputOperand +# CHECK: usage: Input # CHECK: type_var: T # CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)> # CHECK: name: O -# CHECK: usage: OutputOperand +# CHECK: usage: Output # CHECK: type_var: T # CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)> # CHECK: name: strides -# CHECK: usage: IndexAttribute -# CHECK: type_var: I64 -# CHECK: attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s5)> +# CHECK: usage: IndexAttr +# CHECK: index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s5)> +# CHECK: default_vals: +# CHECK: - 1 +# CHECK: - 2 @linalg_structured_op def strided_copy( I=TensorDef(T, S.IH, S.IW), O=TensorDef(T, S.OH, S.OW, output=True), - strides=IndexAttrDef(S.SH, S.SW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 2])): O[D.oh, D.ow] = I[D.oh * S.SH, D.ow * S.SW] diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py b/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py @@ -16,8 +16,8 @@ I=TensorDef(T1, S.N, S.IH, S.IW, S.C), K=TensorDef(T2, S.KH, S.KW, S.C), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 2])): domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] += TypeFn.cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, @@ -51,8 +51,9 @@ RankedTensorType.get((2, 2, 1), f32), RankedTensorType.get((1, 2, 4, 1), i32)) def test_f32i32_conv(input, filter, init_result): + # Use default dilations and set non-default strides. return conv_poly( - input, filter, outs=[init_result], strides=[2, 4], dilations=[1, 2]) + input, filter, outs=[init_result], strides=[2, 4]) print(module) diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py @@ -16,8 +16,8 @@ I=TensorDef(T1, S.N, S.H, S.W, S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw]( TypeFn.cast( @@ -29,8 +29,8 @@ I=TensorDef(T1, S.N, S.H, S.W, S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw]( TypeFn.cast_unsigned( @@ -42,8 +42,8 @@ I=TensorDef(T1, S.N, S.H, S.W, S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw]( TypeFn.cast( @@ -55,8 +55,8 @@ I=TensorDef(T1, S.N, S.H, S.W, S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw]( TypeFn.cast_unsigned( diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py --- a/mlir/test/python/integration/dialects/linalg/opsrun.py +++ b/mlir/test/python/integration/dialects/linalg/opsrun.py @@ -132,7 +132,7 @@ %c2 = arith.constant 2 : index memref.store %v42, %input[%c0, %c0, %c0, %c0] : memref<1x4x16x1xf64> memref.store %v77, %input[%c0, %c0, %c1, %c0] : memref<1x4x16x1xf64> - memref.store %v-13, %input[%c0, %c0, %c2, %c0] : memref<1x4x16x1xf64> + memref.store %v-13, %input[%c0, %c1, %c0, %c0] : memref<1x4x16x1xf64> call @pooling_on_buffers(%input, %shape, %output) : (memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> () @@ -421,9 +421,13 @@ @builtin.FuncOp.from_py_func( MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), MemRefType.get((1, 2, 4, 1), i32)) + # Set the strides and use the default dilations. def pooling_on_buffers(input, shape, output): linalg.pooling_nhwc_min( - input, shape, outs=[output], strides=[2, 4], dilations=[1, 2]) + input, + shape, + outs=[output], + strides=[2, 4]) execution_engine = ExecutionEngine(transform(module, pooling_boiler)) @@ -451,13 +455,13 @@ @builtin.FuncOp.from_py_func( MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), MemRefType.get((1, 2, 4, 1), i32)) + # Set the strides and use the default dilations. def pooling_on_buffers(input, shape, output): linalg.pooling_nhwc_min( input, shape, outs=[output], strides=[2, 4], - dilations=[1, 2], emit_generic=True) execution_engine = ExecutionEngine(transform(module, pooling_boiler)) diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -61,14 +61,15 @@ AffineMap affineMap() { return affineMapAttr.getValue(); } }; -enum class LinalgOperandDefUsage { input, output, attribute }; +enum class LinalgOperandDefUsage { Input, Output, IndexAttr }; struct LinalgOperandDef { std::string name; LinalgOperandDefUsage usage; - std::string typeVar; + Optional typeVar; Optional shapeMap; - Optional attributeMap; + Optional indexAttrMap; + Optional> defaultVals; }; enum class LinalgIteratorTypeDef { @@ -175,18 +176,21 @@ /// the argument. Only tensor arguments have a `shape_map`. Each shape must /// be normalized over the same list of symbols and have no dimension /// inputs. -/// - `attribute_map`: An optional AffineMap from all op symbols to the -/// attribute symbols. During op creation these symbols are replaced by the -/// corresponding `name` attribute values. Only attribute arguments have -/// an `attribute_map`. +/// - `index_attr_map`: An optional AffineMap from all op symbols to the +/// index attribute symbols. During op creation these symbols are replaced +/// by the corresponding `name` index attribue values. Only index attribute +/// arguments have an `index_attr_map`. +/// - `default_vals`: An optional default initialization for index attribute +/// arguments. template <> struct MappingTraits { static void mapping(IO &io, LinalgOperandDef &info) { io.mapRequired("name", info.name); io.mapRequired("usage", info.usage); - io.mapRequired("type_var", info.typeVar); + io.mapOptional("type_var", info.typeVar); io.mapOptional("shape_map", info.shapeMap); - io.mapOptional("attribute_map", info.attributeMap); + io.mapOptional("index_attr_map", info.indexAttrMap); + io.mapOptional("default_vals", info.defaultVals); } }; @@ -194,9 +198,9 @@ template <> struct ScalarEnumerationTraits { static void enumeration(IO &io, LinalgOperandDefUsage &value) { - io.enumCase(value, "InputOperand", LinalgOperandDefUsage::input); - io.enumCase(value, "OutputOperand", LinalgOperandDefUsage::output); - io.enumCase(value, "IndexAttribute", LinalgOperandDefUsage::attribute); + io.enumCase(value, "Input", LinalgOperandDefUsage::Input); + io.enumCase(value, "Output", LinalgOperandDefUsage::Output); + io.enumCase(value, "IndexAttr", LinalgOperandDefUsage::IndexAttr); } }; @@ -395,7 +399,10 @@ // Search all argument types. for (const auto &it : llvm::enumerate(args)) { - if (it.value().typeVar == typeVar) + if (it.value().usage != LinalgOperandDefUsage::Input && + it.value().usage != LinalgOperandDefUsage::Output) + continue; + if (it.value().typeVar.getValue() == typeVar) return llvm::formatv("block.getArgument({0}).getType()", it.index()) .str(); } @@ -674,20 +681,32 @@ // Assemble the attribute specific logic required for the op definition. if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { - return arg.usage == LinalgOperandDefUsage::attribute; + return arg.usage == LinalgOperandDefUsage::IndexAttr; })) { SmallVector attrDefs; SmallVector attrParams; SmallVector attrStmts; for (LinalgOperandDef &arg : opConfig.structuredOp->args) { - if (arg.usage != LinalgOperandDefUsage::attribute) + if (arg.usage != LinalgOperandDefUsage::IndexAttr) continue; - assert(arg.attributeMap.hasValue() && arg.typeVar == "I64"); - static const char defFmt[] = "RankedI64ElementsAttr<[{0}]>:${1}"; + assert(arg.indexAttrMap.hasValue()); + assert(arg.defaultVals.hasValue()); + size_t size = arg.indexAttrMap->affineMap().getNumResults(); + assert(arg.defaultVals.getValue().size() == size); + static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>"; + static const char defFmt[] = "DefaultValuedAttr<{0}, \"{1}\">:${2}"; static const char paramFmt[] = "\"Attribute\":${0}"; static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});"; - attrDefs.push_back(llvm::formatv( - defFmt, arg.attributeMap->affineMap().getNumResults(), arg.name)); + std::string defaultVals; + llvm::raw_string_ostream ss(defaultVals); + ss << "{ "; + llvm::interleave( + arg.defaultVals.getValue(), ss, + [&](int64_t val) { ss << "static_cast(" << val << ")"; }, + ", "); + ss << " }"; + attrDefs.push_back(llvm::formatv(defFmt, llvm::formatv(typeFmt, size), + ss.str(), arg.name)); attrParams.push_back(llvm::formatv(paramFmt, arg.name)); attrStmts.push_back(llvm::formatv(stmtFmt, arg.name)); } @@ -725,7 +744,7 @@ // Compute the number of scalar and tensor arguments. int64_t numOfArgs = llvm::count_if(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { - return arg.usage != LinalgOperandDefUsage::attribute; + return arg.usage != LinalgOperandDefUsage::IndexAttr; }); // An operation that accesses only scalars and scalar/rank zero tensors is @@ -796,11 +815,11 @@ )FMT"; // Update all symbol bindings mapped to an attribute. for (LinalgOperandDef &arg : opConfig.structuredOp->args) { - if (arg.usage != LinalgOperandDefUsage::attribute) + if (arg.usage != LinalgOperandDefUsage::IndexAttr) continue; - assert(arg.attributeMap.hasValue()); + assert(arg.indexAttrMap.hasValue()); for (auto &en : - llvm::enumerate(arg.attributeMap->affineMap().getResults())) { + llvm::enumerate(arg.indexAttrMap->affineMap().getResults())) { if (auto symbol = en.value().dyn_cast()) { symbolBindings[symbol.getPosition()] = llvm::formatv(structuredOpAccessAttrFormat, arg.name, @@ -889,31 +908,26 @@ // hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes() if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { - return arg.usage == LinalgOperandDefUsage::attribute; + return arg.usage == LinalgOperandDefUsage::IndexAttr; })) { std::vector attrVerifications; for (LinalgOperandDef &arg : opConfig.structuredOp->args) { - if (arg.usage != LinalgOperandDefUsage::attribute) + if (arg.usage != LinalgOperandDefUsage::IndexAttr) continue; - assert(arg.attributeMap.hasValue() && arg.typeVar == "I64"); + assert(arg.indexAttrMap.hasValue()); // Verify index attribute. Paramters: // {0}: Attribute name // {1}: Attribute size static const char attrFmt[] = R"FMT( if (auto attr = op->getAttrOfType("{0}")) {{ if (!attr.getType().getElementType().isInteger(64)) - return op->emitError( - "incorrect element type for indexing map required attribute '{0}'"); + return op->emitError("incorrect element type for index attribute '{0}'"); if (attr.getType().getShape() != ArrayRef{{ {1} }) - return op->emitError( - "incorrect shape for indexing map required attribute '{0}'"); -} else { - return op->emitError( - "missing indexing map required attribute '{0}'"); + return op->emitError("incorrect shape for index attribute '{0}'"); } )FMT"; attrVerifications.push_back(llvm::formatv( - attrFmt, arg.name, arg.attributeMap->affineMap().getNumResults())); + attrFmt, arg.name, arg.indexAttrMap->affineMap().getNumResults())); } // Generates the verifyIndexingMapRequiredAttributes method. Parameters: @@ -953,7 +967,7 @@ int localCounter = 0; SmallVector stmts; for (LinalgOperandDef &arg : args) { - if (arg.usage != LinalgOperandDefUsage::output) + if (arg.usage != LinalgOperandDefUsage::Output) continue; // Find the assignment that correlates with the argument.