diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -15,17 +15,17 @@ name: A usage: InputOperand type_var: T1 - shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> + shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> - !LinalgOperandDefConfig name: B usage: InputOperand type_var: T2 - shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)> + shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)> - !LinalgOperandDefConfig name: C usage: OutputOperand type_var: U - shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> + shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)> @@ -77,17 +77,17 @@ name: A usage: InputOperand type_var: T1 - shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)> + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)> - !LinalgOperandDefConfig name: B usage: InputOperand type_var: T2 - shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)> + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)> - !LinalgOperandDefConfig name: C usage: OutputOperand type_var: U - shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)> + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)> indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)> @@ -201,25 +201,25 @@ name: y usage: InputOperand type_var: T1 - shape_map: affine_map<()[s0, s1] -> (s1)> + shape_map: affine_map<()[s0, s1] -> (s0)> - !LinalgOperandDefConfig name: A usage: InputOperand type_var: T2 - shape_map: affine_map<()[s0, s1] -> (s1, s0)> + shape_map: affine_map<()[s0, s1] -> (s0, s1)> - !LinalgOperandDefConfig name: x usage: OutputOperand type_var: U - shape_map: affine_map<()[s0, s1] -> (s0)> + shape_map: affine_map<()[s0, s1] -> (s1)> indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - - affine_map<(d0, d1)[s0, s1] -> (d1)> - - affine_map<(d0, d1)[s0, s1] -> (d1, d0)> - affine_map<(d0, d1)[s0, s1] -> (d0)> + - affine_map<(d0, d1)[s0, s1] -> (d0, d1)> + - affine_map<(d0, d1)[s0, s1] -> (d1)> iterator_types: - - parallel - reduction + - parallel assignments: - !ScalarAssign arg: x @@ -321,19 +321,19 @@ usage: InputOperand type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> - (s0, s4, s5, s3)> + (s0, s1, s2, s3)> - !LinalgOperandDefConfig name: K usage: InputOperand type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> - (s6, s7, s3)> + (s4, s5, s3)> - !LinalgOperandDefConfig name: O usage: OutputOperand type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> - (s0, s1, s2, s3)> + (s0, s6, s7, s3)> - !LinalgOperandDefConfig name: strides usage: IndexAttribute @@ -349,18 +349,18 @@ indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, - s10, s11] -> (d0, d1 * s8 + d4 * s10, d2 * s9 + d5 * s11, d3)> + s10, s11] -> (d0, d1 * s8 + d3 * s10, d2 * s9 + d4 * s11, d5)> - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, - s10, s11] -> (d4, d5, d3)> + s10, s11] -> (d3, d4, d5)> - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, - s10, s11] -> (d0, d1, d2, d3)> + s10, s11] -> (d0, d1, d2, d5)> iterator_types: - parallel - parallel - parallel - - parallel - reduction - reduction + - parallel assignments: - !ScalarAssign arg: O @@ -402,45 +402,45 @@ usage: InputOperand type_var: T1 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> - (s0, s4, s5, s3)> + (s0, s1, s2, s3)> - !LinalgOperandDefConfig name: K usage: InputOperand type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> - (s10, s11)> + (s4, s5)> - !LinalgOperandDefConfig name: O usage: OutputOperand type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> - (s0, s1, s2, s3)> + (s0, s6, s7, s3)> - !LinalgOperandDefConfig name: strides usage: IndexAttribute type_var: I64 attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] - -> (s6, s7)> + -> (s8, s9)> - !LinalgOperandDefConfig name: dilations usage: IndexAttribute type_var: I64 attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] - -> (s8, s9)> + -> (s10, s11)> indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, - s10, s11] -> (d2, d3 * s6 + d0 * s8, d4 * s7 + d1 * s9, d5)> + s10, s11] -> (d0, d1 * s8 + d3 * s10, d2 * s9 + d4 * s11, d5)> - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, - s10, s11] -> (d0, d1)> + s10, s11] -> (d3, d4)> - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, - s10, s11] -> (d2, d3, d4, d5)> + s10, s11] -> (d0, d1, d2, d5)> iterator_types: - - reduction - - reduction - parallel - parallel - parallel + - reduction + - reduction - parallel assignments: - !ScalarAssign diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -32,13 +32,13 @@ """Visits all tensor expression reachable by the expression.""" callback(self) - def _get_all_dim_defs(self) -> Set[DimDef]: - """Recursively gets all DimDef affine expressions that are referenced.""" + def collect_dim_uses(self, uses: Set["DimDef"]): + """Collects all DimDefs reachable through this expression.""" results = set() def visit_dim_def(dim_def): if isinstance(dim_def, DimDef): - results.add(dim_def) + uses.add(dim_def) def visit_affine_exprs(expr): if isinstance(expr, TensorUse): @@ -49,7 +49,6 @@ ind.visit_affine_exprs(visit_dim_def) self.visit_tensor_exprs(visit_affine_exprs) - return results def collect_tensor_uses(self, uses: Set["TensorUse"]): """Collects all TensorUses reachable through this expression.""" @@ -126,8 +125,10 @@ reduced into. Any indices referenced on the rhs and not in self are considered reduction dims and will be ordered as encountered on the rhs. """ - rhs_dims = rhs._get_all_dim_defs() - lhs_dims = self._get_all_dim_defs() + rhs_dims = set() + lhs_dims = set() + rhs.collect_dim_uses(rhs_dims) + self.collect_dim_uses(lhs_dims) return rhs_dims - lhs_dims def __repr__(self): @@ -516,6 +517,7 @@ self.metadata = OpMetadataDef( name=name, cpp_class_name=cpp_class_name, doc=doc) self.registered_operands = dict() # type: Dict[str, OperandDef] + self.domain = list() # type: List[DimDef] self.comprehensions = list() # type: List[Comprehension] self._affine_state = AffineBuildState() diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -115,6 +115,7 @@ def __init__(self, comprehension: Comprehension, + domain: Sequence[DimDef], registered_operands: Sequence[OperandDef], context: Optional[_ir.Context] = None): self.context = context if context is not None else _ir.Context() @@ -123,10 +124,11 @@ self.operands = dict() # type: Dict[OperandDef, OperandDefConfig] self.uses = dict() # type: Dict[TensorUse, TensorUseConfig] - # Compute the ordered set of writes and collect the tensor, capture, and - # index uses. + # Compute the ordered set of writes and collect the tensor, capture, dims, + # and index uses. collected_tensor_uses = set() collected_scalar_uses = set() + collected_dim_uses = set() collected_indices = set() for write_use, read_use in zip(comprehension.definitions, comprehension.values): @@ -136,8 +138,27 @@ collected_tensor_uses.add(write_use) read_use.collect_tensor_uses(collected_tensor_uses) read_use.collect_scalar_uses(collected_scalar_uses) + read_use.collect_dim_uses(collected_dim_uses) + write_use.collect_dim_uses(collected_dim_uses) read_use.collect_indices(collected_indices) + # Verify the domain dimensions match the used dimensions. + if any(dim not in collected_dim_uses for dim in domain): + raise ValueError(f"Expected all domain dimensions {domain} to " + f"have uses") + if domain and len(domain) != len(collected_dim_uses): + raise ValueError(f"Expected the number of domain dimensions " + f"{len(domain)} to match the number of used dims " + f"{len(collected_dim_uses)}") + + # Instantiate the dimensions in the given order. + if domain: + with self.context: + local_state = AffineBuildState( + global_state=self.affine_state, allow_new_symbols=False) + for dim in domain: + dim.build(state=local_state) + # Collect all attribute definitions. collected_attr_defs = list() for operand in registered_operands: @@ -148,18 +169,32 @@ collected_index_defs = list() for operand in registered_operands: if operand.index_dims: + if any(dim not in collected_dim_uses for dim in operand.index_dims): + raise ValueError(f"Expected all index dims {operand.index_dims} of " + f"operand {operand.name} to have uses.") collected_index_defs.append(operand) - # Add all definitions before uses, so process twice. + # Collect the operand definitions of all tensor/scalar uses, attributes, and + # shape-only tensors. + all_operand_defs = list() for use in collected_tensor_uses: - self.add_operand(use.operand_def) + all_operand_defs.append(use.operand_def) for use in collected_scalar_uses: - self.add_operand(use.operand_def) + all_operand_defs.append(use.operand_def) for definition in collected_attr_defs: - self.add_operand(definition) + all_operand_defs.append(definition) + for definition in collected_index_defs: + all_operand_defs.append(definition) + + # Add all operands in registration order to ensure the symbols are + # registered in the order they appear. + all_operand_defs = sorted( + all_operand_defs, key=lambda operand_def: operand_def.registered_index) + for operand_def in all_operand_defs: + self.add_operand(operand_def) + + # Add all shape-only tensor index_dim annotations and all tensor uses. for definition in collected_index_defs: - if definition not in self.operands: - self.add_operand(definition) self.add_indexed_operand(definition) for use in collected_tensor_uses: self.add_tensor_use(use) @@ -396,7 +431,7 @@ LinalgOpConfig( tc_op_def.metadata, structured_op=LinalgStructuredOpConfig( - tc_op_def.comprehensions[0], + tc_op_def.comprehensions[0], tc_op_def.domain, tc_op_def.registered_operands.values(), context)), ] diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -132,3 +132,10 @@ def implements(*interfaces: OpInterfaceDef): current_op_def().metadata.implements.extend(interfaces) + + +def domain(*dimensions: DimDef): + if current_op_def().domain: + raise ValueError(f"Expected only one set of domain dimensions per operator") + current_op_def().domain.extend(dimensions) + return [dimensions] diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -17,7 +17,8 @@ them to the same data type as the accumulator/output. """ implements(ContractionOpInterface) - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + for m, n, k in domain(D.m, D.n, D.k): + C[m, n] += cast(U, A[m, k]) * cast(U, B[k, n]) @linalg_structured_op @@ -31,7 +32,8 @@ them to the same data type as the accumulator/output. """ implements(ContractionOpInterface) - C[D.b, D.m, D.n] += cast(U, A[D.b, D.m, D.k]) * cast(U, B[D.b, D.k, D.n]) + for b, m, n, k in domain(D.b, D.m, D.n, D.k): + C[b, m, n] += cast(U, A[b, m, k]) * cast(U, B[b, k, n]) @linalg_structured_op @@ -45,7 +47,8 @@ them to the same data type as the accumulator/output. """ implements(ContractionOpInterface) - x[D.m] += cast(U, A[D.m, D.n]) * cast(U, y[D.n]) + for m, n in domain(D.m, D.n): + x[m] += cast(U, A[m, n]) * cast(U, y[n]) @linalg_structured_op @@ -59,7 +62,8 @@ them to the same data type as the accumulator/output. """ implements(ContractionOpInterface) - x[D.n] += cast(U, y[D.m]) * cast(U, A[D.m, D.n]) + for m, n in domain(D.m, D.n): + x[n] += cast(U, y[m]) * cast(U, A[m, n]) @linalg_structured_op @@ -71,7 +75,8 @@ them to the same data type as the accumulator/output. """ implements(ContractionOpInterface) - C[None] += cast(U, A[D.m]) * cast(U, B[D.m]) + for m in domain(D.m): + C[None] += cast(U, A[m]) * cast(U, B[m]) @linalg_structured_op @@ -86,9 +91,10 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ - O[D.n, D.oh, D.ow, D.c] += cast( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.c]) * cast(U, K[D.kh, D.kw, D.c]) + for n, oh, ow, kh, kw, c in domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c): + O[n, oh, ow, c] += \ + cast(U, I[n, oh * S.SH + kh * S.DH, ow * S.SW + kw * S.DW, c]) * \ + cast(U, K[kh, kw, c]) @linalg_structured_op @@ -103,8 +109,9 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ - O[D.n, D.oh, D.ow, D.c] += cast( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) + for n, oh, ow, kh, kw, c in domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c): + O[n, oh, ow, c] += \ + cast(U, I[n, oh * S.SH + kh * S.DH, ow * S.SW + kw * S.DW, c]) @linalg_structured_op @@ -123,11 +130,12 @@ element seed the random number generation. The min and max operands limit the range of the generated random numbers. """ - multiplier = cast(I32, const(1103515245)) - increment = cast(I32, const(12345)) - rand1 = (cast(I32, index(D.m)) + seed) * multiplier + increment - rand2 = (cast(I32, index(D.n)) + rand1) * multiplier + increment - inv_range = cast(F64, const(2.3283064e-10)) - offset = cast(F64, const(2147483647)) - scaling = (max - min) * inv_range - O[D.m, D.n] = cast(T, (offset + cast(F64, rand2)) * scaling + min) + for m, n in domain(D.m, D.n): + multiplier = cast(I32, const(1103515245)) + increment = cast(I32, const(12345)) + rand1 = (cast(I32, index(m)) + seed) * multiplier + increment + rand2 = (cast(I32, index(n)) + rand1) * multiplier + increment + inv_range = cast(F64, const(2.3283064e-10)) + offset = cast(F64, const(2147483647)) + scaling = (max - min) * inv_range + O[m, n] = cast(T, (offset + cast(F64, rand2)) * scaling + min) diff --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py --- a/mlir/test/python/dialects/linalg/opdsl/arguments.py +++ b/mlir/test/python/dialects/linalg/opdsl/arguments.py @@ -9,21 +9,22 @@ # CHECK: name: A # CHECK: usage: InputOperand # CHECK: type_var: T -# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> +# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> # CHECK: name: B # CHECK: usage: InputOperand # CHECK: type_var: T -# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)> +# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)> # CHECK: name: C # CHECK: usage: OutputOperand # CHECK: type_var: U -# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> +# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> @linalg_structured_op def matmul( A=TensorDef(T, S.M, S.K), B=TensorDef(T, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True)): - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + for m, n, k in domain(D.m, D.n, D.k): + C[m, n] += cast(U, A[m, k]) * cast(U, B[k, n]) # CHECK: --- @@ -35,7 +36,8 @@ # CHECK: type_var: T @linalg_structured_op def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)): - O[D.m, D.n] = value + for m, n in domain(D.m, D.n): + O[m, n] = value # CHECK: --- @@ -44,11 +46,11 @@ # CHECK: name: I # CHECK: usage: InputOperand # CHECK: type_var: T -# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)> +# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)> # CHECK: name: O # CHECK: usage: OutputOperand # CHECK: type_var: T -# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)> +# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)> # CHECK: name: strides # CHECK: usage: IndexAttribute # CHECK: type_var: I64 @@ -58,4 +60,5 @@ I=TensorDef(T, S.IH, S.IW), O=TensorDef(T, S.OH, S.OW, output=True), strides=AttributeDef(S.SH, S.SW)): - O[D.oh, D.ow] = I[D.h * S.SH, D.w * S.SW] + for oh, ow in domain(D.oh, D.ow): + O[oh, ow] = I[oh * S.SH, ow * S.SW] diff --git a/mlir/test/python/dialects/linalg/opdsl/assignments.py b/mlir/test/python/dialects/linalg/opdsl/assignments.py --- a/mlir/test/python/dialects/linalg/opdsl/assignments.py +++ b/mlir/test/python/dialects/linalg/opdsl/assignments.py @@ -28,7 +28,8 @@ A=TensorDef(T, S.M, S.K), B=TensorDef(T, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True)): - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + for m, n, k in domain(D.m, D.n, D.k): + C[m, n] += cast(U, A[m, k]) * cast(U, B[k, n]) # CHECK: --- @@ -56,10 +57,11 @@ # CHECK: scalar_const: '1.{{[0]*}}e+03 : f64' @linalg_structured_op def constants(O=TensorDef(T, S.M, S.K, output=True)): - pi = cast(T, const(3.1415926535897931)) - cst42 = cast(T, const(42)) - cst1000 = cast(T, const(1e+3)) - O[D.m, D.n] = pi + cst42 - cst1000 + for m, n in domain(D.m, D.n): + pi = cast(T, const(3.1415926535897931)) + cst42 = cast(T, const(42)) + cst1000 = cast(T, const(1e+3)) + O[m, n] = pi + cst42 - cst1000 # CHECK: --- @@ -74,7 +76,8 @@ # CHECK: scalar_index: 0 @linalg_structured_op def indices(O=TensorDef(T, S.M, S.K, output=True)): - O[D.m, D.n] = index(D.n) + index(D.m) + for m, n in domain(D.m, D.n): + O[m, n] = index(n) + index(m) # CHECK: --- @@ -85,4 +88,5 @@ # CHECK: scalar_arg: value @linalg_structured_op def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)): - O[D.m, D.n] = value + for m, n in domain(D.m, D.n): + O[m, n] = value diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py @@ -16,7 +16,8 @@ A=TensorDef(T, S.M, S.K), B=TensorDef(T, S.K, S.N), C=TensorDef(T, S.M, S.N, output=True)): - C[D.m, D.n] += A[D.m, D.k] * B[D.k, D.n] + for m, n, k in domain(D.m, D.n, D.k): + C[m, n] += A[m, k] * B[k, n] @linalg_structured_op @@ -24,7 +25,8 @@ A=TensorDef(T1, S.M, S.K), B=TensorDef(T2, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True)): - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + for m, n, k in domain(D.m, D.n, D.k): + C[m, n] += cast(U, A[m, k]) * cast(U, B[k, n]) @linalg_structured_op @@ -34,9 +36,10 @@ O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), strides=AttributeDef(S.SH, S.SW), dilations=AttributeDef(S.DH, S.DW)): - O[D.n, D.oh, D.ow, D.c] += cast( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.c]) * cast(U, K[D.kh, D.kw, D.c]) + for n, oh, ow, kh, kw, c in domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c): + O[n, oh, ow, c] += \ + cast(U, I[n, oh * S.SH + kh * S.DH, ow * S.SW + kw * S.DW, c]) * \ + cast(U, K[kh, kw, c]) @linalg_structured_op @@ -46,8 +49,9 @@ O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), strides=AttributeDef(S.SH, S.SW), dilations=AttributeDef(S.DH, S.DW)): - O[D.n, D.oh, D.ow, D.c] += cast( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) + for n, oh, ow, kh, kw, c in domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c): + O[n, oh, ow, c] += \ + cast(U, I[n, oh * S.SH + kh * S.DH, ow * S.SW + kw * S.DW, c]) @linalg_structured_op @@ -56,14 +60,15 @@ max=ScalarDef(F64), seed=ScalarDef(I32), O=TensorDef(T, S.M, S.N, output=True)): - multiplier = cast(I32, const(1103515245)) - increment = cast(I32, const(12345)) - rand1 = (cast(I32, index(D.m)) + seed) * multiplier + increment - rand2 = (cast(I32, index(D.n)) + rand1) * multiplier + increment - inv_range = cast(F64, const(2.3283064e-10)) - offset = cast(F64, const(2147483647)) - scaling = (max - min) * inv_range - O[D.m, D.n] = cast(T, (offset + cast(F64, rand2)) * scaling + min) + for m, n in domain(D.m, D.n): + multiplier = cast(I32, const(1103515245)) + increment = cast(I32, const(12345)) + rand1 = (cast(I32, index(m)) + seed) * multiplier + increment + rand2 = (cast(I32, index(n)) + rand1) * multiplier + increment + inv_range = cast(F64, const(2.3283064e-10)) + offset = cast(F64, const(2147483647)) + scaling = (max - min) * inv_range + O[m, n] = cast(T, (offset + cast(F64, rand2)) * scaling + min) with Context() as ctx, Location.unknown(): @@ -84,14 +89,12 @@ # CHECK: #[[$MUL_MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> # Convolution indexing maps. - # CHECK: #[[$CONV_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 4 + d5 * 2, d3)> - # CHECK: #[[$CONV_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)> - # CHECK: #[[$CONV_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + # CHECK: #[[$CONV_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 4 + d4 * 2, d5)> + # CHECK: #[[$CONV_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> + # CHECK: #[[$CONV_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> # Pooling indexing maps. - # CHECK: #[[$POOL_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3 * 2 + d0, d4 * 4 + d1 * 2, d5)> - # CHECK: #[[$POOL_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)> - # CHECK: #[[$POOL_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> + # CHECK: #[[$POOL_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)> # CHECK-LABEL: func @test_matmul_mono # CHECK-SAME: %[[A:.+]]: tensor<4x16xf32> @@ -197,7 +200,7 @@ # CHECK-LABEL: @test_f32i32_conv # CHECK: linalg.generic # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$CONV_MAP_K]], #[[$CONV_MAP_O]]] - # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"] + # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"] # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[FILTER:.+]]: f32, %[[OUT:.+]]: i32) # CHECK-NEXT: %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32 # CHECK-NEXT: %[[FILTER_CAST:.+]] = fptosi %[[FILTER:.+]] : f32 to i32 @@ -215,8 +218,8 @@ # CHECK-LABEL: @test_f32i32_pooling # CHECK: linalg.generic - # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]] - # CHECK-SAME: iterator_types = ["reduction", "reduction", "parallel", "parallel", "parallel", "parallel"] + # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]] + # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"] # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32) # CHECK-NEXT: %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32 # CHECK-NEXT: %[[SUM:.+]] = addi %[[OUT]], %[[IN_CAST]] : i32 diff --git a/mlir/test/python/dialects/linalg/opdsl/interfaces.py b/mlir/test/python/dialects/linalg/opdsl/interfaces.py --- a/mlir/test/python/dialects/linalg/opdsl/interfaces.py +++ b/mlir/test/python/dialects/linalg/opdsl/interfaces.py @@ -2,13 +2,16 @@ from mlir.dialects.linalg.opdsl.lang import * + # CHECK: --- # CHECK-LABEL: matmul # CHECK: implements: # CHECK-NEXT: - LinalgContractionOpInterface @linalg_structured_op -def matmul(A=TensorDef(T, S.M, S.K), - B=TensorDef(T, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True)): +def matmul( + A=TensorDef(T, S.M, S.K), + B=TensorDef(T, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True)): implements(ContractionOpInterface) - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + for m, n, k in domain(D.m, D.n, D.k): + C[m, n] += cast(U, A[m, k]) * cast(U, B[k, n]) diff --git a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py --- a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py +++ b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py @@ -7,9 +7,9 @@ # dims auto discovered emits the right shape, indexing maps and iterator types. # CHECK: --- # CHECK-LABEL: matmul -# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> -# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)> # CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> +# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)> +# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> # CHECK: static_indexing_maps: # CHECK-NEXT: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)> # CHECK-NEXT: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)> @@ -23,7 +23,8 @@ A=TensorDef(T, S.M, S.K), B=TensorDef(T, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True)): - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + for m, n, k in domain(D.m, D.n, D.k): + C[m, n] += cast(U, A[m, k]) * cast(U, B[k, n]) # Verifies that assignment to a scalar (represented as [None]) is represented @@ -41,24 +42,28 @@ # CHECK-NEXT: - reduction @linalg_structured_op def dot(A=TensorDef(T, S.M), B=TensorDef(T, S.M), C=TensorDef(U, output=True)): - C[None] += cast(U, A[D.m]) * cast(U, B[D.m]) + for m in domain(D.m): + C[None] += cast(U, A[m]) * cast(U, B[m]) + # Verifies that the index_dims of shape-only operands translate to correct # indexing maps. # CHECK: --- # CHECK-LABEL: pool +# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0)> # CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1)> # CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s2)> -# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0)> # CHECK: static_indexing_maps: -# CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d1 * 2 + d0)> -# CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d0)> +# CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d0 * 2 + d1)> # CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d1)> +# CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d0)> # CHECK: iterator_types: -# CHECK-NEXT: - reduction # CHECK-NEXT: - parallel +# CHECK-NEXT: - reduction @linalg_structured_op -def pool(I=TensorDef(T, S.I), - K=TensorDef(T, S.K, index_dims=[D.k]), - O=TensorDef(U, S.O, output=True)): - O[D.o] += cast(U, I[D.o * 2 + D.k]) +def pool( + I=TensorDef(T, S.I), + K=TensorDef(T, S.K, index_dims=[D.k]), + O=TensorDef(U, S.O, output=True)): + for o, k in domain(D.o, D.k): + O[o] += cast(U, I[o * 2 + k])