diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -245,11 +245,25 @@ LogicalResult verifyInferredResultTypes(Operation *op); } // namespace detail +namespace OpTrait { +template +class InferTensorType; +} // namespace OpTrait +} // namespace mlir + +/// Include the generated interface declarations. +#include "mlir/Interfaces/InferTypeOpInterface.h.inc" + +namespace mlir { namespace OpTrait { /// Tensor type inference trait that constructs a tensor from the inferred /// shape and elemental types. -/// Requires: Op implements functions of InferShapedTypeOpInterface. +/// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface. +/// Less strict is possible (e.g., implements inferReturnTypeComponents and +/// these always populates all element types and shapes or fails, but this\ +/// trait is currently only used where the interfaces are, so keep it +/// restricted for now). template class InferTensorType : public TraitBase { public: @@ -258,6 +272,12 @@ ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { + static_assert( + ConcreteType::template hasTrait(), + "requires InferShapedTypeOpInterface to ensure succesful invocation"); + static_assert( + ConcreteType::template hasTrait(), + "requires InferTypeOpInterface to ensure succesful invocation"); return ::mlir::detail::inferReturnTensorTypes( ConcreteType::inferReturnTypeComponents, context, location, operands, attributes, regions, inferredReturnTypes); @@ -267,7 +287,4 @@ } // namespace OpTrait } // namespace mlir -/// Include the generated interface declarations. -#include "mlir/Interfaces/InferTypeOpInterface.h.inc" - #endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_