diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -151,6 +151,17 @@ // Returns the total number of arguments. int getNumArgs() const { return arguments.size(); } + // Returns true of the operation has a single variadic arg. + bool hasSingleVariadicArg() const; + + // Returns true if the operation has a single variadic result. + bool hasSingleVariadicResult() const { + return getNumResults() == 1 && getResult(0).isVariadic(); + } + + // Returns true of the operation has no variadic regions. + bool hasNoVariadicRegions() const { return getNumVariadicRegions() == 0; } + using arg_iterator = const Argument *; using arg_range = llvm::iterator_range; diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -134,6 +134,11 @@ }); } +bool tblgen::Operator::hasSingleVariadicArg() const { + return getNumArgs() == 1 && getArg(0).is() && + getOperand(0).isVariadic(); +} + tblgen::Operator::arg_iterator tblgen::Operator::arg_begin() const { return arguments.begin(); } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1526,4 +1526,31 @@ let results = (outs Variadic:$resultA, Variadic:$resultB); } +// Single variadic arg, non variadic results, with SameOperandsAndResultType. +// Tests suppression of ambiguious build methods for operations with +// SameOperandsAndResultType trait. +def TableGenBuildOp4 : TEST_Op<"tblgen_build_4", [SameOperandsAndResultType]> { + let arguments = (ins Variadic:$inputs); + let results = (outs AnyType:$result); +} + +// Single variadic arg with SameOperandsAndResultType and InferTypeOpInterface. +// Tests suppression of ambiguious build methods for operations with +// SameOperandsAndResultType and InferTypeOpInterface. +def TableGenBuildOp5 : TEST_Op<"tblgen_build_5", + [SameOperandsAndResultType, InferTypeOpInterface]> { + let arguments = (ins Variadic:$inputs); + let results = (outs AnyType:$result); + + let extraClassDeclaration = [{ + static LogicalResult inferReturnTypes(MLIRContext *, + Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.assign({operands[0].getType()}); + return success(); + } + }]; +} + #endif // TEST_OPS diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -110,8 +110,8 @@ let results = (outs AnyTensor:$result); } -// CHECK-LABEL: OpK::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange input) -// CHECK: odsState.addTypes({input.front().getType()}); +// CHECK-LABEL: OpK::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes ) +// CHECK: odsState.addTypes({operands[0].getType()}); // Test with inferred shapes and interleaved with operands/attributes. // diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -232,6 +232,10 @@ // operand's type as all results' types. void genUseOperandAsResultTypeCollectiveParamBuilder(); + // Returns true if the inferred collective param build method should be + // generated. + bool shouldGenerateInferredTypeCollectiveParamBuilder(); + // Generates the build() method that takes aggregate operands/attributes // parameters. This build() method uses inferred types as result types. // Requires: The type needs to be inferable via InferTypeOpInterface. @@ -984,40 +988,37 @@ // result // // In that case, skip generating such ambiguous build methods here. - bool hasSingleVariadicResult = - op.getNumResults() == 1 && op.getResult(0).isVariadic(); - - bool hasSingleVariadicArg = - op.getNumArgs() == 1 && - op.getArg(0).is() && - op.getOperand(0).isVariadic(); - bool hasNoVariadicRegions = op.getNumVariadicRegions() == 0; - for (auto attrType : attrBuilderType) { // Case 3b above. - if (!(hasNoVariadicRegions && hasSingleVariadicArg && - hasSingleVariadicResult)) + if (!(op.hasNoVariadicRegions() && op.hasSingleVariadicArg() && + op.hasSingleVariadicResult())) emit(attrType, TypeParamKind::Separate, /*inferType=*/false); - if (canInferType(op)) - emit(attrType, TypeParamKind::None, /*inferType=*/true); + if (canInferType(op)) { + // When inferType = true, the generated build method does not have + // result types. If the op has a single variadic arg, then this build + // method will be ambiguious with the collective inferred build method + // generated in `genInferredTypeCollectiveParamBuilder`. If we are going + // to generate that collective inferred method, suppress generating the + // ambiguious build method here. + bool buildMethodAmbiguious = + op.hasSingleVariadicArg() && + shouldGenerateInferredTypeCollectiveParamBuilder(); + if (!buildMethodAmbiguious) + emit(attrType, TypeParamKind::None, /*inferType=*/true); + } // The separate arg + collective param kind method will be: // (a) Same as the separate arg + separate param kind method if there is // only one variadic result. // (b) Ambiguous with the collective params method under conditions in (3a) // above. // In either case, skip generating such build method. - if (!hasSingleVariadicResult && - !(hasNoVariadicRegions && hasSingleVariadicArg)) + if (!op.hasSingleVariadicResult() && + !(op.hasNoVariadicRegions() && op.hasSingleVariadicArg())) emit(attrType, TypeParamKind::Collective, /*inferType=*/false); } } void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() { - // If this op has a variadic result, we cannot generate this builder because - // we don't know how many results to create. - if (op.getNumVariableLengthResults() != 0) - return; - int numResults = op.getNumResults(); // Signature @@ -1055,6 +1056,10 @@ << llvm::join(resultTypes, ", ") << "});\n\n"; } +bool OpEmitter::shouldGenerateInferredTypeCollectiveParamBuilder() { + return canInferType(op) && op.getNumSuccessors() == 0; +} + void OpEmitter::genInferredTypeCollectiveParamBuilder() { // TODO: Expand to support regions. std::string params = @@ -1209,8 +1214,21 @@ // to facilitate different call patterns. if (op.getNumVariableLengthResults() == 0) { if (op.getTrait("OpTrait::SameOperandsAndResultType")) { - genUseOperandAsResultTypeSeparateParamBuilder(); - genUseOperandAsResultTypeCollectiveParamBuilder(); + // If the operation has a single variadic input, then the build method + // generated by `genUseOperandAsResultTypeSeparateParamBuilder` will be + // ambiguious with the one generated by + // `genUseOperandAsResultTypeCollectiveParamBuilder` (they both will have + // a single `ValueRange` argument for operands, and the collective one + // will have a `ArrayRef` argument initalized to empty). + // Suppress such ambiguious build method. + if (!op.hasSingleVariadicArg()) + genUseOperandAsResultTypeSeparateParamBuilder(); + + // The build method generated by the inferred type collective param + // builder and one generated here have the same arguments and hence + // generating both will be ambiguious. Enable just one of them. + if (!shouldGenerateInferredTypeCollectiveParamBuilder()) + genUseOperandAsResultTypeCollectiveParamBuilder(); } if (op.getTrait("OpTrait::FirstAttrDerivedResultType")) genUseAttrAsResultTypeBuilder(); @@ -1269,7 +1287,7 @@ // Generate builder that infers type too. // TODO: Expand to handle regions and successors. - if (canInferType(op) && op.getNumSuccessors() == 0) + if (shouldGenerateInferredTypeCollectiveParamBuilder()) genInferredTypeCollectiveParamBuilder(); } diff --git a/mlir/unittests/TableGen/OpBuildGen.cpp b/mlir/unittests/TableGen/OpBuildGen.cpp --- a/mlir/unittests/TableGen/OpBuildGen.cpp +++ b/mlir/unittests/TableGen/OpBuildGen.cpp @@ -63,6 +63,28 @@ concreteOp.erase(); } + // Helper method to test ops with inferred result types and single variadic + // input. + template + void testSingleVariadicInputInferredType() { + // Test separate arg, separate param build method. + auto op = builder.create(loc, i32Ty, ArrayRef{cstI32, cstI32}); + verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, noAttrs); + + // Test collective params build method. + op = builder.create(loc, ArrayRef{i32Ty}, + ArrayRef{cstI32, cstI32}); + verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, noAttrs); + + // Test build method with no result types, default value of attributes. + op = builder.create(loc, ArrayRef{cstI32, cstI32}); + verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, noAttrs); + + // Test build method with no result types and supplied attributes. + op = builder.create(loc, ArrayRef{cstI32, cstI32}, attrs); + verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, attrs); + } + protected: MLIRContext ctx; OpBuilder builder; @@ -178,4 +200,19 @@ verifyOp(std::move(op), {i32Ty, f32Ty}, {cstI32}, attrs); } +// The next 2 tests test supression of ambiguious build methods for ops that +// have a single variadic input, and single non-variadic result, and which +// support the SameOperandsAndResultType trait and and optionally the +// InferOpTypeInterface interface. For such ops, the ODS framework generates +// build methods with no result types as they are inferred from the input types. +TEST_F(OpBuildGenTest, BuildMethodsSameOperandsAndResultTypeSuppression) { + testSingleVariadicInputInferredType(); +} + +TEST_F( + OpBuildGenTest, + BuildMethodsSameOperandsAndResultTypeAndInferOpTypeInterfaceSuppression) { + testSingleVariadicInputInferredType(); +} + } // namespace mlir