diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -388,7 +388,7 @@ ```c++ // All result-types/operands/attributes have one aggregate parameter. static void build(OpBuilder &odsBuilder, OperationState &odsState, - ArrayRef resultTypes, + TypeRange resultTypes, ValueRange operands, ArrayRef attributes); @@ -410,7 +410,7 @@ // Each operand/attribute has a separate parameter but result type is aggregate. static void build(OpBuilder &odsBuilder, OperationState &odsState, - ArrayRef resultTypes, + TypeRange resultTypes, Value i32_operand, Value f32_operand, ..., IntegerAttr i32_attr, FloatAttr f32_attr, ...); diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td --- a/mlir/test/mlir-tblgen/op-decl-and-defs.td +++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -94,7 +94,9 @@ // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value val); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, int integer = 0); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::TypeRange s, ::mlir::Value a, ::mlir::ValueRange b, ::mlir::IntegerAttr attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount) +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::ValueRange b, ::mlir::IntegerAttr attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::TypeRange s, ::mlir::Value a, ::mlir::ValueRange b, uint32_t attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount) +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::ValueRange b, uint32_t attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount); // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes, unsigned numRegions) // CHECK: static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); // CHECK: void print(::mlir::OpAsmPrinter &p); @@ -286,6 +288,21 @@ // CHECK: class KWithTraitOp : public ::mlir::Op { + let arguments = (ins AnyType:$a, AnyType:$b, I32Attr:$attr1); + let results = (outs AnyType:$r); +} +// CHECK_LABEL: class NS_LOp : +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b, ::mlir::IntegerAttr attr1); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b, ::mlir::IntegerAttr attr1); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b, uint32_t attr1); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b, uint32_t attr1); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b, ::mlir::IntegerAttr attr1); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b, uint32_t attr1); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + + // Test that type defs have the proper namespaces when used as a constraint. // --- diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1651,31 +1651,44 @@ } void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() { - SmallVector paramList; - SmallVector resultNames; - llvm::StringSet<> inferredAttributes; - buildParamList(paramList, inferredAttributes, resultNames, - TypeParamKind::None); + SmallVector attrBuilderType; + attrBuilderType.push_back(AttrParamKind::WrappedAttr); + // Generate additional builder(s) if attributes can be "unwrapped" + if (canGenerateUnwrappedBuilder(op)) + attrBuilderType.push_back(AttrParamKind::UnwrappedValue); - auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); - // If the builder is redundant, skip generating the method - if (!m) - return; - auto &body = m->body(); - genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes); + auto emit = [&](AttrParamKind attrType) { + SmallVector paramList; + SmallVector resultNames; + llvm::StringSet<> inferredAttributes; + buildParamList(paramList, inferredAttributes, resultNames, + TypeParamKind::None, attrType); - auto numResults = op.getNumResults(); - if (numResults == 0) - return; + auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); + // If the builder is redundant, skip generating the method + if (!m) + return; + auto &body = m->body(); + genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes, + /*isRawValueAttr=*/attrType == + AttrParamKind::UnwrappedValue); - // Push all result types to the operation state - const char *index = op.getOperand(0).isVariadic() ? ".front()" : ""; - std::string resultType = - formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str(); - body << " " << builderOpState << ".addTypes({" << resultType; - for (int i = 1; i != numResults; ++i) - body << ", " << resultType; - body << "});\n\n"; + auto numResults = op.getNumResults(); + if (numResults == 0) + return; + + // Push all result types to the operation state + const char *index = op.getOperand(0).isVariadic() ? ".front()" : ""; + std::string resultType = + formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str(); + body << " " << builderOpState << ".addTypes({" << resultType; + for (int i = 1; i != numResults; ++i) + body << ", " << resultType; + body << "});\n\n"; + }; + + for (auto attrType : attrBuilderType) + emit(attrType); } void OpEmitter::genUseAttrAsResultTypeBuilder() {