diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -16,9 +16,11 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/SmallVector.h" namespace mlir { @@ -39,19 +41,25 @@ public: /// Default construction is an unranked shape. - ShapedTypeComponents() : ranked(false), elementType(nullptr), attr(nullptr){}; + ShapedTypeComponents() : elementType(nullptr), attr(nullptr), ranked(false){}; ShapedTypeComponents(Type elementType) - : ranked(false), elementType(elementType), attr(nullptr) {} + : elementType(elementType), attr(nullptr), ranked(false) {} + ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) { + ranked = shapedType.hasRank(); + elementType = shapedType.getElementType(); + if (ranked) + dims = llvm::to_vector<4>(shapedType.getShape()); + } template ::value>> ShapedTypeComponents(Arg &&arg, Type elementType = nullptr, Attribute attr = nullptr) - : dims(std::forward(arg)), ranked(true), elementType(elementType), - attr(attr) {} + : dims(std::forward(arg)), elementType(elementType), attr(attr), + ranked(true) {} ShapedTypeComponents(ArrayRef vec, Type elementType = nullptr, Attribute attr = nullptr) - : dims(vec.begin(), vec.end()), ranked(true), elementType(elementType), - attr(attr) {} + : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr), + ranked(true) {} /// Return the dimensions of the shape. /// Requires: shape is ranked. @@ -70,24 +78,115 @@ Attribute getAttribute() const { return attr; }; private: + friend class ShapeAdaptor; + ShapeStorageT dims; - bool ranked; Type elementType; Attribute attr; + bool ranked; +}; + +/// Adaptor class to abstract the differences between whether value is from +/// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute. +class ShapeAdaptor { +public: + ShapeAdaptor(Type t) { + if (auto st = t.dyn_cast()) + val = st; + } + ShapeAdaptor(Attribute t) { + if (auto da = t.dyn_cast()) + val = da; + } + ShapeAdaptor(ShapedTypeComponents *components) : val(components) {} + ShapeAdaptor(ShapedTypeComponents &components) : val(&components) {} + + /// Returns whether the shape has a rank. + bool hasRank() const; + + /// Returns the element type. + Type getElementType() const; + + /// Populates the dimensions from shape referenced. + /// Requires: shape is ranked. + void getDims(SmallVectorImpl &res) const; + + /// Populates the dimensions of the ShapeTypeComponents. + /// Requires: shape is ranked. + void getDims(ShapedTypeComponents &res) const; + + /// Returns the size of the index'th dimension. + /// Requires: shape is ranked. + int64_t getDimSize(int index) const; + + /// Returns whether the index'th dimension is dynamic. + /// Requires: shape is ranked. + bool isDynamicDim(int index) const { + return ShapedType::isDynamic(getDimSize(index)); + } + + /// Returns whether the shape is fully static. + bool hasStaticShape() const; + + /// Returns the rank of the shape. + /// Requires: shape is ranked. + int64_t getRank() const; + + /// Returns the number of elements in the shape. + /// Requires: hasStaticShape + int64_t getNumElements() const; + + /// Returns whether valid (non-null) shape. + operator bool() const { return !val.isNull(); } + + /// Dumps textual repesentation to stderr. + void dump() const; + +private: + // Union storing either ShapedTypeComponents, ShapedType (stored as Type and + // casted), or DenseIntElementsAttribute (stored as Atrtribute). + PointerUnion val = nullptr; }; /// Range of values and shapes (corresponding effectively to Shapes dialect's /// ValueShape type concept). +// Currently this exposes the Value (of operands) and Type of the Value. This is +// not ideal as then one can accidentally reference an out of date shape. This +// is done to both enable gradual switch and also as OpAdaptor doesn't currently +// allow returning anything other than Value. class ValueShapeRange : public ValueRange::RangeBaseT { public: - ValueShapeRange(ValueRange values) : RangeBaseT(values) {} - template ::value>> - ValueShapeRange(Arg &&arg) - : ValueShapeRange(ValueRange(std::forward(arg))) {} + using ValueShapeMapFn = function_ref; + + ValueShapeRange(ValueRange values, ValueShapeMapFn operandShape = nullptr, + ValueShapeMapFn valueToShape = nullptr) + : RangeBaseT(values), operandShape(operandShape), + valueToShape(valueToShape) {} ValueShapeRange(const std::initializer_list &values) : ValueShapeRange(ValueRange(values)) {} + ValueShapeRange(const ValueShapeRange &other) : RangeBaseT(other) { + operandShape = other.operandShape; + valueToShape = other.valueToShape; + } + + /// Sets the Value to ShapeAdaptor mapping function and returns this. + ValueShapeRange &setValueToShapeMapping(ValueShapeMapFn fn) { + valueToShape = fn; + return *this; + } + + ValueShapeRange &setOperandShapeMapping(ValueShapeMapFn fn) { + operandShape = fn; + return *this; + } + + /// Returns the set Value to ShapeAdaptor mapping function. + ValueShapeMapFn getValueToShapeMapping() const { return valueToShape; } + ValueShapeMapFn getOperandShapeMapping() const { return operandShape; } + + // Accessors. + /// Returns the types of the values within this range. /// Note: This returns only the types of Values in the ValueRange and not a /// more refined type. @@ -97,7 +196,32 @@ auto getType() const { return getTypes(); } /// Returns the Values in the ValueRange. + /// To query the most up to date shape of a Value, query the shape + /// using getShape below rather than using the type of the Value. ValueRange getValues() const { return ValueRange(begin(), end()); }; + + /// Returns an argument as shape. If the argument is not constant or not a + /// shape, then the function returns a nullptr. + /// This will first query the valueToShape mapping (if set), before querying + /// the ValueRange. + ShapeAdaptor getValueAsShape(int index); + + /// Returns the shape of index'th operand. + // TODO: Update so that operator[] references these instead to avoid + // accidentally refering to less refined shape. + ShapeAdaptor getShape(int index) const; + + /// Returns the shape of the given Value. + ShapeAdaptor getShape(Value val) const; + +private: + // Mapping from Value to ShapedTypeComponents corresponding to shape of type + // of Value. + ValueShapeMapFn operandShape; + + // Mapping from Value to ShapedTypeComponents corresponding to constant Value + // if interpreted as shape. + ValueShapeMapFn valueToShape; }; namespace detail { diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -356,21 +356,21 @@ MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - ShapedType inputTy = operands[0].getType().cast(); + ShapeAdaptor inputShape = operands.getShape(0); IntegerAttr axis = attributes.get("axis").cast(); int32_t axisVal = axis.getValue().getSExtValue(); - if (!inputTy.hasRank()) { + if (!inputShape.hasRank()) { inferredReturnShapes.push_back(ShapedTypeComponents()); return success(); } SmallVector outShape; - outShape.reserve(inputTy.getRank() - 1); - for (int i = 0, s = inputTy.getRank(); i < s; i++) { + outShape.reserve(inputShape.getRank() - 1); + for (int i = 0, s = inputShape.getRank(); i < s; i++) { if (i == axisVal) continue; - outShape.push_back(inputTy.getDimSize(i)); + outShape.push_back(inputShape.getDimSize(i)); } inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); @@ -387,21 +387,21 @@ llvm::SmallVector outputShape; bool hasRankedInput = false; for (auto operand : operands) { - ShapedType operandTy = operand.getType().cast(); - if (!operandTy.hasRank()) + ShapeAdaptor operandShape = operands.getShape(operand); + if (!operandShape.hasRank()) continue; // Copy the Operand's rank. if (!hasRankedInput) - outputShape.resize(operandTy.getRank(), ShapedType::kDynamicSize); + outputShape.resize(operandShape.getRank(), ShapedType::kDynamicSize); // Copy shapes until the dim is non-dynamic. - for (int i = 0, s = operandTy.getRank(); i < s; i++) { - if (i == axis || operandTy.isDynamicDim(i)) + for (int i = 0, s = operandShape.getRank(); i < s; i++) { + if (i == axis || operandShape.isDynamicDim(i)) continue; if (outputShape[i] == ShapedType::kDynamicSize) - outputShape[i] = operandTy.getDimSize(i); - if (outputShape[i] != operandTy.getDimSize(i)) + outputShape[i] = operandShape.getDimSize(i); + if (outputShape[i] != operandShape.getDimSize(i)) return failure(); } @@ -416,16 +416,16 @@ // Determine the dimension size along the concatenation axis. int concatDimSize = 0; for (auto operand : operands) { - ShapedType operandTy = operand.getType().cast(); + ShapeAdaptor operandShape = operands.getShape(operand); // We need to know the length of the concatenation axis of all inputs to // determine the dimension size of the output shape. - if (!operandTy.hasRank() || operandTy.isDynamicDim(axis)) { + if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) { concatDimSize = ShapedType::kDynamicSize; break; } - concatDimSize += operandTy.getDimSize(axis); + concatDimSize += operandShape.getDimSize(axis); } outputShape[axis] = concatDimSize; @@ -438,25 +438,26 @@ MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - ShapedType inputTy = operands[0].getType().cast(); - ShapedType weightTy = operands[1].getType().cast(); - ShapedType biasTy = operands[2].getType().cast(); + ShapeAdaptor inputShape = operands.getShape(0); + ShapeAdaptor weightShape = operands.getShape(1); + ShapeAdaptor biasShape = operands.getShape(2); // All shapes are dynamic. SmallVector outShape; outShape.resize(2, ShapedType::kDynamicSize); - if (inputTy.hasRank()) { - outShape[0] = inputTy.getDimSize(0); + if (inputShape.hasRank()) { + outShape[0] = inputShape.getDimSize(0); } - if (weightTy.hasRank()) { - outShape[1] = weightTy.getDimSize(0); + if (weightShape.hasRank()) { + outShape[1] = weightShape.getDimSize(0); } - if (biasTy.hasRank()) { - outShape[1] = outShape[1] == ShapedType::kDynamicSize ? biasTy.getDimSize(0) - : outShape[1]; + if (biasShape.hasRank()) { + outShape[1] = outShape[1] == ShapedType::kDynamicSize + ? biasShape.getDimSize(0) + : outShape[1]; } inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); @@ -467,22 +468,23 @@ MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - ShapedType lhsTy = operands[0].getType().cast(); - ShapedType rhsTy = operands[1].getType().cast(); + ShapeAdaptor lhsShape = operands.getShape(0); + ShapeAdaptor rhsShape = operands.getShape(1); // All shapes are dynamic. SmallVector outShape; outShape.resize(3, ShapedType::kDynamicSize); - if (lhsTy.hasRank()) { - outShape[0] = lhsTy.getDimSize(0); - outShape[1] = lhsTy.getDimSize(1); + if (lhsShape.hasRank()) { + outShape[0] = lhsShape.getDimSize(0); + outShape[1] = lhsShape.getDimSize(1); } - if (rhsTy.hasRank()) { - outShape[0] = outShape[0] == ShapedType::kDynamicSize ? rhsTy.getDimSize(0) - : outShape[0]; - outShape[2] = rhsTy.getDimSize(2); + if (rhsShape.hasRank()) { + outShape[0] = outShape[0] == ShapedType::kDynamicSize + ? rhsShape.getDimSize(0) + : outShape[0]; + outShape[2] = rhsShape.getDimSize(2); } inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); @@ -493,26 +495,26 @@ MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - ShapedType inputTy = operands[0].getType().cast(); - ShapedType paddingTy = operands[1].getType().cast(); + ShapeAdaptor inputShape = operands.getShape(0); + ShapeAdaptor paddingShape = operands.getShape(1); SmallVector outputShape; // If both inputs have unknown shape, we cannot determine the shape of the // output. - if (!inputTy.hasRank() && !paddingTy.hasRank()) { + if (!inputShape.hasRank() && !paddingShape.hasRank()) { inferredReturnShapes.push_back(ShapedTypeComponents()); return success(); } // If the input rank is unknown we can info the output rank using the padding // shape's first dim. - if (!inputTy.hasRank()) { - if (paddingTy.isDynamicDim(0)) { + if (!inputShape.hasRank()) { + if (paddingShape.isDynamicDim(0)) { inferredReturnShapes.push_back(ShapedTypeComponents()); return success(); } - outputShape.resize(paddingTy.getDimSize(0), ShapedType::kDynamicSize); + outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamicSize); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } @@ -520,7 +522,7 @@ DenseIntElementsAttr paddings; // If the paddings value is not a constant, all dimensions must be dynamic. if (!matchPattern(operands[1], m_Constant(&paddings))) { - outputShape.resize(inputTy.getRank(), ShapedType::kDynamicSize); + outputShape.resize(inputShape.getRank(), ShapedType::kDynamicSize); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } @@ -530,14 +532,14 @@ paddingValues.push_back(val.getSExtValue()); } - outputShape.reserve(inputTy.getRank()); - for (int i = 0, s = inputTy.getRank(); i < s; i++) { - if (inputTy.isDynamicDim(i)) { + outputShape.reserve(inputShape.getRank()); + for (int i = 0, s = inputShape.getRank(); i < s; i++) { + if (inputShape.isDynamicDim(i)) { outputShape.push_back(ShapedType::kDynamicSize); continue; } - outputShape.push_back(inputTy.getDimSize(i) + paddingValues[i * 2] + + outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] + paddingValues[i * 2 + 1]); } @@ -549,7 +551,7 @@ MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - auto sizes = attributes.get("size").cast().getValue(); + ArrayAttr sizes = SliceOpAdaptor(operands, attributes).size(); SmallVector outputShape; outputShape.reserve(sizes.size()); for (auto val : sizes) { @@ -564,14 +566,15 @@ MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - ShapedType inputTy = operands[0].getType().cast(); + ShapeAdaptor inputShape = operands.getShape(0); - if (!inputTy.hasRank()) { + if (!inputShape.hasRank()) { inferredReturnShapes.push_back(ShapedTypeComponents()); return success(); } - inferredReturnShapes.push_back(inputTy.getShape()); + inferredReturnShapes.resize(1); + inputShape.getDims(inferredReturnShapes[0]); return success(); } @@ -579,10 +582,11 @@ MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - auto multiples = attributes.get("multiples").cast().getValue(); - ShapedType inputTy = operands[0].getType().cast(); + TileOpAdaptor adaptor(operands, attributes); + ArrayAttr multiples = adaptor.multiples(); + ShapeAdaptor inputShape = operands.getShape(0); SmallVector outputShape; - if (!inputTy.hasRank()) { + if (!inputShape.hasRank()) { outputShape.resize(multiples.size(), ShapedType::kDynamicSize); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); @@ -597,8 +601,8 @@ // Any non dynamic dimension can be multiplied to a known size. outputShape.reserve(multiples.size()); - for (int i = 0, s = inputTy.getRank(); i < s; i++) { - int dim = inputTy.getDimSize(i); + for (int i = 0, s = inputShape.getRank(); i < s; i++) { + int dim = inputShape.getDimSize(i); if (dim != ShapedType::kDynamicSize) dim *= multipleValues[i]; outputShape.push_back(dim); @@ -612,15 +616,16 @@ MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - ShapedType type = operands.front().getType().cast(); + ReshapeOpAdaptor adaptor(operands, attributes); + ShapeAdaptor inputShape = operands.getShape(0); - auto newShape = attributes.get("new_shape").cast(); + ArrayAttr newShape = adaptor.new_shape(); llvm::SmallVector newShapeValue; getI64Values(newShape, newShapeValue); // We cannot infer from the total number of elements so we must take the // shape attribute as exact. - if (!type.hasRank() || !type.hasStaticShape()) { + if (!inputShape.hasRank() || !inputShape.hasStaticShape()) { inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue)); return success(); } @@ -628,7 +633,7 @@ // Determine the number of elements covered by the slice of all static // dimensions. This allows us to infer the length of the remaining dynamic // dimension. - int64_t numElements = type.getNumElements(); + int64_t numElements = inputShape.getNumElements(); int64_t staticMul = 1; for (auto val : newShapeValue) { if (val != ShapedType::kDynamicSize) { @@ -650,12 +655,13 @@ MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - ShapedType inputTy = operands[0].getType().cast(); - ShapedType permsTy = operands[1].getType().cast(); + ShapeAdaptor inputShape = operands.getShape(0); + ShapeAdaptor permsShape = operands.getShape(1); // If input rank and permutation length is unknown, the output rank is // unknown. - if (!inputTy.hasRank() && (!permsTy.hasRank() || permsTy.isDynamicDim(0))) { + if (!inputShape.hasRank() && + (!permsShape.hasRank() || permsShape.isDynamicDim(0))) { inferredReturnShapes.push_back(ShapedTypeComponents()); return success(); } @@ -663,22 +669,22 @@ // Without the input dims we cannot determine the output dim sizes but we // can determine the output rank. SmallVector outputShape; - if (!inputTy.hasRank()) { - outputShape.resize(permsTy.getDimSize(0), ShapedType::kDynamicSize); + if (!inputShape.hasRank()) { + outputShape.resize(permsShape.getDimSize(0), ShapedType::kDynamicSize); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } // Rank-0 means no permutations matter. - if (inputTy.getRank() == 0) { + if (inputShape.getRank() == 0) { inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } // Check whether the input dimensions are all the same. bool allTheSame = true; - for (int i = 1, s = inputTy.getRank(); i < s; i++) { - if (inputTy.getDimSize(0) != inputTy.getDimSize(i)) { + for (int i = 1, s = inputShape.getRank(); i < s; i++) { + if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) { allTheSame = false; break; } @@ -687,24 +693,18 @@ // If all of the input dimensions are the same we don't care about the // permutation. if (allTheSame) { - outputShape.resize(inputTy.getRank(), inputTy.getDimSize(0)); + outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0)); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } - DenseIntElementsAttr perms; - outputShape.resize(inputTy.getRank(), ShapedType::kDynamicSize); + outputShape.resize(inputShape.getRank(), ShapedType::kDynamicSize); // If the permuations are a constant we can directly determine the output // shape. - if (matchPattern(operands[1], m_Constant(&perms))) { - llvm::SmallVector permValues; - for (auto val : perms) { - permValues.push_back(val.getSExtValue()); - } - - outputShape.reserve(inputTy.getRank()); - for (int i = 0, s = inputTy.getRank(); i < s; i++) { - outputShape[i] = inputTy.getDimSize(permValues[i]); + if (ShapeAdaptor permShape = operands.getValueAsShape(1)) { + outputShape.reserve(inputShape.getRank()); + for (int i = 0, s = inputShape.getRank(); i < s; i++) { + outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i)); } } @@ -719,16 +719,18 @@ llvm::SmallVector outputShape; outputShape.resize(3, ShapedType::kDynamicSize); - if (auto ty = operands[0].getType().dyn_cast()) { - outputShape[0] = ty.getDimSize(0); - outputShape[2] = ty.getDimSize(2); + ShapeAdaptor valuesShape = operands.getShape(0); + if (valuesShape.hasRank()) { + outputShape[0] = valuesShape.getDimSize(0); + outputShape[2] = valuesShape.getDimSize(2); } - if (auto ty = operands[1].getType().dyn_cast()) { + ShapeAdaptor indicesShape = operands.getShape(1); + if (indicesShape.hasRank()) { if (outputShape[0] == ShapedType::kDynamicSize) - outputShape[0] = ty.getDimSize(0); + outputShape[0] = indicesShape.getDimSize(0); if (outputShape[1] == ShapedType::kDynamicSize) - outputShape[1] = ty.getDimSize(1); + outputShape[1] = indicesShape.getDimSize(1); } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); @@ -739,24 +741,25 @@ MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { + ResizeOpAdaptor adaptor(operands, attributes); llvm::SmallVector outputShape; outputShape.resize(4, ShapedType::kDynamicSize); int32_t inHeight = ShapedType::kDynamicSize; int32_t inWidth = ShapedType::kDynamicSize; - if (auto ty = operands[0].getType().dyn_cast()) { - outputShape[0] = ty.getDimSize(0); - outputShape[3] = ty.getDimSize(3); + ShapeAdaptor inputShape = operands.getShape(adaptor.input()); + if (inputShape.hasRank()) { + outputShape[0] = inputShape.getDimSize(0); + outputShape[3] = inputShape.getDimSize(3); - inHeight = ty.getDimSize(1); - inWidth = ty.getDimSize(2); + inHeight = inputShape.getDimSize(1); + inWidth = inputShape.getDimSize(2); } - int32_t shift = - attributes.get("shift").cast().getValue().getSExtValue(); + int32_t shift = adaptor.shift().getValue().getSExtValue(); llvm::SmallVector newShape; - getI64Values(attributes.get("output_size").cast(), newShape); + getI64Values(adaptor.output_size(), newShape); outputShape[1] = newShape[0]; outputShape[2] = newShape[1]; @@ -764,10 +767,10 @@ llvm::SmallVector offsetInt; llvm::SmallVector strideFp; llvm::SmallVector offsetFp; - getI64Values(attributes.get("offset").cast(), offsetInt); - getF64Values(attributes.get("offset_fp").cast(), offsetFp); - getI64Values(attributes.get("stride").cast(), strideInt); - getF64Values(attributes.get("stride_fp").cast(), strideFp); + getI64Values(adaptor.offset(), offsetInt); + getF64Values(adaptor.offset_fp(), offsetFp); + getI64Values(adaptor.stride(), strideInt); + getF64Values(adaptor.stride_fp(), strideFp); // If we have a 0 zero in integers we know that the resize indexing needs to // be performed in floating point. Use the floating point varient to compute @@ -812,22 +815,25 @@ llvm::SmallVector outputShape; outputShape.resize(3, ShapedType::kDynamicSize); - if (auto ty = operands[0].getType().dyn_cast()) { - outputShape[0] = ty.getDimSize(0); - outputShape[1] = ty.getDimSize(1); - outputShape[2] = ty.getDimSize(2); + ShapeAdaptor valuesInShape = operands.getShape(0); + if (valuesInShape.hasRank()) { + outputShape[0] = valuesInShape.getDimSize(0); + outputShape[1] = valuesInShape.getDimSize(1); + outputShape[2] = valuesInShape.getDimSize(2); } - if (auto ty = operands[1].getType().dyn_cast()) { + ShapeAdaptor indicesShape = operands.getShape(1); + if (indicesShape.hasRank()) { if (outputShape[0] == ShapedType::kDynamicSize) - outputShape[0] = ty.getDimSize(0); + outputShape[0] = indicesShape.getDimSize(0); } - if (auto ty = operands[2].getType().dyn_cast()) { + ShapeAdaptor inputShape = operands.getShape(2); + if (inputShape.hasRank()) { if (outputShape[0] == ShapedType::kDynamicSize) - outputShape[0] = ty.getDimSize(0); + outputShape[0] = inputShape.getDimSize(0); if (outputShape[2] == ShapedType::kDynamicSize) - outputShape[2] = ty.getDimSize(2); + outputShape[2] = inputShape.getDimSize(2); } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); @@ -835,21 +841,16 @@ } static LogicalResult ReduceInferReturnTypes( - Value operand, IntegerAttr axis, + ShapeAdaptor operandShape, IntegerAttr axis, SmallVectorImpl &inferredReturnShapes) { - auto operandTy = operand.getType().cast(); - if (!operandTy.hasRank()) { + if (!operandShape.hasRank()) { inferredReturnShapes.push_back(ShapedTypeComponents()); return success(); } - int64_t axisVal = axis.getValue().getSExtValue(); SmallVector outputShape; - outputShape.reserve(operandTy.getRank()); - for (auto dim : operandTy.getShape()) { - outputShape.push_back(dim); - } - + operandShape.getDims(outputShape); + int64_t axisVal = axis.getValue().getSExtValue(); outputShape[axisVal] = 1; inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); @@ -861,7 +862,7 @@ ValueShapeRange operands, DictionaryAttr attributes, \ RegionRange regions, \ SmallVectorImpl &inferredReturnShapes) { \ - return ReduceInferReturnTypes(operands[0], \ + return ReduceInferReturnTypes(operands.getShape(0), \ attributes.get("axis").cast(), \ inferredReturnShapes); \ } @@ -874,26 +875,26 @@ REDUCE_SHAPE_INFER(tosa::ReduceSumOp) #undef REDUCE_SHAPE_INFER -static LogicalResult resolveBroadcastShape(ValueRange operands, +static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector &outShape) { int64_t outRank = 0; - for (auto operand : operands) { - auto type = operand.getType().cast(); - if (!type.hasRank()) + for (int i = 0, e = operands.size(); i != e; ++i) { + auto shape = operands.getShape(i); + if (!shape.hasRank()) { return failure(); - outRank = std::max(outRank, type.getRank()); + } + outRank = std::max(outRank, shape.getRank()); } outShape.resize(outRank, 1); - for (auto operand : operands) { - auto type = operand.getType().cast(); - auto shape = type.getShape(); - auto rankDiff = outShape.size() - shape.size(); + for (int i = 0, e = operands.size(); i != e; ++i) { + auto shape = operands.getShape(i); + auto rankDiff = outShape.size() - shape.getRank(); - for (size_t i = 0; i < shape.size(); i++) { + for (size_t i = 0, e = shape.getRank(); i < e; ++i) { auto dim1 = outShape[i + rankDiff]; - auto dim2 = shape[i]; + auto dim2 = shape.getDimSize(i); auto resolvedDim = dim1; if (dim1 == 1) { @@ -911,7 +912,7 @@ } static LogicalResult NAryInferReturnTypes( - ValueRange operands, + const ValueShapeRange &operands, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outShape; if (resolveBroadcastShape(operands, outShape).failed()) { @@ -973,24 +974,24 @@ #undef PRED_SHAPE_INFER static LogicalResult poolingInferReturnTypes( - ValueRange operands, DictionaryAttr attributes, + const ValueShapeRange &operands, DictionaryAttr attributes, SmallVectorImpl &inferredReturnShapes) { - RankedTensorType inputTy = operands[0].getType().dyn_cast(); + ShapeAdaptor inputShape = operands.getShape(0); llvm::SmallVector outputShape; outputShape.resize(4, -1); // We only know the rank if the input type is unranked. - if (!inputTy) { + if (!inputShape) { inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } // Batch and number of channels are identical for pooling layer. - outputShape[0] = inputTy.getDimSize(0); - outputShape[3] = inputTy.getDimSize(3); + outputShape[0] = inputShape.getDimSize(0); + outputShape[3] = inputShape.getDimSize(3); - int32_t height = inputTy.getDimSize(1); - int32_t width = inputTy.getDimSize(2); + int32_t height = inputShape.getDimSize(1); + int32_t width = inputShape.getDimSize(2); llvm::SmallVector kernel; llvm::SmallVector stride; @@ -1019,7 +1020,7 @@ ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(4, ShapedType::kDynamicSize); - Conv2DOp::Adaptor adaptor(operands.getValues()); + Conv2DOp::Adaptor adaptor(operands.getValues(), attributes); int32_t inputWidth = ShapedType::kDynamicSize; int32_t inputHeight = ShapedType::kDynamicSize; @@ -1027,23 +1028,27 @@ int32_t weightHeight = ShapedType::kDynamicSize; // Input shape describes input width/height and batch. - if (auto inputTy = adaptor.input().getType().dyn_cast()) { - outputShape[0] = inputTy.getDimSize(0); - inputHeight = inputTy.getDimSize(1); - inputWidth = inputTy.getDimSize(2); + + ShapeAdaptor inputShape = operands.getShape(adaptor.input()); + if (inputShape.hasRank()) { + outputShape[0] = inputShape.getDimSize(0); + inputHeight = inputShape.getDimSize(1); + inputWidth = inputShape.getDimSize(2); } // Weight shapes describes the filter width/height and the output channels. - if (auto weightTy = adaptor.weight().getType().dyn_cast()) { - outputShape[3] = weightTy.getDimSize(0); - weightHeight = weightTy.getDimSize(1); - weightWidth = weightTy.getDimSize(2); + ShapeAdaptor weightShape = operands.getShape(adaptor.weight()); + if (weightShape.hasRank()) { + outputShape[3] = weightShape.getDimSize(0); + weightHeight = weightShape.getDimSize(1); + weightWidth = weightShape.getDimSize(2); } // Bias shape can describe the output channels. - if (auto biasTy = adaptor.bias().getType().dyn_cast()) { + ShapeAdaptor biasShape = operands.getShape(adaptor.bias()); + if (biasShape.hasRank()) { outputShape[3] = ShapedType::isDynamic(outputShape[3]) - ? biasTy.getDimSize(0) + ? biasShape.getDimSize(0) : outputShape[3]; } @@ -1051,9 +1056,9 @@ llvm::SmallVector padding; llvm::SmallVector stride; - getI64Values(attributes.get("dilation").cast(), dilation); - getI64Values(attributes.get("pad").cast(), padding); - getI64Values(attributes.get("stride").cast(), stride); + getI64Values(adaptor.dilation(), dilation); + getI64Values(adaptor.pad(), padding); + getI64Values(adaptor.stride(), stride); if (!ShapedType::isDynamic(inputHeight) && !ShapedType::isDynamic(weightHeight)) { @@ -1080,7 +1085,7 @@ ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(5, ShapedType::kDynamicSize); - Conv2DOp::Adaptor adaptor(operands.getValues()); + Conv2DOp::Adaptor adaptor(operands.getValues(), attributes); int32_t inputWidth = ShapedType::kDynamicSize; int32_t inputHeight = ShapedType::kDynamicSize; @@ -1091,34 +1096,37 @@ int32_t weightDepth = ShapedType::kDynamicSize; // Input shape describes input width/height and batch. - if (auto inputTy = adaptor.input().getType().dyn_cast()) { - outputShape[0] = inputTy.getDimSize(0); - inputHeight = inputTy.getDimSize(1); - inputWidth = inputTy.getDimSize(2); - inputDepth = inputTy.getDimSize(3); + ShapeAdaptor inputShape = operands.getShape(adaptor.input()); + if (inputShape.hasRank()) { + outputShape[0] = inputShape.getDimSize(0); + inputHeight = inputShape.getDimSize(1); + inputWidth = inputShape.getDimSize(2); + inputDepth = inputShape.getDimSize(3); } // Weight shapes describes the filter width/height and the output channels. - if (auto weightTy = adaptor.weight().getType().dyn_cast()) { - outputShape[4] = weightTy.getDimSize(0); - weightHeight = weightTy.getDimSize(1); - weightWidth = weightTy.getDimSize(2); - weightDepth = weightTy.getDimSize(3); + ShapeAdaptor weightShape = operands.getShape(adaptor.weight()); + if (weightShape.hasRank()) { + outputShape[4] = weightShape.getDimSize(0); + weightHeight = weightShape.getDimSize(1); + weightWidth = weightShape.getDimSize(2); + weightDepth = weightShape.getDimSize(3); } // Bias shape can describe the output channels. - if (auto biasTy = adaptor.bias().getType().dyn_cast()) { + ShapeAdaptor biasShape = operands.getShape(adaptor.bias()); + if (biasShape.hasRank()) { outputShape[4] = - (outputShape[4] == -1) ? biasTy.getDimSize(0) : outputShape[4]; + (outputShape[4] == -1) ? biasShape.getDimSize(0) : outputShape[4]; } llvm::SmallVector dilation; llvm::SmallVector padding; llvm::SmallVector stride; - getI64Values(attributes.get("dilation").cast(), dilation); - getI64Values(attributes.get("pad").cast(), padding); - getI64Values(attributes.get("stride").cast(), stride); + getI64Values(adaptor.dilation(), dilation); + getI64Values(adaptor.pad(), padding); + getI64Values(adaptor.stride(), stride); if (!ShapedType::isDynamic(inputHeight) && !ShapedType::isDynamic(weightHeight)) { @@ -1167,7 +1175,7 @@ ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(4, ShapedType::kDynamicSize); - DepthwiseConv2DOp::Adaptor adaptor(operands.getValues()); + DepthwiseConv2DOp::Adaptor adaptor(operands.getValues(), attributes); int32_t inputWidth = ShapedType::kDynamicSize; int32_t inputHeight = ShapedType::kDynamicSize; @@ -1178,21 +1186,23 @@ int32_t depthChannels = ShapedType::kDynamicSize; // Input shape describes input width/height and batch. - if (auto inputTy = adaptor.input().getType().dyn_cast()) { - outputShape[0] = inputTy.getDimSize(0); - inputHeight = inputTy.getDimSize(1); - inputWidth = inputTy.getDimSize(2); - inputChannels = inputTy.getDimSize(3); + ShapeAdaptor inputShape = operands.getShape(adaptor.input()); + if (inputShape.hasRank()) { + outputShape[0] = inputShape.getDimSize(0); + inputHeight = inputShape.getDimSize(1); + inputWidth = inputShape.getDimSize(2); + inputChannels = inputShape.getDimSize(3); } // Weight shapes describes the filter width/height and the output channels. - if (auto weightTy = adaptor.weight().getType().dyn_cast()) { - weightHeight = weightTy.getDimSize(0); - weightWidth = weightTy.getDimSize(1); + ShapeAdaptor weightShape = operands.getShape(adaptor.weight()); + if (weightShape.hasRank()) { + weightHeight = weightShape.getDimSize(0); + weightWidth = weightShape.getDimSize(1); inputChannels = ShapedType::isDynamic(inputChannels) - ? weightTy.getDimSize(2) + ? weightShape.getDimSize(2) : inputChannels; - depthChannels = weightTy.getDimSize(3); + depthChannels = weightShape.getDimSize(3); } // If both inputChannels and depthChannels are available we can determine @@ -1203,9 +1213,10 @@ } // Bias shape can describe the output channels. - if (auto biasTy = adaptor.bias().getType().dyn_cast()) { + ShapeAdaptor biasShape = operands.getShape(adaptor.bias()); + if (biasShape.hasRank()) { outputShape[3] = ShapedType::isDynamic(outputShape[3]) - ? biasTy.getDimSize(0) + ? biasShape.getDimSize(0) : outputShape[3]; } @@ -1213,9 +1224,9 @@ llvm::SmallVector padding; llvm::SmallVector stride; - getI64Values(attributes.get("dilation").cast(), dilation); - getI64Values(attributes.get("pad").cast(), padding); - getI64Values(attributes.get("stride").cast(), stride); + getI64Values(adaptor.dilation(), dilation); + getI64Values(adaptor.pad(), padding); + getI64Values(adaptor.stride(), stride); if (!ShapedType::isDynamic(inputHeight) && !ShapedType::isDynamic(weightHeight)) { @@ -1241,9 +1252,9 @@ MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - TransposeConv2DOp::Adaptor adaptor(operands.getValues()); + TransposeConv2DOp::Adaptor adaptor(operands.getValues(), attributes); llvm::SmallVector outputShape; - getI64Values(attributes.get("out_shape").cast(), outputShape); + getI64Values(adaptor.out_shape(), outputShape); int32_t inputWidth = ShapedType::kDynamicSize; int32_t inputHeight = ShapedType::kDynamicSize; @@ -1251,27 +1262,30 @@ int32_t weightHeight = ShapedType::kDynamicSize; // Input shape describes input width/height and batch. - if (auto inputTy = adaptor.input().getType().dyn_cast()) { + ShapeAdaptor inputShape = operands.getShape(adaptor.input()); + if (inputShape.hasRank()) { outputShape[0] = ShapedType::isDynamic(outputShape[0]) - ? inputTy.getDimSize(0) + ? inputShape.getDimSize(0) : outputShape[0]; - inputHeight = inputTy.getDimSize(1); - inputWidth = inputTy.getDimSize(2); + inputHeight = inputShape.getDimSize(1); + inputWidth = inputShape.getDimSize(2); } // Weight shapes describes the filter width/height and the output channels. - if (auto weightTy = adaptor.filter().getType().dyn_cast()) { + ShapeAdaptor weightShape = operands.getShape(adaptor.input()); + if (weightShape.hasRank()) { outputShape[3] = ShapedType::isDynamic(outputShape[3]) - ? weightTy.getDimSize(0) + ? weightShape.getDimSize(0) : outputShape[3]; - weightHeight = weightTy.getDimSize(1); - weightWidth = weightTy.getDimSize(2); + weightHeight = weightShape.getDimSize(1); + weightWidth = weightShape.getDimSize(2); } // Bias shape can describe the output channels. - if (auto biasTy = adaptor.bias().getType().dyn_cast()) { + ShapeAdaptor biasShape = operands.getShape(adaptor.input()); + if (biasShape.hasRank()) { outputShape[3] = ShapedType::isDynamic(outputShape[3]) - ? biasTy.getDimSize(0) + ? biasShape.getDimSize(0) : outputShape[3]; } @@ -1279,9 +1293,9 @@ llvm::SmallVector padding; llvm::SmallVector stride; - getI64Values(attributes.get("dilation").cast(), dilation); - getI64Values(attributes.get("out_pad").cast(), padding); - getI64Values(attributes.get("stride").cast(), stride); + getI64Values(adaptor.dilation(), dilation); + getI64Values(adaptor.out_pad(), padding); + getI64Values(adaptor.stride(), stride); if (!ShapedType::isDynamic(inputHeight) && !ShapedType::isDynamic(weightHeight)) { @@ -1339,7 +1353,7 @@ } } - for (auto result : resultKnowledge) { + for (const ValueKnowledge &result : resultKnowledge) { if (result.hasRank) { inferredReturnShapes.push_back(ShapedTypeComponents(result.sizes)); } else { diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp @@ -25,6 +25,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/FormatVariadic.h" using namespace mlir; using namespace mlir::tosa; @@ -62,6 +63,21 @@ } void propagateShapesInRegion(Region ®ion) { + DenseMap shapesStorage; + auto setShapes = [&](Value val, Type t) { + if (auto st = t.dyn_cast()) + shapesStorage[val] = st; + else + shapesStorage[val] = t; + }; + auto operandShape = [&](Value val) -> ShapeAdaptor { + // Query the WIP mapping rather than the type if set. + auto it = shapesStorage.find(val); + if (it == shapesStorage.end()) + return nullptr; + return it->second; + }; + for (auto &block : region) { for (Operation &op : block) { if (op.getDialect()->getNamespace() != @@ -76,10 +92,12 @@ continue; SmallVector returnedShapes; + + ValueShapeRange range(op.getOperands(), operandShape); if (shapeInterface - .inferReturnTypeComponents( - op.getContext(), op.getLoc(), op.getOperands(), - op.getAttrDictionary(), op.getRegions(), returnedShapes) + .inferReturnTypeComponents(op.getContext(), op.getLoc(), range, + op.getAttrDictionary(), + op.getRegions(), returnedShapes) .succeeded()) { for (auto it : llvm::zip(op.getResults(), returnedShapes)) { Value result = std::get<0>(it); @@ -99,6 +117,7 @@ } // Determine the knowledge based on the output type. + // TODO: should also query WIP type probably Type resultTy = result.getType(); auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(resultTy); @@ -122,11 +141,20 @@ ValueKnowledge::join(currentKnowledge, inferredKnowledge); if (!newKnowledge) continue; - result.setType(newKnowledge.getType()); + setShapes(result, newKnowledge.getType()); } } } } + + // Actually update types with updated shape knowledge. + for (auto it : shapesStorage) { + auto result = it.second; + if (result.hasRank()) { + Type t = it.first.getType().cast().clone(result.getDims()); + it.first.setType(t); + } + } } /// Pass that performs shape propagation across TOSA operations. This includes diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -13,6 +13,8 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" +#include "llvm/Support/FormatVariadic.h" using namespace mlir; @@ -20,6 +22,160 @@ #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc" } // namespace mlir +bool ShapeAdaptor::hasRank() const { + if (val.isNull()) + return false; + if (auto t = val.dyn_cast()) + return t.cast().hasRank(); + if (val.is()) + return true; + return val.get()->hasRank(); +} + +Type ShapeAdaptor::getElementType() const { + if (val.isNull()) + return nullptr; + if (auto t = val.dyn_cast()) + return t.cast().getElementType(); + if (val.is()) + return nullptr; + return val.get()->getElementType(); +} + +void ShapeAdaptor::getDims(SmallVectorImpl &res) const { + assert(hasRank()); + if (auto t = val.dyn_cast()) { + ArrayRef vals = t.cast().getShape(); + res.assign(vals.begin(), vals.end()); + } else if (auto attr = val.dyn_cast()) { + auto dattr = attr.cast(); + res.clear(); + res.reserve(dattr.size()); + for (auto it : dattr.getIntValues()) + res.push_back(it.getSExtValue()); + } else { + auto vals = val.get()->getDims(); + res.assign(vals.begin(), vals.end()); + } +} + +void ShapeAdaptor::getDims(ShapedTypeComponents &res) const { + assert(hasRank()); + res.ranked = true; + getDims(res.dims); +} + +int64_t ShapeAdaptor::getDimSize(int index) const { + assert(hasRank()); + if (auto t = val.dyn_cast()) + return t.cast().getDimSize(index); + if (auto attr = val.dyn_cast()) + return attr.cast() + .getValue({static_cast(index)}) + .getSExtValue(); + auto *stc = val.get(); + return stc->getDims()[index]; +} + +int64_t ShapeAdaptor::getRank() const { + assert(hasRank()); + if (auto t = val.dyn_cast()) + return t.cast().getRank(); + if (auto attr = val.dyn_cast()) + return attr.cast().size(); + return val.get()->getDims().size(); +} + +bool ShapeAdaptor::hasStaticShape() const { + if (!hasRank()) + return false; + + if (auto t = val.dyn_cast()) + return t.cast().hasStaticShape(); + if (auto attr = val.dyn_cast()) { + auto dattr = attr.cast(); + for (auto index : dattr.getIntValues()) + if (ShapedType::isDynamic(index.getSExtValue())) + return false; + return true; + } + auto *stc = val.get(); + for (int64_t dim : stc->getDims()) + if (ShapedType::isDynamic(dim)) + return false; + return true; +} + +int64_t ShapeAdaptor::getNumElements() const { + assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); + + if (auto t = val.dyn_cast()) + return t.cast().getNumElements(); + + if (auto attr = val.dyn_cast()) { + auto dattr = attr.cast(); + int64_t num = 1; + for (auto index : dattr.getIntValues()) { + num *= index.getZExtValue(); + assert(num >= 0 && "integer overflow in element count computation"); + } + return num; + } + + auto *stc = val.get(); + int64_t num = 1; + for (int64_t dim : stc->getDims()) { + num *= dim; + assert(num >= 0 && "integer overflow in element count computation"); + } + return num; +} + +void ShapeAdaptor::dump() const { + if (!hasRank()) { + llvm::errs() << "<>\n"; + return; + } + + SmallVector dims; + getDims(dims); + auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string { + if (ShapedType::isDynamic(dim)) + return "?"; + return llvm::formatv("{0}", dim).str(); + }); + llvm::errs() << "rank = " << getRank() << " dims = ["; + llvm::interleave(mapped, llvm::errs(), "x"); + llvm::errs() << "]\n"; +} + +ShapeAdaptor ValueShapeRange::getValueAsShape(int index) { + Value val = operator[](index); + if (valueToShape) + if (ShapeAdaptor ret = valueToShape(val)) + return ret; + + DenseIntElementsAttr attr; + if (!matchPattern(val, m_Constant(&attr))) + return nullptr; + if (attr.getType().getRank() != 1) + return nullptr; + return attr; +} + +ShapeAdaptor ValueShapeRange::getShape(Value val) const { + if (operandShape) + if (ShapeAdaptor ret = operandShape(val)) + return ret; + return val.getType(); +} + +ShapeAdaptor ValueShapeRange::getShape(int index) const { + if (index < 0 || static_cast(index) >= size()) + return nullptr; + return getShape(operator[](index)); +} + LogicalResult mlir::detail::inferReturnTensorTypes( function_ref location, ValueShapeRange operands, diff --git a/mlir/unittests/Interfaces/CMakeLists.txt b/mlir/unittests/Interfaces/CMakeLists.txt --- a/mlir/unittests/Interfaces/CMakeLists.txt +++ b/mlir/unittests/Interfaces/CMakeLists.txt @@ -1,10 +1,13 @@ add_mlir_unittest(MLIRInterfacesTests DataLayoutInterfacesTest.cpp + InferTypeOpInterfaceTest.cpp ) target_link_libraries(MLIRInterfacesTests PRIVATE MLIRDataLayoutInterfaces MLIRDLTI + MLIRInferTypeOpInterface MLIRParser + MLIRStandard ) diff --git a/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp b/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp @@ -0,0 +1,103 @@ +//===- InferTypeOpInterfaceTest.cpp - Unit Test for type interface --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Parser.h" + +#include + +using namespace mlir; + +class ValueShapeRangeTest : public testing::Test { +protected: + void SetUp() override { + const char *ir = R"MLIR( + func @map(%arg : tensor<1xi64>) { + %0 = constant dense<[10]> : tensor<1xi64> + %1 = addi %arg, %0 : tensor<1xi64> + return + } + )MLIR"; + + registry.insert(); + ctx.appendDialectRegistry(registry); + module = parseSourceString(ir, &ctx); + mapFn = cast(module->front()); + } + + // Create ValueShapeRange on the addi operation. + ValueShapeRange addiRange() { + auto &fnBody = mapFn.body(); + return std::next(fnBody.front().begin())->getOperands(); + } + + DialectRegistry registry; + MLIRContext ctx; + OwningModuleRef module; + FuncOp mapFn; +}; + +TEST_F(ValueShapeRangeTest, ShapesFromValues) { + ValueShapeRange range = addiRange(); + + EXPECT_FALSE(range.getValueAsShape(0)); + ASSERT_TRUE(range.getValueAsShape(1)); + EXPECT_TRUE(range.getValueAsShape(1).hasRank()); + EXPECT_EQ(range.getValueAsShape(1).getRank(), 1); + EXPECT_EQ(range.getValueAsShape(1).getDimSize(0), 10); + EXPECT_EQ(range.getShape(1).getRank(), 1); + EXPECT_EQ(range.getShape(1).getDimSize(0), 1); +} + +TEST_F(ValueShapeRangeTest, MapValuesToShapes) { + ValueShapeRange range = addiRange(); + ShapedTypeComponents fixed(SmallVector{30}); + auto mapping = [&](Value val) -> ShapeAdaptor { + if (val == mapFn.getArgument(0)) + return &fixed; + return nullptr; + }; + range.setValueToShapeMapping(mapping); + + ASSERT_TRUE(range.getValueAsShape(0)); + EXPECT_TRUE(range.getValueAsShape(0).hasRank()); + EXPECT_EQ(range.getValueAsShape(0).getRank(), 1); + EXPECT_EQ(range.getValueAsShape(0).getDimSize(0), 30); + ASSERT_TRUE(range.getValueAsShape(1)); + EXPECT_TRUE(range.getValueAsShape(1).hasRank()); + EXPECT_EQ(range.getValueAsShape(1).getRank(), 1); + EXPECT_EQ(range.getValueAsShape(1).getDimSize(0), 10); +} + +TEST_F(ValueShapeRangeTest, SettingShapes) { + ShapedTypeComponents shape(SmallVector{10, 20}); + ValueShapeRange range = addiRange(); + auto mapping = [&](Value val) -> ShapeAdaptor { + if (val == mapFn.getArgument(0)) + return &shape; + return nullptr; + }; + range.setOperandShapeMapping(mapping); + + ASSERT_TRUE(range.getShape(0)); + EXPECT_EQ(range.getShape(0).getRank(), 2); + EXPECT_EQ(range.getShape(0).getDimSize(0), 10); + EXPECT_EQ(range.getShape(0).getDimSize(1), 20); + ASSERT_TRUE(range.getShape(1)); + EXPECT_EQ(range.getShape(1).getRank(), 1); + EXPECT_EQ(range.getShape(1).getDimSize(0), 1); + EXPECT_FALSE(range.getShape(2)); +}