diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -245,11 +245,25 @@ LogicalResult verifyInferredResultTypes(Operation *op); } // namespace detail +namespace OpTrait { +template +class InferTensorType; +} // namespace OpTrait +} // namespace mlir + +/// Include the generated interface declarations. +#include "mlir/Interfaces/InferTypeOpInterface.h.inc" + +namespace mlir { namespace OpTrait { /// Tensor type inference trait that constructs a tensor from the inferred /// shape and elemental types. -/// Requires: Op implements functions of InferShapedTypeOpInterface. +/// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface. +/// Less strict is possible (e.g., implements inferReturnTypeComponents and +/// these always populates all element types and shapes or fails, but this\ +/// trait is currently only used where the interfaces are, so keep it +/// restricted for now). template class InferTensorType : public TraitBase { public: @@ -258,6 +272,9 @@ ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { + assert( + ConcreteType::template hasTrait() && + ConcreteType::template hasTrait()); return ::mlir::detail::inferReturnTensorTypes( ConcreteType::inferReturnTypeComponents, context, location, operands, attributes, regions, inferredReturnTypes); @@ -267,7 +284,4 @@ } // namespace OpTrait } // namespace mlir -/// Include the generated interface declarations. -#include "mlir/Interfaces/InferTypeOpInterface.h.inc" - #endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_