diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -1,3 +1,4 @@ + --- !LinalgOpConfig metadata: !LinalgOpMetadata name: matmul @@ -15,17 +16,17 @@ name: A usage: InputOperand type_var: T1 - shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> + shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> - !LinalgOperandDefConfig name: B usage: InputOperand type_var: T2 - shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)> + shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)> - !LinalgOperandDefConfig name: C usage: OutputOperand type_var: U - shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> + shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)> @@ -77,17 +78,17 @@ name: A usage: InputOperand type_var: T1 - shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)> + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)> - !LinalgOperandDefConfig name: B usage: InputOperand type_var: T2 - shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)> + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)> - !LinalgOperandDefConfig name: C usage: OutputOperand type_var: U - shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)> + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)> indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)> @@ -201,17 +202,17 @@ name: y usage: InputOperand type_var: T1 - shape_map: affine_map<()[s0, s1] -> (s1)> + shape_map: affine_map<()[s0, s1] -> (s0)> - !LinalgOperandDefConfig name: A usage: InputOperand type_var: T2 - shape_map: affine_map<()[s0, s1] -> (s1, s0)> + shape_map: affine_map<()[s0, s1] -> (s0, s1)> - !LinalgOperandDefConfig name: x usage: OutputOperand type_var: U - shape_map: affine_map<()[s0, s1] -> (s0)> + shape_map: affine_map<()[s0, s1] -> (s1)> indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1)[s0, s1] -> (d1)> @@ -321,19 +322,19 @@ usage: InputOperand type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> - (s0, s4, s5, s3)> + (s0, s1, s2, s3)> - !LinalgOperandDefConfig name: K usage: InputOperand type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> - (s6, s7, s3)> + (s4, s5, s3)> - !LinalgOperandDefConfig name: O usage: OutputOperand type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> - (s0, s1, s2, s3)> + (s0, s6, s7, s3)> - !LinalgOperandDefConfig name: strides usage: IndexAttribute @@ -349,18 +350,18 @@ indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, - s10, s11] -> (d0, d1 * s8 + d4 * s10, d2 * s9 + d5 * s11, d3)> + s10, s11] -> (d0, d1 * s8 + d3 * s10, d2 * s9 + d4 * s11, d5)> - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, - s10, s11] -> (d4, d5, d3)> + s10, s11] -> (d3, d4, d5)> - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, - s10, s11] -> (d0, d1, d2, d3)> + s10, s11] -> (d0, d1, d2, d5)> iterator_types: - parallel - parallel - parallel - - parallel - reduction - reduction + - parallel assignments: - !ScalarAssign arg: O @@ -402,45 +403,45 @@ usage: InputOperand type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> - (s0, s4, s5, s3)> + (s0, s1, s2, s3)> - !LinalgOperandDefConfig name: K usage: InputOperand type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> - (s10, s11)> + (s4, s5)> - !LinalgOperandDefConfig name: O usage: OutputOperand type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> - (s0, s1, s2, s3)> + (s0, s6, s7, s3)> - !LinalgOperandDefConfig name: strides usage: IndexAttribute type_var: I64 attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] - -> (s6, s7)> + -> (s8, s9)> - !LinalgOperandDefConfig name: dilations usage: IndexAttribute type_var: I64 attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] - -> (s8, s9)> + -> (s10, s11)> indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, - s10, s11] -> (d2, d3 * s6 + d0 * s8, d4 * s7 + d1 * s9, d5)> + s10, s11] -> (d0, d1 * s8 + d3 * s10, d2 * s9 + d4 * s11, d5)> - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, - s10, s11] -> (d0, d1)> + s10, s11] -> (d3, d4)> - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, - s10, s11] -> (d2, d3, d4, d5)> + s10, s11] -> (d0, d1, d2, d5)> iterator_types: - - reduction - - reduction - parallel - parallel - parallel + - reduction + - reduction - parallel assignments: - !ScalarAssign diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -32,13 +32,13 @@ """Visits all tensor expression reachable by the expression.""" callback(self) - def _get_all_dim_defs(self) -> Set[DimDef]: - """Recursively gets all DimDef affine expressions that are referenced.""" + def collect_dim_uses(self, uses: Set["DimDef"]): + """Collects all DimDefs reachable through this expression.""" results = set() def visit_dim_def(dim_def): if isinstance(dim_def, DimDef): - results.add(dim_def) + uses.add(dim_def) def visit_affine_exprs(expr): if isinstance(expr, TensorUse): @@ -49,7 +49,6 @@ ind.visit_affine_exprs(visit_dim_def) self.visit_tensor_exprs(visit_affine_exprs) - return results def collect_tensor_uses(self, uses: Set["TensorUse"]): """Collects all TensorUses reachable through this expression.""" @@ -126,8 +125,10 @@ reduced into. Any indices referenced on the rhs and not in self are considered reduction dims and will be ordered as encountered on the rhs. """ - rhs_dims = rhs._get_all_dim_defs() - lhs_dims = self._get_all_dim_defs() + rhs_dims = set() + lhs_dims = set() + rhs.collect_dim_uses(rhs_dims) + self.collect_dim_uses(lhs_dims) return rhs_dims - lhs_dims def __repr__(self): @@ -202,7 +203,7 @@ f"number of index_dims {len(index_dims)}") if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims): raise ValueError(f"TensorDef requires index dims of type DimDef but " - f"got {type(index_dims)}") + f"got {index_dims}") kind = OperandKind.OutputTensor if output else OperandKind.InputTensor self.operand_def = OperandDef( kind, type_var, size_exprs=shape, index_dims=index_dims) @@ -273,7 +274,7 @@ def __init__(self, *sizes: SymbolDef): if any(not isinstance(size, SymbolDef) for size in sizes): raise ValueError(f"AttributeDef requires sizes of type SymbolDef but got " - f"{type(sizes)}") + f"{sizes}") self.operand_def = OperandDef(OperandKind.Attribute, I64, size_exprs=sizes) @@ -516,6 +517,7 @@ self.metadata = OpMetadataDef( name=name, cpp_class_name=cpp_class_name, doc=doc) self.registered_operands = dict() # type: Dict[str, OperandDef] + self.domain = list() # type: List[DimDef] self.comprehensions = list() # type: List[Comprehension] self._affine_state = AffineBuildState() diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -115,6 +115,7 @@ def __init__(self, comprehension: Comprehension, + domain: Sequence[DimDef], registered_operands: Sequence[OperandDef], context: Optional[_ir.Context] = None): self.context = context if context is not None else _ir.Context() @@ -123,10 +124,11 @@ self.operands = dict() # type: Dict[OperandDef, OperandDefConfig] self.uses = dict() # type: Dict[TensorUse, TensorUseConfig] - # Compute the ordered set of writes and collect the tensor, capture, and - # index uses. + # Compute the ordered set of writes and collect the tensor, capture, dims, + # and index uses. collected_tensor_uses = set() collected_scalar_uses = set() + collected_dim_uses = set() collected_indices = set() for write_use, read_use in zip(comprehension.definitions, comprehension.values): @@ -136,8 +138,28 @@ collected_tensor_uses.add(write_use) read_use.collect_tensor_uses(collected_tensor_uses) read_use.collect_scalar_uses(collected_scalar_uses) + read_use.collect_dim_uses(collected_dim_uses) + write_use.collect_dim_uses(collected_dim_uses) read_use.collect_indices(collected_indices) + # Set domain to the sorted list of uses if no domain annotation is given. + if not domain: + domain = sorted(collected_dim_uses, key=lambda dim: dim.dimname) + + # Verify the domain dimensions match the used dimensions. + if (len(domain) != len(collected_dim_uses) or + any(dim not in collected_dim_uses for dim in domain)): + raise ValueError(f"Expected the annotated domain dimensions {domain} to " + f"match the set of dimension used by the tensor " + f"comprehension {collected_dim_uses}") + + # Instantiate the dimensions in the given order. + with self.context: + local_state = AffineBuildState( + global_state=self.affine_state, allow_new_symbols=False) + for dim in domain: + dim.build(state=local_state) + # Collect all attribute definitions. collected_attr_defs = list() for operand in registered_operands: @@ -148,18 +170,32 @@ collected_index_defs = list() for operand in registered_operands: if operand.index_dims: + if any(dim not in collected_dim_uses for dim in operand.index_dims): + raise ValueError(f"Expected all index dims {operand.index_dims} of " + f"operand {operand.name} to have uses.") collected_index_defs.append(operand) - # Add all definitions before uses, so process twice. + # Collect the operand definitions of all tensor/scalar uses, attributes, and + # shape-only tensors. + all_operand_defs = list() for use in collected_tensor_uses: - self.add_operand(use.operand_def) + all_operand_defs.append(use.operand_def) for use in collected_scalar_uses: - self.add_operand(use.operand_def) + all_operand_defs.append(use.operand_def) for definition in collected_attr_defs: - self.add_operand(definition) + all_operand_defs.append(definition) + for definition in collected_index_defs: + all_operand_defs.append(definition) + + # Add all operands in registration order to ensure the symbols are + # registered in the order they appear. + all_operand_defs = sorted( + all_operand_defs, key=lambda operand_def: operand_def.registered_index) + for operand_def in all_operand_defs: + self.add_operand(operand_def) + + # Add all shape-only tensor index_dim annotations and all tensor uses. for definition in collected_index_defs: - if definition not in self.operands: - self.add_operand(definition) self.add_indexed_operand(definition) for use in collected_tensor_uses: self.add_tensor_use(use) @@ -396,7 +432,7 @@ LinalgOpConfig( tc_op_def.metadata, structured_op=LinalgStructuredOpConfig( - tc_op_def.comprehensions[0], + tc_op_def.comprehensions[0], tc_op_def.domain, tc_op_def.registered_operands.values(), context)), ] diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -132,3 +132,11 @@ def implements(*interfaces: OpInterfaceDef): current_op_def().metadata.implements.extend(interfaces) + + +def domain(*dimensions: DimDef): + if current_op_def().domain: + raise ValueError(f"Expected only one set of domain dimensions per operator") + if any(not isinstance(dim, DimDef) for dim in dimensions): + raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}") + current_op_def().domain.extend(dimensions) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -16,6 +16,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + domain(D.m, D.n, D.k) implements(ContractionOpInterface) C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) @@ -30,6 +31,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + domain(D.b, D.m, D.n, D.k) implements(ContractionOpInterface) C[D.b, D.m, D.n] += cast(U, A[D.b, D.m, D.k]) * cast(U, B[D.b, D.k, D.n]) @@ -44,6 +46,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + domain(D.m, D.n) implements(ContractionOpInterface) x[D.m] += cast(U, A[D.m, D.n]) * cast(U, y[D.n]) @@ -58,6 +61,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + domain(D.n, D.m) implements(ContractionOpInterface) x[D.n] += cast(U, y[D.m]) * cast(U, A[D.m, D.n]) @@ -86,6 +90,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] += cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) * cast(U, K[D.kh, D.kw, D.c]) @@ -103,6 +108,7 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] += cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) @@ -123,6 +129,7 @@ element seed the random number generation. The min and max operands limit the range of the generated random numbers. """ + domain(D.m, D.n) multiplier = cast(I32, const(1103515245)) increment = cast(I32, const(12345)) rand1 = (cast(I32, index(D.m)) + seed) * multiplier + increment diff --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py --- a/mlir/test/python/dialects/linalg/opdsl/arguments.py +++ b/mlir/test/python/dialects/linalg/opdsl/arguments.py @@ -9,15 +9,15 @@ # CHECK: name: A # CHECK: usage: InputOperand # CHECK: type_var: T -# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> +# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> # CHECK: name: B # CHECK: usage: InputOperand # CHECK: type_var: T -# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)> +# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)> # CHECK: name: C # CHECK: usage: OutputOperand # CHECK: type_var: U -# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> +# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> @linalg_structured_op def matmul( A=TensorDef(T, S.M, S.K), @@ -44,11 +44,11 @@ # CHECK: name: I # CHECK: usage: InputOperand # CHECK: type_var: T -# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)> +# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)> # CHECK: name: O # CHECK: usage: OutputOperand # CHECK: type_var: T -# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)> +# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)> # CHECK: name: strides # CHECK: usage: IndexAttribute # CHECK: type_var: I64 @@ -58,4 +58,4 @@ I=TensorDef(T, S.IH, S.IW), O=TensorDef(T, S.OH, S.OW, output=True), strides=AttributeDef(S.SH, S.SW)): - O[D.oh, D.ow] = I[D.h * S.SH, D.w * S.SW] + O[D.oh, D.ow] = I[D.oh * S.SH, D.ow * S.SW] diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py @@ -16,6 +16,7 @@ A=TensorDef(T, S.M, S.K), B=TensorDef(T, S.K, S.N), C=TensorDef(T, S.M, S.N, output=True)): + domain(D.m, D.n, D.k) C[D.m, D.n] += A[D.m, D.k] * B[D.k, D.n] @@ -24,6 +25,7 @@ A=TensorDef(T1, S.M, S.K), B=TensorDef(T2, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True)): + domain(D.m, D.n, D.k) C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) @@ -34,6 +36,7 @@ O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), strides=AttributeDef(S.SH, S.SW), dilations=AttributeDef(S.DH, S.DW)): + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] += cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) * cast(U, K[D.kh, D.kw, D.c]) @@ -46,6 +49,7 @@ O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), strides=AttributeDef(S.SH, S.SW), dilations=AttributeDef(S.DH, S.DW)): + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] += cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) @@ -84,14 +88,12 @@ # CHECK: #[[$MUL_MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> # Convolution indexing maps. - # CHECK: #[[$CONV_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 4 + d5 * 2, d3)> - # CHECK: #[[$CONV_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)> - # CHECK: #[[$CONV_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + # CHECK: #[[$CONV_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 4 + d4 * 2, d5)> + # CHECK: #[[$CONV_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> + # CHECK: #[[$CONV_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> # Pooling indexing maps. - # CHECK: #[[$POOL_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3 * 2 + d0, d4 * 4 + d1 * 2, d5)> - # CHECK: #[[$POOL_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)> - # CHECK: #[[$POOL_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> + # CHECK: #[[$POOL_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)> # CHECK-LABEL: func @test_matmul_mono # CHECK-SAME: %[[A:.+]]: tensor<4x16xf32> @@ -197,7 +199,7 @@ # CHECK-LABEL: @test_f32i32_conv # CHECK: linalg.generic # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$CONV_MAP_K]], #[[$CONV_MAP_O]]] - # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"] + # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"] # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[FILTER:.+]]: f32, %[[OUT:.+]]: i32) # CHECK-NEXT: %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32 # CHECK-NEXT: %[[FILTER_CAST:.+]] = fptosi %[[FILTER:.+]] : f32 to i32 @@ -215,8 +217,8 @@ # CHECK-LABEL: @test_f32i32_pooling # CHECK: linalg.generic - # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]] - # CHECK-SAME: iterator_types = ["reduction", "reduction", "parallel", "parallel", "parallel", "parallel"] + # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]] + # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"] # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32) # CHECK-NEXT: %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32 # CHECK-NEXT: %[[SUM:.+]] = addi %[[OUT]], %[[IN_CAST]] : i32 diff --git a/mlir/test/python/dialects/linalg/opdsl/interfaces.py b/mlir/test/python/dialects/linalg/opdsl/interfaces.py --- a/mlir/test/python/dialects/linalg/opdsl/interfaces.py +++ b/mlir/test/python/dialects/linalg/opdsl/interfaces.py @@ -2,13 +2,15 @@ from mlir.dialects.linalg.opdsl.lang import * + # CHECK: --- # CHECK-LABEL: matmul # CHECK: implements: # CHECK-NEXT: - LinalgContractionOpInterface @linalg_structured_op -def matmul(A=TensorDef(T, S.M, S.K), - B=TensorDef(T, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True)): +def matmul( + A=TensorDef(T, S.M, S.K), + B=TensorDef(T, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True)): implements(ContractionOpInterface) C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) diff --git a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py --- a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py +++ b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py @@ -7,9 +7,9 @@ # dims auto discovered emits the right shape, indexing maps and iterator types. # CHECK: --- # CHECK-LABEL: matmul -# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> -# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)> # CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> +# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)> +# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> # CHECK: static_indexing_maps: # CHECK-NEXT: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)> # CHECK-NEXT: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)> @@ -23,6 +23,7 @@ A=TensorDef(T, S.M, S.K), B=TensorDef(T, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True)): + domain(D.m, D.n, D.k) C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) @@ -43,22 +44,25 @@ def dot(A=TensorDef(T, S.M), B=TensorDef(T, S.M), C=TensorDef(U, output=True)): C[None] += cast(U, A[D.m]) * cast(U, B[D.m]) + # Verifies that the index_dims of shape-only operands translate to correct # indexing maps. # CHECK: --- # CHECK-LABEL: pool +# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0)> # CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1)> # CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s2)> -# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0)> # CHECK: static_indexing_maps: -# CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d1 * 2 + d0)> -# CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d0)> +# CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d0 * 2 + d1)> # CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d1)> +# CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d0)> # CHECK: iterator_types: -# CHECK-NEXT: - reduction # CHECK-NEXT: - parallel +# CHECK-NEXT: - reduction @linalg_structured_op -def pool(I=TensorDef(T, S.I), - K=TensorDef(T, S.K, index_dims=[D.k]), - O=TensorDef(U, S.O, output=True)): +def pool( + I=TensorDef(T, S.I), + K=TensorDef(T, S.K, index_dims=[D.k]), + O=TensorDef(U, S.O, output=True)): + domain(D.o, D.k) O[D.o] += cast(U, I[D.o * 2 + D.k])