diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -88,6 +88,7 @@ // ODS: F32:$f32_attr, // ODS: RankedF32ElementsAttr<[4]>:$fvec_attr, // ODS: I32:$i32_attr, +// ODS: I64:$i64_attr, // ODS: RankedI32ElementsAttr<[5, 6]>:$ivec_attr, // ODS: OptionalAttr:$optional_attr // @@ -96,6 +97,7 @@ attr( f32_attr: f32, i32_attr: i32, + i64_attr: i64, fvec_attr: 4xf32, ivec_attr: 5x6xi32, array_attr : f32[], @@ -126,6 +128,7 @@ I(n, h * strides[0] + kh, w * strides[1] + kw, c), K(f, kh, kw, c))); } +// Test documentation // ODS-LABEL: def Test6Op // ODS: let summary = [{ My magic op. }]; // ODS-NEXT: let description = [{ @@ -144,3 +147,18 @@ { C(m) = std_addf(std_mulf(A(m, k), B(k))); } + +// Test attribute builder +// ODS-LABEL: def Test7Op +// ODS: OpBuilderDAG< +// ODS: (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, +// ODS: "ValueRange":$outputs, "Attribute":$attr_a, "Attribute":$attr_b) +// ODS: $_state.addAttribute("attr_a", attr_a); +// ODS: $_state.addAttribute("attr_b", attr_b); +// +ods_def: +def test7(A: f32(M, K), B: f32(K)) -> (C: f32(M)) + attr(attr_a: f32, attr_b: 4xi32) +{ + C(m) = std_addf(std_mulf(A(m, k), B(k))); +} diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1768,6 +1768,7 @@ std::string odsType = llvm::StringSwitch(elementType) .Case("f32", "F32") .Case("i32", "I32") + .Case("i64", "I64") .Default(""); if (odsType.empty()) { parser.emitError("unimplemented support for attribute element type: " + @@ -1811,7 +1812,8 @@ let regions = (region AnyRegion:$region); let skipDefaultBuilders = 1; - let builders = [ OpBuilderDAG< + let builders = [ + OpBuilderDAG< (ins "ValueRange":$inputs, "ValueRange":$outputs), [{{ $_state.addOperands(inputs); @@ -1826,7 +1828,8 @@ $_state, TypeRange(inputs), TypeRange(outputs)); - }]>, OpBuilderDAG< + }]>, + OpBuilderDAG< (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, "ValueRange":$outputs), [{{ @@ -1843,7 +1846,8 @@ $_state, TypeRange(inputs), TypeRange(outputs)); - }]>, OpBuilderDAG< + }]>, + OpBuilderDAG< (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, CArg<"ArrayRef", "{{}">:$attributes), [{{ @@ -1852,6 +1856,7 @@ $_state.addTypes(resultTensorTypes); (void)$_state.addRegion(); }]> + {5} ]; let printer = [{{ return ::printNamedStructuredOp(p, *this); }]; let parser = [{{ return ::parseNamedStructuredOp<{0}>(parser, result); }]; @@ -1873,8 +1878,8 @@ }]; })FMT"; + // Generate documentation. std::string doc; - if (!docString.empty()) { const char *docFmt = R"FMT( let summary = [{ {0} }]; @@ -1888,8 +1893,47 @@ doc = llvm::formatv(docFmt, summary.trim(), description.trim()); } + // Generate an additional builder that has parameters for attributes. + std::string attrBuilder; + if (!registeredAttrs.empty()) { + SmallVector attrParams, attrStmts; + for (const auto &attr : registeredAttrs) { + llvm::StringRef name = attr.first; + attrParams.push_back(llvm::formatv("\"Attribute\":${0}", name)); + attrStmts.push_back( + llvm::formatv("$_state.addAttribute(\"{0}\", {0});", name)); + } + std::string attrParamsList = llvm::join(attrParams, ", "); + std::string attrStmtsList = llvm::join(attrStmts, "\n"); + + const char *builderFmt = R"FMT( + , OpBuilderDAG< + (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, + "ValueRange":$outputs, {1}), + [{{ + $_state.addOperands(inputs); + $_state.addOperands(outputs); + $_state.addTypes(resultTensorTypes); + $_state.addAttribute( + "operand_segment_sizes", + $_builder.getI32VectorAttr({{ + static_cast(inputs.size()), + static_cast(outputs.size())})); + buildNamedStructuredOpRegionAndAttributes<{0}>( + $_builder, + $_state, + TypeRange(inputs), + TypeRange(outputs)); + {2} + }]> + )FMT"; + attrBuilder = + llvm::formatv(builderFmt, cppOpName, attrParamsList, attrStmtsList); + } + + // Finally put everything together. os << llvm::formatv(header, cppOpName, linalgOpName, doc, attrList, - state.orderedTensorArgs.size()); + state.orderedTensorArgs.size(), attrBuilder); } /// Print the C++ StructuredOpsInterface impl of `iterator_types`. @@ -2146,13 +2190,15 @@ return llvm::formatv("getValue({ {0} })", indexList); if (elementType == "i32") return llvm::formatv("getValue({ {0} })", indexList); + if (elementType == "i64") + return llvm::formatv("getValue({ {0} })", indexList); return ""; } if (elementType == "f32") return "getValue().convertToFloat()"; - if (elementType == "i32") + if (elementType == "i32" || elementType == "i64") return "getInt()"; return ""; }