diff --git a/mlir/include/mlir/Dialect/Math/IR/Math.h b/mlir/include/mlir/Dialect/Math/IR/Math.h --- a/mlir/include/mlir/Dialect/Math/IR/Math.h +++ b/mlir/include/mlir/Dialect/Math/IR/Math.h @@ -13,6 +13,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Quant/QuantOps.h b/mlir/include/mlir/Dialect/Quant/QuantOps.h --- a/mlir/include/mlir/Dialect/Quant/QuantOps.h +++ b/mlir/include/mlir/Dialect/Quant/QuantOps.h @@ -15,6 +15,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Types.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/Support/MathExtras.h" diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -14,6 +14,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/TensorEncoding.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #define GET_ATTRDEF_CLASSES diff --git a/mlir/lib/Dialect/Quant/IR/CMakeLists.txt b/mlir/lib/Dialect/Quant/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Quant/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Quant/IR/CMakeLists.txt @@ -12,6 +12,7 @@ LINK_LIBS PUBLIC MLIRIR + MLIRInferTypeOpInterface MLIRSideEffectInterfaces MLIRSupport ) diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -333,8 +333,25 @@ // Skip cases currently being custom generated. // TODO: Remove special cases. - if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) + if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) { + // Check for a non-variable length operand to use as the type anchor. + auto *operandI = llvm::find_if(arguments, [](const Argument &arg) { + NamedTypeConstraint *operand = arg.dyn_cast(); + return operand && !operand->isVariableLength(); + }); + if (operandI == arguments.end()) + return; + + // Map each of the result types to the anchor operation. + int operandIdx = operandI - arguments.begin(); + resultTypeMapping.resize(getNumResults()); + for (int i = 0; i < getNumResults(); ++i) + resultTypeMapping[i].emplace_back(operandIdx); + + allResultsHaveKnownTypes = true; + traits.push_back(Trait::create(inferTrait->getDefInit())); return; + } // We create equivalence classes of argument/result types where arguments // and results are mapped into the same index space and indices corresponding diff --git a/mlir/test/Analysis/test-shape-fn-report.mlir b/mlir/test/Analysis/test-shape-fn-report.mlir --- a/mlir/test/Analysis/test-shape-fn-report.mlir +++ b/mlir/test/Analysis/test-shape-fn-report.mlir @@ -5,9 +5,9 @@ // expected-remark@+1 {{associated shape function: same_result_shape}} func.func @tanh(%arg: tensor<10x20xf32>) -> tensor<10x20xf32> attributes {shape.function = @shape_lib::@same_result_shape} { - // expected-remark@+1 {{no associated way}} + // expected-remark@+1 {{implements InferType op interface}} %0 = math.tanh %arg : tensor<10x20xf32> - // expected-remark@+1 {{associated shape function: same_result_shape}} + // expected-remark@+1 {{implements InferType op interface}} %1 = "test.same_operand_result_type"(%0) : (tensor<10x20xf32>) -> tensor<10x20xf32> return %1 : tensor<10x20xf32> } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2608,15 +2608,9 @@ }]; } -// Single variadic arg with SameOperandsAndResultType and InferTypeOpInterface. -// Tests suppression of ambiguous build methods for operations with -// SameOperandsAndResultType and InferTypeOpInterface. -def TableGenBuildOp5 : TableGenBuildInferReturnTypeBaseOp< - "tblgen_build_5", [SameOperandsAndResultType]>; - // Op with InferTypeOpInterface and regions. -def TableGenBuildOp6 : TableGenBuildInferReturnTypeBaseOp< - "tblgen_build_6", [InferTypeOpInterface]> { +def TableGenBuildOp5 : TableGenBuildInferReturnTypeBaseOp< + "tblgen_build_5", [InferTypeOpInterface]> { let regions = (region AnyRegion:$body); } diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td --- a/mlir/test/mlir-tblgen/op-decl-and-defs.td +++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -199,7 +199,7 @@ let results = (outs AnyType:$b); } -// CHECK_LABEL: class NS_HCollectiveParamsOp : +// CHECK_LABEL: class HCollectiveParamsOp : // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type b, ::mlir::Value a); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a); // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}) @@ -212,7 +212,7 @@ let results = (outs Variadic:$b); } -// CHECK_LABEL: class NS_HCollectiveParamsSuppress0Op : +// CHECK_LABEL: class HCollectiveParamsSuppress0Op : // CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a); // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); @@ -224,7 +224,7 @@ let results = (outs I32:$b); } -// CHECK_LABEL: class NS_HCollectiveParamsSuppress1Op : +// CHECK_LABEL: class HCollectiveParamsSuppress1Op : // CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a); // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); @@ -237,7 +237,7 @@ let arguments = (ins Variadic:$a); let results = (outs Variadic:$b, Variadic:$c); } -// CHECK_LABEL: class NS_HCollectiveParamsSuppress2Op : +// CHECK_LABEL: class HCollectiveParamsSuppress2Op : // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::TypeRange c, ::mlir::ValueRange a); // CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a); // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); @@ -247,11 +247,11 @@ let arguments = (ins AnyType:$a, AnyType:$b); let results = (outs AnyType:$r); } -// CHECK_LABEL: class NS_IOp : +// CHECK_LABEL: class IOp : // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b); // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); -// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // Check default value of `attributes` for the `genInferredTypeCollectiveParamBuilder` builder @@ -259,7 +259,7 @@ let arguments = (ins AnyType:$a, AnyType:$b); let results = (outs AnyType:$r); } -// CHECK_LABEL: class NS_JOp : +// CHECK_LABEL: class JOp : // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b); @@ -292,14 +292,14 @@ let arguments = (ins AnyType:$a, AnyType:$b, I32Attr:$attr1); let results = (outs AnyType:$r); } -// CHECK_LABEL: class NS_LOp : +// CHECK_LABEL: class LOp : // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b, ::mlir::IntegerAttr attr1); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b, ::mlir::IntegerAttr attr1); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b, ::mlir::IntegerAttr attr1); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b, uint32_t attr1); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b, uint32_t attr1); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b, uint32_t attr1); // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); -// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b, ::mlir::IntegerAttr attr1); -// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b, uint32_t attr1); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -27,7 +27,12 @@ // CHECK: void OpB::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type y, ::mlir::Value x) // CHECK: odsState.addTypes(y); // CHECK: void OpB::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value x) -// CHECK: odsState.addTypes({x.getType()}); +// CHECK: ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes; +// CHECK: if (::mlir::succeeded(OpB::inferReturnTypes(odsBuilder.getContext(), +// CHECK: odsState.location, odsState.operands, +// CHECK: odsState.attributes.getDictionary(odsState.getContext()), +// CHECK: /*regions=*/{}, inferredReturnTypes))) +// CHECK: odsState.addTypes(inferredReturnTypes); def OpC : NS_Op<"three_normal_result_op", []> { let results = (outs I32:$x, /*unnamed*/I32, I32:$z); diff --git a/mlir/unittests/TableGen/OpBuildGen.cpp b/mlir/unittests/TableGen/OpBuildGen.cpp --- a/mlir/unittests/TableGen/OpBuildGen.cpp +++ b/mlir/unittests/TableGen/OpBuildGen.cpp @@ -204,7 +204,7 @@ verifyOp(op, {i32Ty, f32Ty}, {*cstI32}, attrs); } -// The next 2 tests test supression of ambiguous build methods for ops that +// The next test checks supression of ambiguous build methods for ops that // have a single variadic input, and single non-variadic result, and which // support the SameOperandsAndResultType trait and and optionally the // InferOpTypeInterface interface. For such ops, the ODS framework generates @@ -213,14 +213,8 @@ testSingleVariadicInputInferredType(); } -TEST_F( - OpBuildGenTest, - BuildMethodsSameOperandsAndResultTypeAndInferOpTypeInterfaceSuppression) { - testSingleVariadicInputInferredType(); -} - TEST_F(OpBuildGenTest, BuildMethodsRegionsAndInferredType) { - auto op = builder.create( + auto op = builder.create( loc, ValueRange{*cstI32, *cstF32}, /*attributes=*/noAttrs); ASSERT_EQ(op->getNumRegions(), 1u); verifyOp(op, {i32Ty}, {*cstI32, *cstF32}, noAttrs);