diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -13,19 +13,19 @@ args: - !LinalgOperandDefConfig name: A - usage: input - shape: affine_map<()[s0, s1, s2] -> (s0, s2)> + usage: InputOperand type_var: T1 + shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> - !LinalgOperandDefConfig name: B - usage: input - shape: affine_map<()[s0, s1, s2] -> (s2, s1)> + usage: InputOperand type_var: T2 + shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)> - !LinalgOperandDefConfig name: C - usage: output - shape: affine_map<()[s0, s1, s2] -> (s0, s1)> + usage: OutputOperand type_var: U + shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)> @@ -75,19 +75,19 @@ args: - !LinalgOperandDefConfig name: A - usage: input - shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)> + usage: InputOperand type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)> - !LinalgOperandDefConfig name: B - usage: input - shape: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)> + usage: InputOperand type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)> - !LinalgOperandDefConfig name: C - usage: output - shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)> + usage: OutputOperand type_var: U + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)> indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)> @@ -138,19 +138,19 @@ args: - !LinalgOperandDefConfig name: A - usage: input - shape: affine_map<()[s0, s1] -> (s0, s1)> + usage: InputOperand type_var: T1 + shape_map: affine_map<()[s0, s1] -> (s0, s1)> - !LinalgOperandDefConfig name: y - usage: input - shape: affine_map<()[s0, s1] -> (s1)> + usage: InputOperand type_var: T2 + shape_map: affine_map<()[s0, s1] -> (s1)> - !LinalgOperandDefConfig name: x - usage: output - shape: affine_map<()[s0, s1] -> (s0)> + usage: OutputOperand type_var: U + shape_map: affine_map<()[s0, s1] -> (s0)> indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1)[s0, s1] -> (d0, d1)> @@ -199,19 +199,19 @@ args: - !LinalgOperandDefConfig name: y - usage: input - shape: affine_map<()[s0, s1] -> (s1)> + usage: InputOperand type_var: T1 + shape_map: affine_map<()[s0, s1] -> (s1)> - !LinalgOperandDefConfig name: A - usage: input - shape: affine_map<()[s0, s1] -> (s1, s0)> + usage: InputOperand type_var: T2 + shape_map: affine_map<()[s0, s1] -> (s1, s0)> - !LinalgOperandDefConfig name: x - usage: output - shape: affine_map<()[s0, s1] -> (s0)> + usage: OutputOperand type_var: U + shape_map: affine_map<()[s0, s1] -> (s0)> indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1)[s0, s1] -> (d1)> @@ -260,19 +260,19 @@ args: - !LinalgOperandDefConfig name: A - usage: input - shape: affine_map<()[s0] -> (s0)> + usage: InputOperand type_var: T1 + shape_map: affine_map<()[s0] -> (s0)> - !LinalgOperandDefConfig name: B - usage: input - shape: affine_map<()[s0] -> (s0)> + usage: InputOperand type_var: T2 + shape_map: affine_map<()[s0] -> (s0)> - !LinalgOperandDefConfig name: C - usage: output - shape: affine_map<()[s0] -> ()> + usage: OutputOperand type_var: U + shape_map: affine_map<()[s0] -> ()> indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0)[s0] -> (d0)> @@ -306,6 +306,83 @@ - !ScalarExpression scalar_arg: B --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: depthwise_conv_2d_input_nhwc_filter_hwc_poly + cpp_class_name: DepthwiseConv2DInputNhwcFilterHwcPolyOp + doc: A depth-wise 2-D convolution operation. +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + usage: InputOperand + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> + (s0, s6, s7, s3)> + - !LinalgOperandDefConfig + name: K + usage: InputOperand + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> + (s4, s5, s3)> + - !LinalgOperandDefConfig + name: O + usage: OutputOperand + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> + (s0, s1, s2, s3)> + - !LinalgOperandDefConfig + name: strides + usage: IndexAttribute + type_var: I64 + attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] + -> (s8, s9)> + - !LinalgOperandDefConfig + name: dilations + usage: IndexAttribute + type_var: I64 + attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] + -> (s10, s11)> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, + s10, s11] -> (d0, d1 * s8 + d4 * s10, d2 * s9 + d5 * s11, d3)> + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, + s10, s11] -> (d4, d5, d3)> + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, + s10, s11] -> (d0, d1, d2, d3)> + iterator_types: + - parallel + - parallel + - parallel + - parallel + - reduction + - reduction + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + scalar_arg: O + - !ScalarExpression + scalar_apply: + fn_name: mul + operands: + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: I + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: K +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: fill_rng_2d cpp_class_name: FillRng2DOp @@ -323,21 +400,21 @@ args: - !LinalgOperandDefConfig name: min - usage: input + usage: InputOperand type_var: F64 - !LinalgOperandDefConfig name: max - usage: input + usage: InputOperand type_var: F64 - !LinalgOperandDefConfig name: seed - usage: input + usage: InputOperand type_var: I32 - !LinalgOperandDefConfig name: O - usage: output - shape: affine_map<()[s0, s1] -> (s0, s1)> + usage: OutputOperand type_var: T + shape_map: affine_map<()[s0, s1] -> (s0, s1)> indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1)[s0, s1] -> ()> diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -30,6 +30,36 @@ // ----- +func @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_f32(%input : tensor<1x4x16x1xf32>, %filter: tensor<2x2x1xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> { + %0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} + ins(%input, %filter : tensor<1x4x16x1xf32>, tensor<2x2x1xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> + return %0: tensor<1x2x4x1xf32> +} + +// CHECK-LABEL: @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_f32 +// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[FILTER_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32) +// CHECK-NEXT: %[[MUL:.+]] = mulf %[[IN_ARG]], %[[FILTER_ARG]] : f32 +// CHECK-NEXT: %[[ADD:.+]] = addf %[[OUT_ARG]], %[[MUL]] : f32 +// CHECK-NEXT: linalg.yield %[[ADD]] : f32 +// CHECK-NEXT: -> tensor<1x2x4x1xf32> + +// ----- + +func @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_i32(%input : tensor<1x4x16x1xi32>, %filter: tensor<2x2x1xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> { + %0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} + ins(%input, %filter : tensor<1x4x16x1xi32>, tensor<2x2x1xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> + return %0: tensor<1x2x4x1xi32> +} + +// CHECK-LABEL: @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_i32 +// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[FILTER_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32) +// CHECK-NEXT: %[[MUL:.+]] = muli %[[IN_ARG]], %[[FILTER_ARG]] : i32 +// CHECK-NEXT: %[[ADD:.+]] = addi %[[OUT_ARG]], %[[MUL]] : i32 +// CHECK-NEXT: linalg.yield %[[ADD]] : i32 +// CHECK-NEXT: -> tensor<1x2x4x1xi32> + +// ----- + func @generalize_fill_rng_2d_f32(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xf32>) -> tensor<16x32xf32> { %0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32> return %0: tensor<16x32xf32> diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -21,9 +21,9 @@ args: - !LinalgOperandDefConfig name: O - usage: output - shape: affine_map<()[s0, s1] -> (s0, s1)> + usage: OutputOperand type_var: T + shape_map: affine_map<()[s0, s1] -> (s0, s1)> indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1)[s0, s1] -> (d0, d1)> @@ -86,12 +86,13 @@ # @linalg_structured_op # def test2(I=TensorDef(T, S.M, S.N), -# O=TensorDef(T, S.M, S.N, output=True)): +# O=TensorDef(T, S.M, S.N, output=True), +# strides=AttributeDef(S.S0, S.S1)): # """Title. # Detailed description. # """ -# O[D.m, D.n] = I[D.n, D.m] +# O[D.m, D.n] = I[D.n * S.S0, D.m * S.S1] --- !LinalgOpConfig metadata: !LinalgOpMetadata @@ -103,25 +104,25 @@ Detailed description. structured_op: !LinalgStructuredOpConfig args: - - !LinalgOperandDefConfig - name: value - usage: input - type_var: T - !LinalgOperandDefConfig name: I - usage: input - shape: affine_map<()[s0, s1] -> (s1, s0)> + usage: InputOperand type_var: T + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)> - !LinalgOperandDefConfig name: O - usage: output - shape: affine_map<()[s0, s1] -> (s0, s1)> + usage: OutputOperand type_var: T + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)> + - !LinalgOperandDefConfig + name: strides + usage: IndexAttribute + type_var: I64 + attribute_map: affine_map<()[s0, s1, s2, s3] -> (s2, s3)> indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - - affine_map<(d0, d1)[s0, s1] -> ()> - - affine_map<(d0, d1)[s0, s1] -> (d1, d0)> - - affine_map<(d0, d1)[s0, s1] -> (d0, d1)> + - affine_map<(d0, d1)[s0, s1, s2, s3] -> (d1 * s2, d0 * s3)> + - affine_map<(d0, d1)[s0, s1, s2, s3] -> (d0, d1)> iterator_types: - parallel - parallel @@ -129,23 +130,41 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: - fn_name: add - operands: - - !ScalarExpression - scalar_arg: value - - !ScalarExpression - scalar_arg: I + scalar_arg: I -# IMPL-LABEL: Test2Op::iterator_types() -# IMPL-NEXT: { getParallelIteratorTypeName(), getParallelIteratorTypeName() } +# ODS-LABEL: def Test2Op : LinalgStructuredBase_Op<"test2" + +# ODS: let arguments = +# ODS-NEXT: Variadic:$inputs, +# ODS-NEXT: Variadic:$outputs, +# ODS-NEXT: RankedI64ElementsAttr<[2]>:$strides + +# ODS: "Attribute":$strides +# ODS: $_state.addAttribute("strides", strides); + +# ODS: bool hasDynamicIndexingMaps(); +# ODS-NEXT: LogicalResult verifyIndexingMapRequiredAttributes(); + +# IMPL: getSymbolBindings(Test2Op self) +# IMPL: cst2 = self.strides().getValue({ 0 }); +# IMPL-NEXT: getAffineConstantExpr(cst2, context) +# IMPL: cst3 = self.strides().getValue({ 1 }); +# IMPL-NEXT: getAffineConstantExpr(cst3, context) # IMPL: Test2Op::indexing_maps() -# IMPL: "affine_map<(d0, d1)[s0, s1] -> ()>" -# IMPL: "affine_map<(d0, d1)[s0, s1] -> (d1, d0)>" -# IMPL: "affine_map<(d0, d1)[s0, s1] -> (d0, d1)>" +# IMPL: = getSymbolBindings(*this); +# IMPL: "affine_map<(d0, d1)[s0, s1, s2, s3] -> (d1 * s2, d0 * s3)>" +# IMPL: "affine_map<(d0, d1)[s0, s1, s2, s3] -> (d0, d1)>" + +# IMPL: Test2Op::getNumRegionArgs() { return 2; } + +# IMPL: Test2Op::hasDynamicIndexingMaps() { return true; } +# IMPL: Test2Op::verifyIndexingMapRequiredAttributes() +# IMPL: auto attr = op->getAttrOfType("strides") +# IMPL: "missing indexing map required attribute 'strides'" # IMPL: void Test2Op::regionBuilder( -# IMPL: ImplicitLocOpBuilder &b, Block &block, ValueRange captures) +# IMPL-NEXT: ImplicitLocOpBuilder &b, Block &block, ValueRange captures) +# IMPL-NEXT: assert(2 > 0 && block.getNumArguments() == 2 && -# IMPL: = helper.applyfn__add(block.getArgument(0), block.getArgument(1)); +# IMPL: yields.push_back(block.getArgument(0)); diff --git a/mlir/test/python/dialects/linalg/opsrun.py b/mlir/test/python/dialects/linalg/opsrun.py --- a/mlir/test/python/dialects/linalg/opsrun.py +++ b/mlir/test/python/dialects/linalg/opsrun.py @@ -210,6 +210,36 @@ test_fill_generic() +def test_conv_builtin(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f64 = F64Type.get() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + + @builtin.FuncOp.from_py_func( + MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2, 1), f64), + MemRefType.get((1, 2, 4, 1), i32)) + def conv_on_buffers(input, filter, output): + linalg.depthwise_conv_2d_input_nhwc_filter_hwc_poly( + input, filter, outs=[output], strides=[2, 4], dilations=[1, 2]) + + execution_engine = ExecutionEngine(transform(module, conv_boiler)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result i32. + # Arguments must be passed as pointers. + c_int_p = ctypes.c_int * 1 + res = c_int_p(-1) + execution_engine.invoke("main", res) + + log("RESULT: ", res[0]) + # CHECK: RESULT: 8 + + +test_conv_builtin() + + def test_conv_generic(): with Context() as ctx, Location.unknown(): module = Module.create() diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -62,13 +62,14 @@ AffineMap affineMap() { return affineMapAttr.getValue(); } }; -enum class LinalgOperandDefUsage { input, output }; +enum class LinalgOperandDefUsage { input, output, attribute }; struct LinalgOperandDef { std::string name; LinalgOperandDefUsage usage; - Optional shape; std::string typeVar; + Optional shapeMap; + Optional attributeMap; }; enum class LinalgIteratorTypeDef { @@ -149,8 +150,8 @@ }; /// A structured op models (at most) a single contraction by modeling -/// - A list of named arguments (`LinalgOperandDef`), which can be inputs or -/// outputs. +/// - A list of named arguments (`LinalgOperandDef`), which can be inputs, +/// outputs, or index attributes. /// - List of indexing maps (see `LinalgIndexingMaps`). /// - Iterator types (see `LinalgIteratorTypeDef`). /// - List of scalar level assignment (see `ScalarAssign`). @@ -164,21 +165,28 @@ } }; -/// Maps a named tensor- or scalar-argument to an operation, consisting of: +/// Maps a named tensor, scalar or attribute argument to an operation, +/// consisting of: /// - `name`: Must be unique within the operation. -/// - `usage`: How the argument is used (input, output, etc). -/// - `shape`: An optional AffineMap from all op symbols to the shape of the -/// argument. Only tensor-arguments have a shape. Each shape must be -/// normalized over the same list of symbols and have no dimension inputs. +/// - `usage`: How the argument is used (input, output, attribute, etc). /// - `type_var`: The symbolic type variable that binds to the element or self -/// type of the tensor- or scalar-argument, respectively. +/// type of the tensor or scalar argument, respectively. +/// - `shape_map`: An optional AffineMap from all op symbols to the shape of +/// the argument. Only tensor arguments have a `shape_map`. Each shape must +/// be normalized over the same list of symbols and have no dimension +/// inputs. +/// - `attribute_map`: An optional AffineMap from all op symbols to the +/// attribute symbols. During op creation these symbols are replaced by the +/// corresponding `name` attribute values. Only attribute arguments have +/// an `attribute_map`. template <> struct MappingTraits { static void mapping(IO &io, LinalgOperandDef &info) { io.mapRequired("name", info.name); io.mapRequired("usage", info.usage); - io.mapOptional("shape", info.shape); io.mapRequired("type_var", info.typeVar); + io.mapOptional("shape_map", info.shapeMap); + io.mapOptional("attribute_map", info.attributeMap); } }; @@ -186,8 +194,9 @@ template <> struct ScalarEnumerationTraits { static void enumeration(IO &io, LinalgOperandDefUsage &value) { - io.enumCase(value, "input", LinalgOperandDefUsage::input); - io.enumCase(value, "output", LinalgOperandDefUsage::output); + io.enumCase(value, "InputOperand", LinalgOperandDefUsage::input); + io.enumCase(value, "OutputOperand", LinalgOperandDefUsage::output); + io.enumCase(value, "IndexAttribute", LinalgOperandDefUsage::attribute); } }; @@ -425,9 +434,8 @@ // {2}: op interface list // {3}: documentation (summary + description) // {4}: op attribute list -// {5}: the number of arguments for the op region -// {6}: builder methods taking standalone attribute parameters -// {7}: additional methods for attributes used by indexing maps +// {5}: builder methods taking standalone attribute parameters +// {6}: additional methods for attributes used by indexing maps static const char structuredOpOdsHeaderFormat[] = R"FMT( //===----------------------------------------------------------------------===// // Op definition for {0} @@ -491,7 +499,7 @@ $_state.addTypes(resultTensorTypes); (void)$_state.addRegion(); }]> - {6} + {5} ]; let printer = [{{ return ::printNamedStructuredOp(p, *this); }]; let parser = [{{ @@ -514,11 +522,37 @@ // Generic methods. static unsigned getNumRegionArgs(); std::string getLibraryCallName(); - {7} + {6} }]; } )FMT"; +// Builder method taking attribute parameters. Parameters: +// {0}: Class name +// {1}: Comma interleaved attribute parameters +// {2}: Attribute initialization +static const char structuredOpBuilderFormat[] = R"FMT( + , OpBuilder< + (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, + "ValueRange":$outputs, {1}), + [{{ + $_state.addOperands(inputs); + $_state.addOperands(outputs); + $_state.addTypes(resultTensorTypes); + $_state.addAttribute( + "operand_segment_sizes", + $_builder.getI32VectorAttr({{ + static_cast(inputs.size()), + static_cast(outputs.size())})); + createAndFillStructuredOpRegion<{0}>( + $_builder, + $_state, + TypeRange(inputs), + TypeRange(outputs)); + {2} + }]> +)FMT"; + // The iterator_types() method implementation. Parameters: // {0}: Class name // {1}: Comma interleaved iterator type names. @@ -560,24 +594,53 @@ std::string doc; if (opConfig.metadata->doc) { - const char *docFmt = R"FMT( - let summary = [{ {0} }]; - let description = [{ - {1} - }]; - )FMT"; + static const char structuredOpDocFmt[] = R"FMT( + let summary = [{ {0} }]; + let description = [{ + {1} + }]; +)FMT"; StringRef summary, description; std::tie(summary, description) = StringRef(*opConfig.metadata->doc).trim().split('\n'); - doc = llvm::formatv(docFmt, summary.trim(), description.trim()); + doc = llvm::formatv(structuredOpDocFmt, summary.trim(), description.trim()); } interfaceNameList = interleaveToString(opConfig.metadata->implements, ", "); - os << llvm::formatv( - structuredOpOdsHeaderFormat, opConfig.metadata->cppClassName, - opConfig.metadata->name, interfaceNameList, doc, attrList, - opConfig.structuredOp->args.size(), attrBuilder, attrMethods); + // Assemble the attribute specific logic required for the op definition. + if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { + return arg.usage == LinalgOperandDefUsage::attribute; + })) { + SmallVector attrDefs; + SmallVector attrParams; + SmallVector attrStmts; + for (LinalgOperandDef &arg : opConfig.structuredOp->args) { + if (arg.usage != LinalgOperandDefUsage::attribute) + continue; + assert(arg.attributeMap.hasValue() && arg.typeVar == "I64"); + static const char defFmt[] = "RankedI64ElementsAttr<[{0}]>:${1}"; + static const char paramFmt[] = "\"Attribute\":${0}"; + static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});"; + attrDefs.push_back(llvm::formatv( + defFmt, arg.attributeMap->affineMap().getNumResults(), arg.name)); + attrParams.push_back(llvm::formatv(paramFmt, arg.name)); + attrStmts.push_back(llvm::formatv(stmtFmt, arg.name)); + } + attrList = ",\n" + llvm::join(attrDefs, ",\n"); + attrMethods = R"( + bool hasDynamicIndexingMaps(); + LogicalResult verifyIndexingMapRequiredAttributes(); + )"; + attrBuilder = llvm::formatv( + structuredOpBuilderFormat, opConfig.metadata->cppClassName, + llvm::join(attrParams, ", "), llvm::join(attrStmts, "\n")); + } + + os << llvm::formatv(structuredOpOdsHeaderFormat, + opConfig.metadata->cppClassName, opConfig.metadata->name, + interfaceNameList, doc, attrList, attrBuilder, + attrMethods); return success(); } @@ -595,6 +658,12 @@ std::string bannerComment = llvm::formatv("Implementation of {0}", className); os << llvm::formatv(bannerFormat, bannerComment); + // Compute the number of scalar and tensor arguments. + int64_t numOfArgs = + llvm::count_if(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { + return arg.usage != LinalgOperandDefUsage::attribute; + }); + // Reference iterators. { std::string iteratorsStr; @@ -627,7 +696,6 @@ // For each symbol, generate a declaration for it, either with an // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from // an attribute). - // TODO: Implement attribute constants. // TODO: Possibly lift into a top-level method. static const char structuredOpSymbolBindingsFormat[] = R"FMT( static SmallVector getSymbolBindings({0} self) { @@ -641,10 +709,33 @@ unsigned symbolCount = firstMap.getNumSymbols(); SmallVector symbolBindings; for (unsigned i = 0; i < symbolCount; ++i) { - // TODO: Switch and emit constants for attribute bound symbols. symbolBindings.push_back(llvm::formatv( " exprs.push_back(getAffineSymbolExpr({0}, context));", i)); } + + // Access an index attribute. Parameters: + // {0}: Attribute name + // {1}: Symbol position + // {2}: Attribute index + static const char structuredOpAccessAttrFormat[] = R"FMT( +int64_t cst{1} = self.{0}().getValue({ {2} }); +exprs.push_back(getAffineConstantExpr(cst{1}, context)); +)FMT"; + // Update all symbol bindings mapped to an attribute. + for (LinalgOperandDef &arg : opConfig.structuredOp->args) { + if (arg.usage != LinalgOperandDefUsage::attribute) + continue; + assert(arg.attributeMap.hasValue()); + for (auto &en : + llvm::enumerate(arg.attributeMap->affineMap().getResults())) { + if (auto symbol = en.value().dyn_cast()) { + symbolBindings[symbol.getPosition()] = + llvm::formatv(structuredOpAccessAttrFormat, arg.name, + symbol.getPosition(), en.index()); + } + } + } + std::string symbolBindingsStr; llvm::raw_string_ostream symbolBindingsSs(symbolBindingsStr); llvm::interleave(symbolBindings, symbolBindingsSs, "\n"); @@ -726,7 +817,7 @@ unsigned {0}::getNumRegionArgs() {{ return {1}; } )FMT"; os << llvm::formatv(structuredOpGetNumRegionArgsFormat, className, - opConfig.structuredOp->args.size()); + numOfArgs); } // getLibraryCallName() @@ -741,6 +832,50 @@ os << llvm::formatv(structuredOpGetLibraryCallFormat, className); } + // hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes() + if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { + return arg.usage == LinalgOperandDefUsage::attribute; + })) { + std::vector attrVerifications; + for (LinalgOperandDef &arg : opConfig.structuredOp->args) { + if (arg.usage != LinalgOperandDefUsage::attribute) + continue; + assert(arg.attributeMap.hasValue() && arg.typeVar == "I64"); + // Verify index attribute. Paramters: + // {0}: Attribute name + // {1}: Attribute size + static const char attrFmt[] = R"FMT( +if (auto attr = op->getAttrOfType("{0}")) {{ + if (!attr.getType().getElementType().isInteger(64)) + return op->emitError( + "incorrect element type for indexing map required attribute '{0}'"); + if (attr.getType().getShape() != ArrayRef{{ {1} }) + return op->emitError( + "incorrect shape for indexing map required attribute '{0}'"); +} else { + return op->emitError( + "missing indexing map required attribute '{0}'"); +} +)FMT"; + attrVerifications.push_back(llvm::formatv( + attrFmt, arg.name, arg.attributeMap->affineMap().getNumResults())); + } + + // Generates the verifyIndexingMapRequiredAttributes method. Parameters: + // {0}: Class name + // {1}: Attribute verification + static const char structuredOpVerifyIndexingMapRequiredAttributes[] = R"FMT( +bool {0}::hasDynamicIndexingMaps() {{ return true; } +LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{ + Operation *op = getOperation(); + {1} + return success(); +} +)FMT"; + os << llvm::formatv(structuredOpVerifyIndexingMapRequiredAttributes, + className, llvm::join(attrVerifications, "\n")); + } + // regionBuilder() { // Generates a regionBuilder method. Parameters. @@ -861,7 +996,6 @@ return emitError(genContext.getLoc()) << "mismatched number of assignments vs output arguments"; - int64_t numOfArgs = args.size(); os << llvm::formatv(structuredOpRegionBuilderFormat, className, numOfArgs, interleaveToString(stmts, "\n ")); }