diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md --- a/mlir/docs/Dialects/Linalg/OpDSL.md +++ b/mlir/docs/Dialects/Linalg/OpDSL.md @@ -102,7 +102,7 @@ appear in the parameter list of the operation: ```python -fill(val, in_tensor, outs=[out_tensor]) +copy_and_scale(val, in_tensor, outs=[out_tensor]) ``` ## Attributes @@ -251,3 +251,31 @@ Not all functions are applicable for all numeric types, and on mismatch, op verification will fail. + +## Pointwise Computations + +Pointwise computations are expressible in a rank polymorphic form that supports +arbitrary ranked operands - all of them need to have the same rank - with a +single operation definition. + +An example for a rank polymorphic operation is `fill`: + +```python +@linalg_structured_op +def fill(value=ScalarDef(T1), + O=TensorDef(U, output=True)): + O[None] = TypeFn.cast(U, value) +``` + +The operation sets the elements of the output tensor `O` to `value`. All +operands are either scalars or rank zero tensors that are accessed using the +index `None`. The operation thus performs a scalar computation that trivially +extends to a multi-dimensional pointwise computation. As a result, we may use +`fill` with arbitrary ranked output tensors: + +```python +tensor_2d = linalg.InitTensorOp([4, 8], f32) +tensor_3d = linalg.InitTensorOp([4, 8, 16], f32) +fill(value, outs=[tensor_2d]) +fill(value, outs=[tensor_3d]) +``` diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -2522,6 +2522,42 @@ - !ScalarExpression scalar_arg: I --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: fill_tensor + cpp_class_name: FillTensorOp + doc: |- + Fills the output tensor with the given value. + + Works for arbitrary ranked output tensors since the operation performs scalar + accesses only and is thus rank polymorphic. Numeric casting is performed on + the value operand, promoting it to the same data type as the output. +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: value + usage: InputOperand + type_var: T1 + - !LinalgOperandDefConfig + name: O + usage: OutputOperand + type_var: U + shape_map: affine_map<() -> ()> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<() -> ()> + - affine_map<() -> ()> + iterator_types: [] + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + type_fn: + fn_name: cast + type_var: U + operands: + - !ScalarExpression + scalar_arg: value +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: fill_rng_2d cpp_class_name: FillRng2DOp diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -14,6 +14,7 @@ from .scalar_expr import * from .config import * +from .comprehension import * import numpy as np __all__ = [ @@ -132,6 +133,25 @@ indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \ prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) + # An operation that accesses only scalars and scalar/rank zero tensors is + # rank polymorhpic. We implement rank polymorphism by generating different + # indexing maps and iterators that match the rank of the first output tensor. + # An operation is rank polymorphic if the iteration domain has rank zero. + if not iterator_types_attr: + rank = ShapedType(outs[0].type).rank + iterator_types_attr = ArrayAttr.get([StringAttr.get("parallel")] * rank) + scalar_map = AffineMap.get(rank, 0, []) + tensor_map = AffineMap.get_identity(rank) + indexing_maps = [] + for arg_def in all_arg_defs: + if arg_def.operand_def.kind == OperandKind.Scalar: + indexing_maps.append(scalar_map) + if (arg_def.operand_def.kind == OperandKind.InputTensor or + arg_def.operand_def.kind == OperandKind.OutputTensor): + indexing_maps.append(tensor_map) + indexing_maps_attr = ArrayAttr.get( + [AffineMapAttr.get(am) for am in indexing_maps]) + generic_op = linalg.GenericOp( result_tensors=result_types, inputs=ins, @@ -172,19 +192,13 @@ raise NotImplementedError( f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}") + # Set the index attributes used to compute the indexing maps. named_op = getattr(linalg, op_class_name)(ins, outs, result_types) - linalg.fill_builtin_region(named_op.operation) - # Note: mlir-linalg-ods-yaml-gen.cpp uses a special linalg.memoized_indexing_maps - # attribute that the non-yaml path does not. The non-yaml path hardcodes the - # indexing_maps in C++ directly. - named_op.operation.attributes[ - "linalg.memoized_indexing_maps"] = indexing_maps_attr - # iterator_types are hardcoded in C++ both in the yaml and non-yaml path. - - # Additionally set all named attributes. for name, value in index_attributes.items(): named_op.operation.attributes[name] = value + linalg.fill_builtin_region(named_op.operation) + if len(result_types) == 1: return named_op.result else: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -627,6 +627,17 @@ D.ow * S.SW + D.kw * S.DW, D.c])) +@linalg_structured_op +def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)): + """Fills the output tensor with the given value. + + Works for arbitrary ranked output tensors since the operation performs scalar + accesses only and is thus rank polymorphic. Numeric casting is performed on + the value operand, promoting it to the same data type as the output. + """ + O[None] = TypeFn.cast(U, value) + + @linalg_structured_op def fill_rng_2d( min=ScalarDef(F64), diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -207,6 +207,35 @@ // ----- +func @generalize_fill_0d(%value: f64, %O: tensor) -> tensor { + %0 = linalg.fill_tensor ins(%value: f64) outs(%O : tensor) -> tensor + return %0: tensor +} + +// CHECK-DAG: #[[$MAP0:.+]] = affine_map<() -> ()> + +// CHECK-LABEL: @generalize_fill_0d +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]] +// CHECK-SAME: iterator_types = [] + +// ----- + +func @generalize_fill_2d(%value: f64, %O: memref<16x32xf32>) { + linalg.fill_tensor ins(%value: f64) outs(%O : memref<16x32xf32>) + return +} + +// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1) -> ()> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: @generalize_fill +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] + +// ----- + func @generalize_fill_rng_2d_f32(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xf32>) -> tensor<16x32xf32> { %0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32> return %0: tensor<16x32xf32> diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -175,3 +175,55 @@ # IMPL-NEXT: assert(2 > 0 && block.getNumArguments() == 2 && # IMPL: yields.push_back(block.getArgument(0)); + +# @linalg_structured_op +# def test3(value=ScalarDef(T1), +# O=TensorDef(U, output=True)): +# """Title. + +# Detailed description. +# """ +# O[None] = TypeFn.cast(U, value) + +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: test3 + cpp_class_name: Test3Op + doc: |- + Title. + + Detailed description. +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: value + usage: InputOperand + type_var: T1 + - !LinalgOperandDefConfig + name: O + usage: OutputOperand + type_var: U + shape_map: affine_map<() -> ()> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<() -> ()> + - affine_map<() -> ()> + iterator_types: [] + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + type_fn: + fn_name: cast + type_var: U + operands: + - !ScalarExpression + scalar_arg: value + +# IMPL: Test3Op::iterator_types() { +# IMPL-NEXT: int64_t rank = getRank(getOutputOperand(0)); + +# IMPL: Test3Op::indexing_maps() { +# IMPL-NEXT: MLIRContext *context = getContext(); +# IMPL-NEXT: AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context); +# IMPL-NEXT: AffineMap tensorMap = AffineMap::getMultiDimIdentityMap( diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_fill.py b/mlir/test/python/dialects/linalg/opdsl/emit_fill.py new file mode 100644 --- /dev/null +++ b/mlir/test/python/dialects/linalg/opdsl/emit_fill.py @@ -0,0 +1,46 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects import builtin +from mlir.dialects import linalg +from mlir.dialects import std + +from mlir.dialects.linalg.opdsl.lang import * + +T1 = TV.T1 +T2 = TV.T2 + + +@linalg_structured_op +def fill_poly(value=ScalarDef(T1), O=TensorDef(U, output=True)): + O[None] = TypeFn.cast(U, value) + + +with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + + # Fill indexing maps. + # CHECK-DAG: #[[$MAP0:.+]] = affine_map<() -> ()> + # CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()> + # CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)> + + # CHECK-LABEL: @test_fill_0d + # CHECK: linalg.generic + # CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]] + # CHECK-SAME: iterator_types = [] + @builtin.FuncOp.from_py_func(f32, RankedTensorType.get([], f32)) + def test_fill_0d(value, init_result): + return fill_poly(value, outs=[init_result]) + + # CHECK-LABEL: @test_fill_2d + # CHECK: linalg.generic + # CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]]] + # CHECK-SAME: iterator_types = ["parallel", "parallel"] + @builtin.FuncOp.from_py_func(f32, RankedTensorType.get([4, 16], f32)) + def test_fill_2d(value, init_result): + return fill_poly(value, outs=[init_result]) + + +print(module) diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -126,7 +126,7 @@ # CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32 # CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32 # CHECK-NEXT: linalg.yield{{.*}} (f32) -> () - # CHECK-NEXT: {linalg.memoized_indexing_maps{{.*}}operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : + # CHECK-NEXT: operand_segment_sizes = dense<[2, 1]> : vector<2xi32> # CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> return linalg.matmul(lhs, rhs, outs=[init_result.result]) diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py --- a/mlir/test/python/integration/dialects/linalg/opsrun.py +++ b/mlir/test/python/integration/dialects/linalg/opsrun.py @@ -42,13 +42,42 @@ """ fill_boiler = """ +func @main() -> i32 attributes {llvm.emit_c_interface} { + %O0 = memref.alloc() : memref + %O1 = memref.alloc() : memref<16xi32> + %O2 = memref.alloc() : memref<4x16xi32> + + %val0 = arith.constant 1.0 : f32 + %val1 = arith.constant 2.0 : f32 + %val2 = arith.constant 3.0 : f32 + + call @fill_0d_on_buffers(%val0, %O0) : (f32, memref) -> () + call @fill_1d_on_buffers(%val1, %O1) : (f32, memref<16xi32>) -> () + call @fill_2d_on_buffers(%val2, %O2) : (f32, memref<4x16xi32>) -> () + + %c0 = arith.constant 0 : index + %res0 = memref.load %O0[] : memref + %c8 = arith.constant 8 : index + %res1 = memref.load %O1[%c8] : memref<16xi32> + %c2 = arith.constant 2 : index + %res2 = memref.load %O2[%c2, %c8] : memref<4x16xi32> + + %0 = arith.addi %res0, %res1 : i32 + %1 = arith.addi %0, %res2 : i32 + + // TODO: FFI-based solution to allow testing and printing with python code. + return %1 : i32 +} +""" + +fill_rng_boiler = """ func @main() -> i32 attributes {llvm.emit_c_interface} { %O = memref.alloc() : memref<4x16xi32> %min = arith.constant -1000.0 : f64 %max = arith.constant 1000.0 : f64 %seed = arith.constant 42 : i32 - call @fill_on_buffers(%min, %max, %seed, %O) : + call @fill_rng_on_buffers(%min, %max, %seed, %O) : (f64, f64, i32, memref<4x16xi32>) -> () %c0 = arith.constant 0 : index @@ -123,9 +152,9 @@ # TODO: Allow cloning functions from one module to another. # Atm we have to resort to string concatenation. - mod = Module.parse( - str(module.operation.regions[0].blocks[0].operations[0].operation) + - boilerplate) + ops = module.operation.regions[0].blocks[0].operations + mod = Module.parse("\n".join([str(op) for op in ops]) + boilerplate) + pm = PassManager.parse( "builtin.func(convert-linalg-to-loops, lower-affine, " + "convert-scf-to-cf, arith-expand, memref-expand), convert-vector-to-llvm," + @@ -192,6 +221,76 @@ def test_fill_builtin(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + + @builtin.FuncOp.from_py_func(f32, MemRefType.get([], i32)) + def fill_0d_on_buffers(value, out): + linalg.fill_tensor(value, outs=[out]) + + @builtin.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) + def fill_1d_on_buffers(value, out): + linalg.fill_tensor(value, outs=[out]) + + @builtin.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) + def fill_2d_on_buffers(value, out): + linalg.fill_tensor(value, outs=[out]) + + execution_engine = ExecutionEngine(transform(module, fill_boiler)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result i32. + # Arguments must be passed as pointers. + c_int_p = ctypes.c_int * 1 + res = c_int_p(-1) + execution_engine.invoke("main", res) + + log("RESULT: ", res[0]) + # CHECK: RESULT: 6 + + +test_fill_builtin() + + +def test_fill_generic(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + + @builtin.FuncOp.from_py_func(f32, MemRefType.get([], i32)) + def fill_0d_on_buffers(value, out): + linalg.fill_tensor(value, outs=[out], emit_generic=True) + + @builtin.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) + def fill_1d_on_buffers(value, out): + linalg.fill_tensor(value, outs=[out], emit_generic=True) + + @builtin.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) + def fill_2d_on_buffers(value, out): + linalg.fill_tensor(value, outs=[out], emit_generic=True) + + execution_engine = ExecutionEngine(transform(module, fill_boiler)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result i32. + # Arguments must be passed as pointers. + c_int_p = ctypes.c_int * 1 + res = c_int_p(-1) + execution_engine.invoke("main", res) + + log("RESULT: ", res[0]) + # CHECK: RESULT: 6 + + +test_fill_generic() + + +def test_fill_rng_builtin(): with Context() as ctx, Location.unknown(): module = Module.create() f64 = F64Type.get() @@ -199,10 +298,10 @@ with InsertionPoint(module.body): @builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32)) - def fill_on_buffers(min, max, seed, out): + def fill_rng_on_buffers(min, max, seed, out): linalg.fill_rng_2d(min, max, seed, outs=[out]) - execution_engine = ExecutionEngine(transform(module, fill_boiler)) + execution_engine = ExecutionEngine(transform(module, fill_rng_boiler)) # TODO: FFI-based solution to allow testing and printing with python code. # Prepare arguments: one result i32. @@ -215,10 +314,10 @@ # CHECK: RESULT: -480 -test_fill_builtin() +test_fill_rng_builtin() -def test_fill_generic(): +def test_fill_rng_generic(): with Context() as ctx, Location.unknown(): module = Module.create() f64 = F64Type.get() @@ -226,10 +325,10 @@ with InsertionPoint(module.body): @builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32)) - def fill_on_buffers(min, max, seed, out): + def fill_rng_on_buffers(min, max, seed, out): linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True) - execution_engine = ExecutionEngine(transform(module, fill_boiler)) + execution_engine = ExecutionEngine(transform(module, fill_rng_boiler)) # TODO: FFI-based solution to allow testing and printing with python code. # Prepare arguments: one result i32. @@ -242,7 +341,7 @@ # CHECK: RESULT: -480 -test_fill_generic() +test_fill_rng_generic() def test_max_pooling_builtin(): diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -558,16 +558,63 @@ }]> )FMT"; -// The iterator_types() method implementation. Parameters: +// The iterator_types() method for structured ops. Parameters: // {0}: Class name // {1}: Comma interleaved iterator type names. static const char structuredOpIteratorTypesFormat[] = R"FMT( -ArrayAttr {0}::iterator_types() { +ArrayAttr {0}::iterator_types() {{ return Builder(getContext()).getStrArrayAttr(SmallVector{{ {1} }); } )FMT"; +// The iterator_types() method for rank polymorphic structured ops. Parameters: +// {0}: Class name +static const char rankPolyStructuredOpIteratorTypesFormat[] = + R"FMT( +ArrayAttr {0}::iterator_types() {{ + int64_t rank = getRank(getOutputOperand(0)); + return Builder(getContext()).getStrArrayAttr( + SmallVector(rank, getParallelIteratorTypeName())); +} +)FMT"; + +// The indexing_maps() method for structured ops. Parameters: +// {0}: Class name +// {1}: Comma-separated list of dimension variable names. +// {2}: Statements +static const char structuredOpIndexingMapsFormat[] = R"FMT( +ArrayAttr {0}::indexing_maps() {{ + static const char memoizeAttr[] = "linalg.memoized_indexing_maps"; + ArrayAttr cached = getOperation()->getAttrOfType(memoizeAttr); + if (cached) + return cached; + + MLIRContext *context = getContext(); + auto symbolBindings = getSymbolBindings(*this); + SmallVector maps; + {2} + cached = Builder(context).getAffineMapArrayAttr(maps); + getOperation()->setAttr(memoizeAttr, cached); + return cached; +} +)FMT"; + +// The indexing_maps() method for rank polymorphic structured ops. Parameters: +// {0}: Class name +static const char rankPolyStructuredOpIndexingMapsFormat[] = R"FMT( +ArrayAttr {0}::indexing_maps() {{ + MLIRContext *context = getContext(); + AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context); + AffineMap tensorMap = AffineMap::getMultiDimIdentityMap( + getNumParallelLoops(), context); + SmallVector indexingMaps; + for (OpOperand *opOperand : getInputAndOutputOperands()) + indexingMaps.push_back(isScalar(opOperand) ? scalarMap : tensorMap); + return Builder(getContext()).getAffineMapArrayAttr(indexingMaps); +} +)FMT"; + // Implementations of fold and getEffects. // Parameters: // {0}: Class name @@ -681,8 +728,14 @@ return arg.usage != LinalgOperandDefUsage::attribute; }); - // Reference iterators. - { + // An operation that accesses only scalars and scalar/rank zero tensors is + // rank polymorhpic. We implement rank polymorphism by generating different + // indexing maps and iterators that match the rank of the first output tensor. + // An operation is rank polymorphic if the iteration domain has rank zero. + bool isRankPolymorphic = opConfig.structuredOp->iteratorTypes.empty(); + + // Generate the iterator_types() method. + if (!isRankPolymorphic) { std::string iteratorsStr; llvm::raw_string_ostream ss(iteratorsStr); llvm::interleaveComma(opConfig.structuredOp->iteratorTypes, ss, @@ -699,22 +752,25 @@ ss.flush(); os << llvm::formatv(structuredOpIteratorTypesFormat, className, iteratorsStr); + } else { + os << llvm::formatv(rankPolyStructuredOpIteratorTypesFormat, className); } - // Static indexing maps. + // Generating the indexing_maps() method. if (auto &staticMaps = opConfig.structuredOp->indexingMaps.staticIndexingMaps) { if (staticMaps->empty()) return emitError(genContext.getLoc()) << "op has no indexing maps"; - AffineMap firstMap = staticMaps->front().affineMap(); - - // Symbol bindings. - { - // For each symbol, generate a declaration for it, either with an - // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from - // an attribute). - // TODO: Possibly lift into a top-level method. - static const char structuredOpSymbolBindingsFormat[] = R"FMT( + if (!isRankPolymorphic) { + AffineMap firstMap = staticMaps->front().affineMap(); + + // Symbol bindings. + { + // For each symbol, generate a declaration for it, either with an + // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from + // an attribute). + // TODO: Possibly lift into a top-level method. + static const char structuredOpSymbolBindingsFormat[] = R"FMT( static SmallVector getSymbolBindings({0} self) { MLIRContext *context = self.getContext(); SmallVector exprs; @@ -723,101 +779,83 @@ } )FMT"; - unsigned symbolCount = firstMap.getNumSymbols(); - SmallVector symbolBindings; - for (unsigned i = 0; i < symbolCount; ++i) { - symbolBindings.push_back(llvm::formatv( - " exprs.push_back(getAffineSymbolExpr({0}, context));", i)); - } + unsigned symbolCount = firstMap.getNumSymbols(); + SmallVector symbolBindings; + for (unsigned i = 0; i < symbolCount; ++i) { + symbolBindings.push_back(llvm::formatv( + " exprs.push_back(getAffineSymbolExpr({0}, context));", i)); + } - // Access an index attribute. Parameters: - // {0}: Attribute name - // {1}: Symbol position - // {2}: Attribute index - static const char structuredOpAccessAttrFormat[] = R"FMT( + // Access an index attribute. Parameters: + // {0}: Attribute name + // {1}: Symbol position + // {2}: Attribute index + static const char structuredOpAccessAttrFormat[] = R"FMT( int64_t cst{1} = self.{0}().getValues()[{2}]; exprs.push_back(getAffineConstantExpr(cst{1}, context)); )FMT"; - // Update all symbol bindings mapped to an attribute. - for (LinalgOperandDef &arg : opConfig.structuredOp->args) { - if (arg.usage != LinalgOperandDefUsage::attribute) - continue; - assert(arg.attributeMap.hasValue()); - for (auto &en : - llvm::enumerate(arg.attributeMap->affineMap().getResults())) { - if (auto symbol = en.value().dyn_cast()) { - symbolBindings[symbol.getPosition()] = - llvm::formatv(structuredOpAccessAttrFormat, arg.name, - symbol.getPosition(), en.index()); + // Update all symbol bindings mapped to an attribute. + for (LinalgOperandDef &arg : opConfig.structuredOp->args) { + if (arg.usage != LinalgOperandDefUsage::attribute) + continue; + assert(arg.attributeMap.hasValue()); + for (auto &en : + llvm::enumerate(arg.attributeMap->affineMap().getResults())) { + if (auto symbol = en.value().dyn_cast()) { + symbolBindings[symbol.getPosition()] = + llvm::formatv(structuredOpAccessAttrFormat, arg.name, + symbol.getPosition(), en.index()); + } } } - } - std::string symbolBindingsStr; - llvm::raw_string_ostream symbolBindingsSs(symbolBindingsStr); - llvm::interleave(symbolBindings, symbolBindingsSs, "\n"); - symbolBindingsSs.flush(); + std::string symbolBindingsStr; + llvm::raw_string_ostream symbolBindingsSs(symbolBindingsStr); + llvm::interleave(symbolBindings, symbolBindingsSs, "\n"); + symbolBindingsSs.flush(); - os << llvm::formatv(structuredOpSymbolBindingsFormat, className, - symbolBindingsStr); - } - - // Indexing maps. - { - // Parameters: - // {0}: Class name - // {1}: Comma-separated list of dimension variable names. - // {2}: Statements - static const char structuredOpIndexingMapsFormat[] = R"FMT( -ArrayAttr {0}::indexing_maps() { - static const char memoizeAttr[] = "linalg.memoized_indexing_maps"; - ArrayAttr cached = getOperation()->getAttrOfType(memoizeAttr); - if (cached) - return cached; + os << llvm::formatv(structuredOpSymbolBindingsFormat, className, + symbolBindingsStr); + } - MLIRContext *context = getContext(); - auto symbolBindings = getSymbolBindings(*this); - SmallVector maps; - {2} - cached = Builder(context).getAffineMapArrayAttr(maps); - getOperation()->setAttr(memoizeAttr, cached); - return cached; -} -)FMT"; + // Indexing maps. + { + unsigned dimCount = firstMap.getNumDims(); + + // Generate a comma-separated list of dim identifiers to be passed to + // bindDims, ensuring tht AffineExpr identifiers are bound in the right + // order to the proper AffineDimExpr. + // This results in vars in scope like: d0, d1, d2... + SmallVector dimIndices; + for (unsigned i = 0; i < dimCount; ++i) + dimIndices.push_back(i); + std::string dimIdentsStr; + llvm::raw_string_ostream dimIdentsSs(dimIdentsStr); + llvm::interleaveComma(dimIndices, dimIdentsSs, + [&](unsigned i) { dimIdentsSs << "d" << i; }); + dimIdentsSs.flush(); + + // Statements to add and simplify each affine map. + SmallVector stmts; + for (auto &indexingMap : *staticMaps) { + // TODO: Assert that dim and symbol count match the first. + stmts.push_back( + llvm::formatv("maps.push_back({0});", + generateCppExpression(indexingMap, "context"))); + stmts.push_back(llvm::formatv( + "maps.back() = " + "simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, " + "symbolBindings, {0}, 0));", + dimCount)); + } - unsigned dimCount = firstMap.getNumDims(); - - // Generate a comma-separated list of dim identifiers to be passed to - // bindDims, ensuring tht AffineExpr identifiers are bound in the right - // order to the proper AffineDimExpr. - // This results in vars in scope like: d0, d1, d2... - SmallVector dimIndices; - for (unsigned i = 0; i < dimCount; ++i) - dimIndices.push_back(i); - std::string dimIdentsStr; - llvm::raw_string_ostream dimIdentsSs(dimIdentsStr); - llvm::interleaveComma(dimIndices, dimIdentsSs, - [&](unsigned i) { dimIdentsSs << "d" << i; }); - dimIdentsSs.flush(); - - // Statements to add and simplify each affine map. - SmallVector stmts; - for (auto &indexingMap : *staticMaps) { - // TODO: Assert that dim and symbol count match the first. - stmts.push_back( - llvm::formatv("maps.push_back({0});", - generateCppExpression(indexingMap, "context"))); - stmts.push_back(llvm::formatv( - "maps.back() = " - "simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, " - "symbolBindings, {0}, 0));", - dimCount)); + // TODO: This needs to be memoized and/or converted to non-parser based + // C++ codegen prior to real use. + os << llvm::formatv(structuredOpIndexingMapsFormat, className, + dimIdentsStr, interleaveToString(stmts, "\n ")); } - - // TODO: This needs to be memoized and/or converted to non-parser based - // C++ codegen prior to real use. - os << llvm::formatv(structuredOpIndexingMapsFormat, className, - dimIdentsStr, interleaveToString(stmts, "\n ")); + } else { + os << llvm::formatv(rankPolyStructuredOpIndexingMapsFormat, className); } } else { return emitError(genContext.getLoc())