diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -247,6 +247,16 @@ /// Verifies that the inferred result types match the actual result types for /// the op. Precondition: op implements InferTypeOpInterface. LogicalResult verifyInferredResultTypes(Operation *op); + +LogicalResult inferSameElementCompatibleShapeOperandsAndResultType( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes); +LogicalResult +verifyInferableSameElementCompatibleShapeOperandsAndResultType(Operation *op); +bool verifyInferableSameElementCompatibleShapeOperandsAndResultType( + TypeRange lhs, TypeRange rhs); + } // namespace detail namespace OpTrait { @@ -288,6 +298,35 @@ } }; +/// This class provides verification for ops that are known to have the same +/// element types but only compatible shapes, and still want some result type +/// inference. +template +class InferableSameElementCompatibleShapeOperandsAndResultType + : public TraitBase< + ConcreteType, + InferableSameElementCompatibleShapeOperandsAndResultType> { +public: + static LogicalResult verifyTrait(Operation *op) { + return ::mlir::detail:: + verifyInferableSameElementCompatibleShapeOperandsAndResultType(op); + } + static LogicalResult + inferReturnTypes(MLIRContext *context, Optional location, + ValueRange operands, DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + return ::mlir::detail::inferSameElementCompatibleShapeOperandsAndResultType( + context, location, operands, attributes, regions, inferredReturnTypes); + } + + static bool isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) { + return ::mlir::detail:: + verifyInferableSameElementCompatibleShapeOperandsAndResultType(lhs, + rhs); + } +}; + } // namespace OpTrait } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -185,4 +185,12 @@ // TODO: Change from hard coded to utilizing type inference trait. def SameOperandsAndResultType : NativeOpTrait<"SameOperandsAndResultType">; +// Op has the compatiable shapes for operand and result type, with result +// type inference enabled. +def InferableSameElementCompatibleShapeOperandsAndResultType : TraitList<[ + NativeOpTrait<"InferableSameElementCompatibleShapeOperandsAndResultType">, + InferTypeOpInterface + ]>; + + #endif // MLIR_INFERTYPEOPINTERFACE diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -14,6 +14,8 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/TypeUtilities.h" #include "llvm/Support/FormatVariadic.h" using namespace mlir; @@ -218,3 +220,25 @@ << op->getResultTypes(); return success(); } + +LogicalResult +mlir::detail::inferSameElementCompatibleShapeOperandsAndResultType( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(operands.front().getType()); + return success(); +} + +LogicalResult +mlir::detail::verifyInferableSameElementCompatibleShapeOperandsAndResultType( + Operation *op) { + if (failed(OpTrait::impl::verifySameOperandsAndResultElementType(op))) + return failure(); + SmallVector types(op->getOperandTypes()); + types.append(llvm::to_vector<4>(op->getResultTypes())); + if (failed(verifyCompatibleShapes(types))) + return op->emitOpError() + << "requires the same shape for all operands and results"; + return success(); +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -642,6 +642,13 @@ let results = (outs AnyTensor); } +def OpWithInferableSameElementCompatibleShapeOperandsAndResultType : + TEST_Op<"op_with_infer_type_compatible_shapes", + InferableSameElementCompatibleShapeOperandsAndResultType.traits> { + let arguments = (ins AnyTensor:$lhs, AnyTensor:$rhs); + let results = (outs AnyTensor); +} + def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_type_if", [InferTensorTypeWithReify]> { let arguments = (ins AnyTensor, AnyTensor); diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -238,6 +238,8 @@ invokeCreateWithInferredReturnType(op); invokeCreateWithInferredReturnType< OpWithShapedTypeInferTypeInterfaceOp>(op); + invokeCreateWithInferredReturnType< + OpWithInferableSameElementCompatibleShapeOperandsAndResultType>(op); }; return; }