diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -25,67 +25,9 @@ namespace mlir { +class ShapedTypeComponents; using ReifiedRankedShapedTypeDims = SmallVector>; -/// ShapedTypeComponents that represents the components of a ShapedType. -/// The components consist of -/// - A ranked or unranked shape with the dimension specification match those -/// of ShapeType's getShape() (e.g., dynamic dimension represented using -/// ShapedType::kDynamicSize) -/// - A element type, may be unset (nullptr) -/// - A attribute, may be unset (nullptr) -/// Used by ShapedType type inferences. -class ShapedTypeComponents { - /// Internal storage type for shape. - using ShapeStorageT = SmallVector; - -public: - /// Default construction is an unranked shape. - ShapedTypeComponents() : elementType(nullptr), attr(nullptr){}; - ShapedTypeComponents(Type elementType) - : elementType(elementType), attr(nullptr), ranked(false) {} - ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) { - ranked = shapedType.hasRank(); - elementType = shapedType.getElementType(); - if (ranked) - dims = llvm::to_vector<4>(shapedType.getShape()); - } - template ::value>> - ShapedTypeComponents(Arg &&arg, Type elementType = nullptr, - Attribute attr = nullptr) - : dims(std::forward(arg)), elementType(elementType), attr(attr), - ranked(true) {} - ShapedTypeComponents(ArrayRef vec, Type elementType = nullptr, - Attribute attr = nullptr) - : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr), - ranked(true) {} - - /// Return the dimensions of the shape. - /// Requires: shape is ranked. - ArrayRef getDims() const { - assert(ranked && "requires ranked shape"); - return dims; - } - - /// Return whether the shape has a rank. - bool hasRank() const { return ranked; }; - - /// Return the element type component. - Type getElementType() const { return elementType; }; - - /// Return the raw attribute component. - Attribute getAttribute() const { return attr; }; - -private: - friend class ShapeAdaptor; - - ShapeStorageT dims; - Type elementType; - Attribute attr; - bool ranked{false}; -}; - /// Adaptor class to abstract the differences between whether value is from /// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute. class ShapeAdaptor { @@ -137,7 +79,7 @@ int64_t getNumElements() const; /// Returns whether valid (non-null) shape. - operator bool() const { return !val.isNull(); } + explicit operator bool() const { return !val.isNull(); } /// Dumps textual repesentation to stderr. void dump() const; @@ -148,6 +90,71 @@ PointerUnion val = nullptr; }; +/// ShapedTypeComponents that represents the components of a ShapedType. +/// The components consist of +/// - A ranked or unranked shape with the dimension specification match those +/// of ShapeType's getShape() (e.g., dynamic dimension represented using +/// ShapedType::kDynamicSize) +/// - A element type, may be unset (nullptr) +/// - A attribute, may be unset (nullptr) +/// Used by ShapedType type inferences. +class ShapedTypeComponents { + /// Internal storage type for shape. + using ShapeStorageT = SmallVector; + +public: + /// Default construction is an unranked shape. + ShapedTypeComponents() : elementType(nullptr), attr(nullptr){}; + ShapedTypeComponents(Type elementType) + : elementType(elementType), attr(nullptr), ranked(false) {} + ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) { + ranked = shapedType.hasRank(); + elementType = shapedType.getElementType(); + if (ranked) + dims = llvm::to_vector<4>(shapedType.getShape()); + } + ShapedTypeComponents(ShapeAdaptor adaptor) : attr(nullptr) { + ranked = adaptor.hasRank(); + elementType = adaptor.getElementType(); + if (ranked) + adaptor.getDims(*this); + } + template ::value>> + ShapedTypeComponents(Arg &&arg, Type elementType = nullptr, + Attribute attr = nullptr) + : dims(std::forward(arg)), elementType(elementType), attr(attr), + ranked(true) {} + ShapedTypeComponents(ArrayRef vec, Type elementType = nullptr, + Attribute attr = nullptr) + : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr), + ranked(true) {} + + /// Return the dimensions of the shape. + /// Requires: shape is ranked. + ArrayRef getDims() const { + assert(ranked && "requires ranked shape"); + return dims; + } + + /// Return whether the shape has a rank. + bool hasRank() const { return ranked; }; + + /// Return the element type component. + Type getElementType() const { return elementType; }; + + /// Return the raw attribute component. + Attribute getAttribute() const { return attr; }; + +private: + friend class ShapeAdaptor; + + ShapeStorageT dims; + Type elementType; + Attribute attr; + bool ranked{false}; +}; + /// Range of values and shapes (corresponding effectively to Shapes dialect's /// ValueShape type concept). // Currently this exposes the Value (of operands) and Type of the Value. This is