diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td --- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td +++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td @@ -128,7 +128,9 @@ static constexpr LenType singleton() { return 1; } /// Character has a LEN value which is not a compile-time known constant. - static constexpr LenType unknownLen() { return mlir::ShapedType::kDynamic; } + static constexpr LenType unknownLen() { + return mlir::RankedShapedType::kDynamic; + } /// Character LEN is a runtime value. bool hasDynamicLen() { return getLen() == unknownLen(); } @@ -485,7 +487,7 @@ // The value `kDynamic` represents an unknown extent for a dimension static constexpr Extent getUnknownExtent() { - return mlir::ShapedType::kDynamic; + return mlir::RankedShapedType::kDynamic; } }]; } diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td b/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td --- a/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td +++ b/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td @@ -76,7 +76,7 @@ isPolymorphic()); } static constexpr int64_t getUnknownExtent() { - return mlir::ShapedType::kDynamic; + return mlir::RankedShapedType::kDynamic; } }]; diff --git a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp --- a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp +++ b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp @@ -410,8 +410,8 @@ auto affineApply = rewriter.create(acoOp.getLoc(), affineMap, indexArgs); auto arrayElementType = coordinateArrayElement(acoOp); - auto newType = - mlir::MemRefType::get({mlir::ShapedType::kDynamic}, arrayElementType); + auto newType = mlir::MemRefType::get({mlir::RankedShapedType::kDynamic}, + arrayElementType); auto arrayConvert = rewriter.create(acoOp.getLoc(), newType, acoOp.getMemref()); return std::make_pair(affineApply, arrayConvert); diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -159,44 +159,47 @@ // Shaped type. //===----------------------------------------------------------------------===// -/// Checks whether the given type is a Shaped type. +/// Checks whether the given type is a ShapedType. MLIR_CAPI_EXPORTED bool mlirTypeIsAShaped(MlirType type); +/// Checks whether the given type is a RankedShapeType. +MLIR_CAPI_EXPORTED bool mlirTypeIsARankedShaped(MlirType type); + /// Returns the element type of the shaped type. MLIR_CAPI_EXPORTED MlirType mlirShapedTypeGetElementType(MlirType type); -/// Checks whether the given shaped type is ranked. -MLIR_CAPI_EXPORTED bool mlirShapedTypeHasRank(MlirType type); - /// Returns the rank of the given ranked shaped type. -MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetRank(MlirType type); +MLIR_CAPI_EXPORTED int64_t mlirRankedShapedTypeGetRank(MlirType type); -/// Checks whether the given shaped type has a static shape. -MLIR_CAPI_EXPORTED bool mlirShapedTypeHasStaticShape(MlirType type); +/// Checks whether the given ranked shaped type has a static shape. +MLIR_CAPI_EXPORTED bool mlirRankedShapedTypeHasStaticShape(MlirType type); -/// Checks wither the dim-th dimension of the given shaped type is dynamic. -MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim); +/// Checks whether the dim-th dimension of the given ranked shaped type is +/// dynamic. +MLIR_CAPI_EXPORTED bool mlirRankedShapedTypeIsDynamicDim(MlirType type, + intptr_t dim); /// Returns the dim-th dimension of the given ranked shaped type. -MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type, - intptr_t dim); +MLIR_CAPI_EXPORTED int64_t mlirRankedShapedTypeGetDimSize(MlirType type, + intptr_t dim); /// Checks whether the given value is used as a placeholder for dynamic sizes -/// in shaped types. -MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicSize(int64_t size); +/// in ranked shaped types. +MLIR_CAPI_EXPORTED bool mlirRankedShapedTypeIsDynamicSize(int64_t size); -/// Returns the value indicating a dynamic size in a shaped type. Prefer -/// mlirShapedTypeIsDynamicSize to direct comparisons with this value. -MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicSize(void); +/// Returns the value indicating a dynamic size in a ranked shaped type. Prefer +/// mlirRankedShapedTypeIsDynamicSize to direct comparisons with this value. +MLIR_CAPI_EXPORTED int64_t mlirRankedShapedTypeGetDynamicSize(void); /// Checks whether the given value is used as a placeholder for dynamic strides -/// and offsets in shaped types. -MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val); - -/// Returns the value indicating a dynamic stride or offset in a shaped type. -/// Prefer mlirShapedTypeGetDynamicStrideOrOffset to direct comparisons with -/// this value. -MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(void); +/// and offsets in ranked shaped types. +MLIR_CAPI_EXPORTED bool +mlirRankedShapedTypeIsDynamicStrideOrOffset(int64_t val); + +/// Returns the value indicating a dynamic stride or offset in a ranked shaped +/// type. Prefer mlirRankedShapedTypeGetDynamicStrideOrOffset to direct +/// comparisons with this value. +MLIR_CAPI_EXPORTED int64_t mlirRankedShapedTypeGetDynamicStrideOrOffset(void); //===----------------------------------------------------------------------===// // Vector type. diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -125,7 +125,7 @@ ArrayRef shape = getType().getShape(); return std::count_if( shape.begin(), shape.begin() + idx, - [&](int64_t size) { return ShapedType::isDynamic(size); }); + [&](int64_t size) { return RankedShapedType::isDynamic(size); }); } // Return the Value of the dynamic size of the tensor at dimension diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -307,7 +307,7 @@ /*defaultImplementation=*/[{ assert(opOperand->getOwner() == this->getOperation()); if (auto shapedType = - opOperand->get().getType().template dyn_cast()) + opOperand->get().getType().template dyn_cast()) return shapedType.getRank(); return 0; }] @@ -359,7 +359,7 @@ /*defaultImplementation=*/[{ assert(opOperand->getOwner() == this->getOperation()); if (auto shapedType = - opOperand->get().getType().template dyn_cast()) + opOperand->get().getType().template dyn_cast()) return shapedType.getShape(); return {}; }] @@ -544,7 +544,7 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return llvm::any_of(getStaticShape(), ShapedType::isDynamic); + return llvm::any_of(getStaticShape(), RankedShapedType::isDynamic); }] >, InterfaceMethod< diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -234,7 +234,7 @@ //===----------------------------------------------------------------------===// def TensorOrMemref : - AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">; + AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::RankedShapedType">; def MapOp : LinalgStructuredBase_Op<"map", [ DeclareOpInterfaceMethods, diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1828,8 +1828,8 @@ The representation based on offsets, sizes and strides support a partially-static specification via attributes specified through the `static_offsets`, `static_sizes` and `static_strides` arguments. A special - sentinel value ShapedType::kDynamic encodes that the corresponding entry has - a dynamic value. + sentinel value RankedShapedType::kDynamic encodes that the corresponding + entry has a dynamic value. A subview operation may additionally reduce the rank of the resulting view by removing dimensions that are statically known to be of size 1. diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -35,7 +35,7 @@ /// Alias type for extent tensors. RankedTensorType getExtentTensorType(MLIRContext *ctx, - int64_t rank = ShapedType::kDynamic); + int64_t rank = RankedShapedType::kDynamic); // Check if a type is an extent tensor, e.g., tensor. bool isExtentTensorType(Type); diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h @@ -54,7 +54,7 @@ assert((!isIdentity() || getDimRank() == lvlRank) && "Rank mismatch"); } - SparseTensorType(ShapedType stp, SparseTensorEncodingAttr enc) + SparseTensorType(RankedShapedType stp, SparseTensorEncodingAttr enc) : SparseTensorType( RankedTensorType::get(stp.getShape(), stp.getElementType(), enc)) {} @@ -181,7 +181,7 @@ ArrayRef getDimShape() const { return rtp.getShape(); } /// Safely looks up the requested dimension-DynSize. If you intend - /// to check the result with `ShapedType::isDynamic`, then see the + /// to check the result with `RankedShapedType::isDynamic`, then see the /// `getStaticDimSize` method instead. DynSize getDynamicDimSize(Dimension d) const { assert(d < getDimRank() && "Dimension is out of bounds"); @@ -192,8 +192,8 @@ /// sizes to `std::nullopt`. std::optional getStaticDimSize(Dimension d) const { const DynSize sh = getDynamicDimSize(d); - return ShapedType::isDynamic(sh) ? std::nullopt - : std::optional(sh); + return RankedShapedType::isDynamic(sh) ? std::nullopt + : std::optional(sh); } /// Returns true if no dimension has dynamic size. @@ -208,7 +208,7 @@ bool isDynamicDim(Dimension d) const { // We don't use `rtp.isDynamicDim(d)` because we want the // OOB error message to be consistent with `getDynamicDimSize`. - return ShapedType::isDynamic(getDynamicDimSize(d)); + return RankedShapedType::isDynamic(getDynamicDimSize(d)); } /// Returns the number of dimensions which have dynamic sizes. diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -281,8 +281,8 @@ The representation based on offsets, sizes and strides support a partially-static specification via attributes specified through the `static_offsets`, `static_sizes` and `static_strides` arguments. A special - sentinel value ShapedType::kDynamic encodes that the corresponding entry has - a dynamic value. + sentinel value RankedShapedType::kDynamic encodes that the corresponding + entry has a dynamic value. After buffer allocation, the "extract_slice" op is expected to lower into a memref.subview op. @@ -769,8 +769,8 @@ The representation based on offsets, sizes and strides support a partially-static specification via attributes specified through the `static_offsets`, `static_sizes` and `static_strides` arguments. A special - sentinel value ShapedType::kDynamic encodes that the corresponding entry has - a dynamic value. + sentinel value RankedShapedType::kDynamic encodes that the corresponding + entry has a dynamic value. After buffer allocation, the "insert_slice" op is expected to lower into a memref.subview op. @@ -1266,7 +1266,7 @@ unsigned numDynamic = 0; unsigned count = staticAttrs.size(); for (unsigned idx = 0; idx < count; ++idx) { - if (ShapedType::isDynamic(staticAttrs[idx])) + if (RankedShapedType::isDynamic(staticAttrs[idx])) res.push_back(values[numDynamic++]); else res.push_back(builder.getI64IntegerAttr(staticAttrs[idx])); @@ -1378,8 +1378,8 @@ The representation based on offsets, sizes and strides support a partially-static specification via attributes specified through the `static_offsets`, `static_sizes` and `static_strides` arguments. A special - sentinel value ShapedType::kDynamic encodes that the corresponding entry has - a dynamic value. + sentinel value RankedShapedType::kDynamic encodes that the corresponding + entry has a dynamic value. After buffer allocation, the "parallel_insert_slice" op is expected to lower into a memref.subview op. diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h --- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h @@ -48,10 +48,10 @@ std::optional> checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef params) { - SmallVector dynTypes; + SmallVector dynTypes; SmallVector dynamicDims; for (const Value ¶m : params) { - auto paramTy = param.getType().cast(); + auto paramTy = param.getType().cast(); if (!paramTy.hasStaticShape()) dynTypes.push_back(paramTy); } @@ -59,8 +59,9 @@ if (dynTypes.empty()) return dynamicDims; - for (const ShapedType &dynTy : dynTypes) { - if (llvm::any_of(dynTy.getShape().drop_front(), ShapedType::isDynamic)) { + for (const RankedShapedType &dynTy : dynTypes) { + if (llvm::any_of(dynTy.getShape().drop_front(), + RankedShapedType::isDynamic)) { (void)rewriter.notifyMatchFailure( op, "input can only be dynamic for batch size"); return std::nullopt; diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h --- a/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h @@ -45,10 +45,10 @@ static ValueKnowledge getKnowledgeFromType(Type type) { ValueKnowledge result = getPessimisticValueState(); if (auto shapedType = type.dyn_cast()) { - if (shapedType.hasRank()) { + if (auto rankedShapedType = dyn_cast(shapedType)) { result.hasRank = true; - result.sizes.reserve(shapedType.getRank()); - for (auto dim : shapedType.getShape()) + result.sizes.reserve(rankedShapedType.getRank()); + for (auto dim : rankedShapedType.getShape()) result.sizes.push_back(dim); } result.dtype = shapedType.getElementType(); @@ -111,14 +111,14 @@ return result; result.hasRank = true; - result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamic); + result.sizes.resize(lhs.sizes.size(), RankedShapedType::kDynamic); for (auto i : llvm::seq(0, result.sizes.size())) { int64_t lhsSize = lhs.sizes[i]; int64_t rhsSize = rhs.sizes[i]; int64_t &resultSize = result.sizes[i]; - if (lhsSize == ShapedType::kDynamic) { + if (lhsSize == RankedShapedType::kDynamic) { resultSize = rhsSize; - } else if (rhsSize == ShapedType::kDynamic) { + } else if (rhsSize == RankedShapedType::kDynamic) { resultSize = lhsSize; } else if (lhsSize == rhsSize) { resultSize = lhsSize; @@ -155,7 +155,7 @@ } result.hasRank = true; - result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamic); + result.sizes.resize(lhs.sizes.size(), RankedShapedType::kDynamic); for (int i = 0, e = lhs.sizes.size(); i < e; i++) { if (lhs.sizes[i] == rhs.sizes[i]) { result.sizes[i] = lhs.sizes[i]; diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -68,7 +68,8 @@ /// the target type when possible. Return std::nullopt when this computation /// failed. std::optional> -getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType); +getReassociationIndicesForReshape(RankedShapedType sourceType, + RankedShapedType targetType); /// Returns the reassociation maps to collapse `sourceShape` to `targetShape` if /// possible. @@ -155,8 +156,9 @@ ArrayRef reassociationMaps, bool isExpandingReshape); template -static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType, - ShapedType expandedType, +static LogicalResult verifyReshapeLikeShapes(OpTy op, + RankedShapedType collapsedType, + RankedShapedType expandedType, bool isExpandingReshape) { return reshapeLikeShapesAreCompatible( [&](const Twine &msg) { return op->emitOpError(msg); }, @@ -235,8 +237,8 @@ if (!expandOp) return failure(); - ShapedType srcType = expandOp.getSrcType(); - ShapedType resultType = collapseOp.getResultType(); + RankedShapedType srcType = expandOp.getSrcType(); + RankedShapedType resultType = collapseOp.getResultType(); if (hasNonIdentityLayout(collapseOp.getSrc().getType()) || hasNonIdentityLayout(expandOp.getSrc().getType()) || @@ -301,8 +303,8 @@ if (!collapseOp) return failure(); - ShapedType srcType = collapseOp.getSrcType(); - ShapedType resultType = expandOp.getResultType(); + RankedShapedType srcType = collapseOp.getSrcType(); + RankedShapedType resultType = expandOp.getResultType(); if (hasNonIdentityLayout(expandOp.getSrc().getType()) || hasNonIdentityLayout(collapseOp.getSrc().getType()) || diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -105,8 +105,8 @@ // ArrayRef valueOrAttrVec); /// Return a vector of OpFoldResults with the same size a staticValues, but -/// all elements for which ShapedType::isDynamic is true, will be replaced by -/// dynamicValues. +/// all elements for which RankedShapedType::isDynamic is true, will be replaced +/// by dynamicValues. SmallVector getMixedValues(ArrayRef staticValues, ValueRange dynamicValues, Builder &b); diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -134,7 +134,7 @@ /// Build the default minor identity map suitable for a vector transfer. This /// also handles the case memref<... x vector<...>> -> vector<...> in which the /// rank of the identity map must take the vector element type into account. -AffineMap getTransferMinorIdentityMap(ShapedType shapedType, +AffineMap getTransferMinorIdentityMap(RankedShapedType shapedType, VectorType vectorType); /// Return true if the transfer_write fully writes the data accessed by the diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td --- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td @@ -160,7 +160,7 @@ }]>, InterfaceMethod<[{ Returns the shaped type of the elements attribute. - }], "::mlir::ShapedType", "getType"> + }], "::mlir::RankedShapedType", "getType"> ]; string ElementsAttrInterfaceAccessors = [{ @@ -312,7 +312,7 @@ bool isValidIndex(ArrayRef index) const { return isValidIndex(*this, index); } - static bool isValidIndex(ShapedType type, ArrayRef index); + static bool isValidIndex(RankedShapedType type, ArrayRef index); static bool isValidIndex(ElementsAttr elementsAttr, ArrayRef index); diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -103,7 +103,8 @@ /// Each element attribute value is expected to be an element of 'type'. /// 'type' must be a vector or tensor with static shape. If the element of /// `type` is non-integer/index/float it is assumed to be a string type. - static DenseElementsAttr get(ShapedType type, ArrayRef values); + static DenseElementsAttr get(RankedShapedType type, + ArrayRef values); /// Constructs a dense integer elements attribute from an array of integer /// or floating-point values. Each value is expected to be the same bitwidth @@ -112,7 +113,8 @@ template ::is_integer || is_valid_cpp_fp_type::value>> - static DenseElementsAttr get(const ShapedType &type, ArrayRef values) { + static DenseElementsAttr get(const RankedShapedType &type, + ArrayRef values) { const char *data = reinterpret_cast(values.data()); return getRawIntOrFloat( type, ArrayRef(data, values.size() * sizeof(T)), sizeof(T), @@ -124,7 +126,7 @@ typename = std::enable_if_t::is_integer || is_valid_cpp_fp_type::value || detail::is_complex_t::value>> - static DenseElementsAttr get(const ShapedType &type, T value) { + static DenseElementsAttr get(const RankedShapedType &type, T value) { return get(type, llvm::ArrayRef(value)); } @@ -136,7 +138,8 @@ typename = std::enable_if_t::value && (std::numeric_limits::is_integer || is_valid_cpp_fp_type::value)>> - static DenseElementsAttr get(const ShapedType &type, ArrayRef values) { + static DenseElementsAttr get(const RankedShapedType &type, + ArrayRef values) { const char *data = reinterpret_cast(values.data()); return getRawComplex(type, ArrayRef(data, values.size() * sizeof(T)), sizeof(T), std::numeric_limits::is_integer, @@ -144,43 +147,44 @@ } /// Overload of the above 'get' method that is specialized for boolean values. - static DenseElementsAttr get(ShapedType type, ArrayRef values); + static DenseElementsAttr get(RankedShapedType type, ArrayRef values); /// Overload of the above 'get' method that is specialized for StringRef /// values. - static DenseElementsAttr get(ShapedType type, ArrayRef values); + static DenseElementsAttr get(RankedShapedType type, + ArrayRef values); /// Constructs a dense integer elements attribute from an array of APInt /// values. Each APInt value is expected to have the same bitwidth as the /// element type of 'type'. 'type' must be a vector or tensor with static /// shape. - static DenseElementsAttr get(ShapedType type, ArrayRef values); + static DenseElementsAttr get(RankedShapedType type, ArrayRef values); /// Constructs a dense complex elements attribute from an array of APInt /// values. Each APInt value is expected to have the same bitwidth as the /// element type of 'type'. 'type' must be a vector or tensor with static /// shape. - static DenseElementsAttr get(ShapedType type, + static DenseElementsAttr get(RankedShapedType type, ArrayRef> values); /// Constructs a dense float elements attribute from an array of APFloat /// values. Each APFloat value is expected to have the same bitwidth as the /// element type of 'type'. 'type' must be a vector or tensor with static /// shape. - static DenseElementsAttr get(ShapedType type, ArrayRef values); + static DenseElementsAttr get(RankedShapedType type, ArrayRef values); /// Constructs a dense complex elements attribute from an array of APFloat /// values. Each APFloat value is expected to have the same bitwidth as the /// element type of 'type'. 'type' must be a vector or tensor with static /// shape. - static DenseElementsAttr get(ShapedType type, + static DenseElementsAttr get(RankedShapedType type, ArrayRef> values); /// Construct a dense elements attribute for an initializer_list of values. /// Each value is expected to be the same bitwidth of the element type of /// 'type'. 'type' must be a vector or tensor with static shape. template - static DenseElementsAttr get(const ShapedType &type, + static DenseElementsAttr get(const RankedShapedType &type, const std::initializer_list &list) { return get(type, ArrayRef(list)); } @@ -198,7 +202,7 @@ /// - For bitwidth = 1: Packed into 8bit bytes with bits corresponding to /// the linear order of the shape type from MSB to LSB, padded to on the /// right. - static DenseElementsAttr getFromRawBuffer(ShapedType type, + static DenseElementsAttr getFromRawBuffer(RankedShapedType type, ArrayRef rawBuffer); /// Returns true if the given buffer is a valid raw buffer for the given type. @@ -210,7 +214,7 @@ /// /// User code should be prepared for additional, conformant patterns to be /// identified as splats in the future. - static bool isValidRawBuffer(ShapedType type, ArrayRef rawBuffer, + static bool isValidRawBuffer(RankedShapedType type, ArrayRef rawBuffer, bool &detectedSplat); //===--------------------------------------------------------------------===// @@ -590,7 +594,7 @@ /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor /// with static shape. - ShapedType getType() const; + RankedShapedType getType() const; /// Return the element type of this DenseElementsAttr. Type getElementType() const; @@ -611,12 +615,12 @@ /// Return a new DenseElementsAttr that has the same data as the current /// attribute, but has been reshaped to 'newType'. The new type must have the /// same total number of elements as well as element type. - DenseElementsAttr reshape(ShapedType newType); + DenseElementsAttr reshape(RankedShapedType newType); /// Return a new DenseElementsAttr that has the same data as the current /// attribute, but with a different shape for a splat type. The new type must /// have the same element type. - DenseElementsAttr resizeSplat(ShapedType newType); + DenseElementsAttr resizeSplat(RankedShapedType newType); /// Return a new DenseElementsAttr that has the same data as the current /// attribute, but has bitcast elements to 'newElType'. The new type must have @@ -656,14 +660,15 @@ /// Overload of the raw 'get' method that asserts that the given type is of /// complex type. This method is used to verify type invariants that the /// templatized 'get' method cannot. - static DenseElementsAttr getRawComplex(ShapedType type, ArrayRef data, + static DenseElementsAttr getRawComplex(RankedShapedType type, + ArrayRef data, int64_t dataEltSize, bool isInt, bool isSigned); /// Overload of the raw 'get' method that asserts that the given type is of /// integer or floating-point type. This method is used to verify type /// invariants that the templatized 'get' method cannot. - static DenseElementsAttr getRawIntOrFloat(ShapedType type, + static DenseElementsAttr getRawIntOrFloat(RankedShapedType type, ArrayRef data, int64_t dataEltSize, bool isInt, bool isSigned); @@ -781,7 +786,7 @@ /// resource, but may be changed if necessary to ensure uniqueness during /// insertion. static DenseResourceElementsAttrBase - get(ShapedType type, StringRef blobName, AsmResourceBlob blob); + get(RankedShapedType type, StringRef blobName, AsmResourceBlob blob); /// Return the data of this attribute as an ArrayRef if it is present, /// returns std::nullopt otherwise. @@ -910,12 +915,12 @@ /// Get an instance of a DenseFPElementsAttr with the given arguments. This /// simply wraps the DenseElementsAttr::get calls. template - static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg) { + static DenseFPElementsAttr get(const RankedShapedType &type, Arg &&arg) { return DenseElementsAttr::get(type, llvm::ArrayRef(arg)) .template cast(); } template - static DenseFPElementsAttr get(const ShapedType &type, + static DenseFPElementsAttr get(const RankedShapedType &type, const std::initializer_list &list) { return DenseElementsAttr::get(type, list) .template cast(); @@ -952,12 +957,12 @@ /// Get an instance of a DenseIntElementsAttr with the given arguments. This /// simply wraps the DenseElementsAttr::get calls. template - static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg) { + static DenseIntElementsAttr get(const RankedShapedType &type, Arg &&arg) { return DenseElementsAttr::get(type, llvm::ArrayRef(arg)) .template cast(); } template - static DenseIntElementsAttr get(const ShapedType &type, + static DenseIntElementsAttr get(const RankedShapedType &type, const std::initializer_list &list) { return DenseElementsAttr::get(type, list) .template cast(); @@ -1034,7 +1039,7 @@ namespace mlir { -/// Given a list of strides (in which ShapedType::kDynamic +/// Given a list of strides (in which RankedShapedType::kDynamic /// represents a dynamic value), return the single result AffineMap which /// represents the linearized strided layout map. Dimensions correspond to the /// offset followed by the strides in order. Symbols are inserted for each diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -246,8 +246,10 @@ dense<[10.0, 11.0]> : tensor<2xf32> ``` }]; - let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type, - "ArrayRef":$rawData); + let parameters = (ins + AttributeSelfTypeParameter<"", "RankedShapedType">:$type, + "ArrayRef":$rawData + ); let extraClassDeclaration = [{ using DenseElementsAttr::empty; using DenseElementsAttr::getNumElements; @@ -292,7 +294,7 @@ static void convertEndianOfArrayRefForBEmachine(ArrayRef inRawData, MutableArrayRef outRawData, - ShapedType type); + RankedShapedType type); /// Convert endianess of input for big-endian(BE) machines. The number of /// elements of `inRawData` is `numElements`, and each element has @@ -314,7 +316,7 @@ /// /// If the `values` array only has a single element, then this constructs /// splat of that value. - static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth, + static DenseElementsAttr getRaw(RankedShapedType type, size_t storageWidth, ArrayRef values); /// Constructs a dense elements attribute from an array of raw APInt values. @@ -323,7 +325,7 @@ /// /// If the `values` array only has a single element, then this constructs /// splat of that value. - static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth, + static DenseElementsAttr getRaw(RankedShapedType type, size_t storageWidth, ArrayRef values); /// Get or create a new dense elements attribute instance with the given raw @@ -331,19 +333,19 @@ /// /// If the `values` array only has a single element, then this constructs /// splat of that value. - static DenseElementsAttr getRaw(ShapedType type, ArrayRef data); + static DenseElementsAttr getRaw(RankedShapedType type, ArrayRef data); /// Overload of the raw 'get' method that asserts that the given type is of /// complex type. This method is used to verify type invariants that the /// templatized 'get' method cannot. - static DenseElementsAttr getRawComplex(ShapedType type, ArrayRef data, + static DenseElementsAttr getRawComplex(RankedShapedType type, ArrayRef data, int64_t dataEltSize, bool isInt, bool isSigned); /// Overload of the raw 'get' method that asserts that the given type is of /// integer or floating-point type. This method is used to verify type /// invariants that the templatized 'get' method cannot. - static DenseElementsAttr getRawIntOrFloat(ShapedType type, + static DenseElementsAttr getRawIntOrFloat(RankedShapedType type, ArrayRef data, int64_t dataEltSize, bool isInt, bool isSigned); @@ -386,10 +388,12 @@ dense<["example1", "example2"]> : tensor<2x!foo.string> ``` }]; - let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type, - "ArrayRef":$value); + let parameters = (ins + AttributeSelfTypeParameter<"", "RankedShapedType">:$type, + "ArrayRef":$value + ); let builders = [ - AttrBuilderWithInferredContext<(ins "ShapedType":$type, + AttrBuilderWithInferredContext<(ins "RankedShapedType":$type, "ArrayRef":$values), [{ return $_get(type.getContext(), type, values, /* isSplat */(values.size() == 1)); @@ -460,12 +464,12 @@ ``` }]; let parameters = (ins - AttributeSelfTypeParameter<"", "ShapedType">:$type, + AttributeSelfTypeParameter<"", "RankedShapedType">:$type, ResourceHandleParameter<"DenseResourceElementsHandle">:$rawHandle ); let builders = [ AttrBuilderWithInferredContext<(ins - "ShapedType":$type, "DenseResourceElementsHandle":$handle + "RankedShapedType":$type, "DenseResourceElementsHandle":$handle )> ]; let extraClassDeclaration = [{ @@ -476,7 +480,7 @@ /// for the key of the new handle for the `blob` resource, but may be /// changed if necessary to ensure uniqueness during insertion. static DenseResourceElementsAttr get( - ShapedType type, StringRef blobName, AsmResourceBlob blob + RankedShapedType type, StringRef blobName, AsmResourceBlob blob ); public: @@ -839,11 +843,13 @@ ``` }]; - let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type, - "DenseIntElementsAttr":$indices, - "DenseElementsAttr":$values); + let parameters = (ins + AttributeSelfTypeParameter<"", "RankedShapedType">:$type, + "DenseIntElementsAttr":$indices, + "DenseElementsAttr":$values + ); let builders = [ - AttrBuilderWithInferredContext<(ins "ShapedType":$type, + AttrBuilderWithInferredContext<(ins "RankedShapedType":$type, "DenseElementsAttr":$indices, "DenseElementsAttr":$values), [{ assert(indices.getType().getElementType().isInteger(64) && @@ -986,7 +992,7 @@ Strides must be positive and the offset must be non-negative. Both the strides and the offset may be _dynamic_, i.e. their value may not be known at compile time. This is expressed as a `?` in the assembly syntax and as - `ShapedType::kDynamic` in the code. Stride and offset values + `RankedShapedType::kDynamic` in the code. Stride and offset values must satisfy the constraints above at runtime, the behavior is undefined otherwise. diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td --- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td @@ -48,14 +48,12 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> { let cppNamespace = "::mlir"; let description = [{ - This interface provides a common API for interacting with multi-dimensional - container types. These types contain a shape and an element type. + This interface provides a common API for interacting with ranked or unranked + container types. These types have an element type and may have a shape. - A shape is a list of sizes corresponding to the dimensions of the container. - If the number of dimensions in the shape is unknown, the shape is "unranked". - If the number of dimensions is known, the shape "ranked". The sizes of the - dimensions of the shape must be positive, or kDynamic (in which case the - size of the dimension is dynamic, or not statically known). + If the number of dimensions in the shape is unknown, the shape is + "unranked". If the number of dimensions is known, the shape "ranked". Ranked + types should implement `RankedShapedType`, a sub-interface of `ShapedType`. }]; let methods = [ InterfaceMethod<[{ @@ -71,108 +69,121 @@ Returns the element type of this shaped type. }], "::mlir::Type", "getElementType">, + ]; + let extraSharedClassDeclaration = [{ + /// Return a clone of this type with the given new shape and element type. + auto clone(::llvm::ArrayRef shape, Type elementType) { + return $_type.cloneWith(shape, elementType); + } + /// Return a clone of this type with the given new shape. + auto clone(::llvm::ArrayRef shape) { + return $_type.cloneWith(shape, $_type.getElementType()); + } + /// Return a clone of this type with the given new element type. + auto clone(::mlir::Type elementType) { + return $_type.cloneWith(/*shape=*/std::nullopt, elementType); + } - InterfaceMethod<[{ - Returns if this type is ranked, i.e. it has a known number of dimensions. - }], - "bool", "hasRank">, + /// If an element type is an integer or a float, return its width. + /// Otherwise, abort. + unsigned getElementTypeBitWidth() const { + return $_type.getElementType().getIntOrFloatBitWidth(); + } + }]; +} + +//===----------------------------------------------------------------------===// +// RankedShapedType +//===----------------------------------------------------------------------===// +def RankedShapedTypeInterface + : TypeInterface<"RankedShapedType", [ShapedTypeInterface]> { + let cppNamespace = "::mlir"; + let description = [{ + This interface provides a common API for interacting with multi-dimensional + container types that have a known rank. These types contain a shape and an + element type. + + A shape is a list of sizes corresponding to the dimensions of the container. + The sizes of the dimensions of the shape must be positive, or kDynamic (in + which case the size of the dimension is dynamic). + }]; + let methods = [ InterfaceMethod<[{ - Returns the shape of this type if it is ranked, otherwise asserts. - }], - "::llvm::ArrayRef", "getShape">, + Returns the shape of this type. + }], "::llvm::ArrayRef", "getShape">, ]; let extraClassDeclaration = [{ static constexpr int64_t kDynamic = std::numeric_limits::min(); - /// Whether the given dimension size indicates a dynamic dimension. + /// Return "true" if the given dimension size indicates a dynamic dimension. static constexpr bool isDynamic(int64_t dValue) { - return dValue == kDynamic; + return dValue == kDynamic; } - /// Whether the given shape has any size that indicates a dynamic dimension. + /// Return "true" if the given shape has any size that indicates a dynamic + /// dimension. static bool isDynamicShape(ArrayRef dSizes) { return any_of(dSizes, [](int64_t dSize) { return isDynamic(dSize); }); } - /// Return the number of elements present in the given shape. + /// Return the number of elements present in the given shape. Asserts that + /// all dimensions are static. static int64_t getNumElements(ArrayRef shape); }]; let extraSharedClassDeclaration = [{ - /// Return a clone of this type with the given new shape and element type. - auto clone(::llvm::ArrayRef shape, Type elementType) { - return $_type.cloneWith(shape, elementType); - } - /// Return a clone of this type with the given new shape. - auto clone(::llvm::ArrayRef shape) { - return $_type.cloneWith(shape, $_type.getElementType()); - } - /// Return a clone of this type with the given new element type. - auto clone(::mlir::Type elementType) { - return $_type.cloneWith(/*shape=*/std::nullopt, elementType); - } - - /// If an element type is an integer or a float, return its width. Otherwise, - /// abort. - unsigned getElementTypeBitWidth() const { - return $_type.getElementType().getIntOrFloatBitWidth(); - } - - /// If this is a ranked type, return the rank. Otherwise, abort. + /// Return the rank of this type. int64_t getRank() const { - assert($_type.hasRank() && "cannot query rank of unranked shaped type"); return $_type.getShape().size(); } /// If it has static shape, return the number of elements. Otherwise, abort. int64_t getNumElements() const { - assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); - return ::mlir::ShapedType::getNumElements($_type.getShape()); + assert(hasStaticShape() + && "cannot get element count of dynamically shaped type"); + return ::mlir::RankedShapedType::getNumElements($_type.getShape()); } - /// Returns true if this dimension has a dynamic size (for ranked types); - /// aborts for unranked types. + /// Return "true" if this dimension has a dynamic size. bool isDynamicDim(unsigned idx) const { assert(idx < getRank() && "invalid index for shaped type"); - return ::mlir::ShapedType::isDynamic($_type.getShape()[idx]); + return ::mlir::RankedShapedType::isDynamic($_type.getShape()[idx]); } - /// Returns if this type has a static shape, i.e. if the type is ranked and - /// all dimensions have known size (>= 0). + /// Return "true" if this type has a static shape, i.e., if all dimensions + /// have known size (>= 0). bool hasStaticShape() const { - return $_type.hasRank() && - !::mlir::ShapedType::isDynamicShape($_type.getShape()); + return !::mlir::RankedShapedType::isDynamicShape($_type.getShape()); } - /// Returns if this type has a static shape and the shape is equal to - /// `shape` return true. + /// Return "true" if this type has a static shape and the shape is equal to + /// `shape`. bool hasStaticShape(::llvm::ArrayRef shape) const { return hasStaticShape() && $_type.getShape() == shape; } - /// If this is a ranked type, return the number of dimensions with dynamic - /// size. Otherwise, abort. + /// Return the number of dimensions with dynamic size. int64_t getNumDynamicDims() const { - return llvm::count_if($_type.getShape(), ::mlir::ShapedType::isDynamic); + return llvm::count_if($_type.getShape(), ::mlir::RankedShapedType::isDynamic); } - /// If this is ranked type, return the size of the specified dimension. - /// Otherwise, abort. + /// Return the size of the specified dimension. int64_t getDimSize(unsigned idx) const { assert(idx < getRank() && "invalid index for shaped type"); return $_type.getShape()[idx]; } - /// Returns the position of the dynamic dimension relative to just the dynamic - /// dimensions, given its `index` within the shape. + /// Return the position of the dynamic dimension relative to just the + /// dynamic dimensions, given its `index` within the shape. unsigned getDynamicDimIndex(unsigned index) const { assert(index < getRank() && "invalid index"); - assert(::mlir::ShapedType::isDynamic(getDimSize(index)) && "invalid index"); + assert(::mlir::RankedShapedType::isDynamic(getDimSize(index)) + && "invalid index"); return llvm::count_if($_type.getShape().take_front(index), - ::mlir::ShapedType::isDynamic); + ::mlir::RankedShapedType::isDynamic); } }]; } diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -87,9 +87,6 @@ /// Returns if this type is ranked, i.e. it has a known number of dimensions. bool hasRank() const; - /// Returns the shape of this tensor type. - ArrayRef getShape() const; - /// Clone this type with the given shape and element type. If the /// provided shape is `None`, the current shape of the type is used. TensorType cloneWith(std::optional> shape, @@ -123,9 +120,6 @@ /// Returns if this type is ranked, i.e. it has a known number of dimensions. bool hasRank() const; - /// Returns the shape of this memref type. - ArrayRef getShape() const; - /// Clone this type with the given shape and element type. If the /// provided shape is `None`, the current shape of the type is used. BaseMemRefType cloneWith(std::optional> shape, diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -376,9 +376,8 @@ // MemRefType //===----------------------------------------------------------------------===// -def Builtin_MemRef : Builtin_Type<"MemRef", [ - ShapedTypeInterface - ], "BaseMemRefType"> { +def Builtin_MemRef + : Builtin_Type<"MemRef", [RankedShapedTypeInterface], "BaseMemRefType"> { let summary = "Shaped reference to a region of memory"; let description = [{ Syntax: @@ -629,15 +628,15 @@ "unsigned":$memorySpaceInd)> ]; let extraClassDeclaration = [{ - using ShapedType::Trait::clone; - using ShapedType::Trait::getElementTypeBitWidth; - using ShapedType::Trait::getRank; - using ShapedType::Trait::getNumElements; - using ShapedType::Trait::isDynamicDim; - using ShapedType::Trait::hasStaticShape; - using ShapedType::Trait::getNumDynamicDims; - using ShapedType::Trait::getDimSize; - using ShapedType::Trait::getDynamicDimIndex; + using ShapedType::Trait::clone; + using ShapedType::Trait::getElementTypeBitWidth; + using RankedShapedType::Trait::getRank; + using RankedShapedType::Trait::getNumElements; + using RankedShapedType::Trait::isDynamicDim; + using RankedShapedType::Trait::hasStaticShape; + using RankedShapedType::Trait::getNumDynamicDims; + using RankedShapedType::Trait::getDimSize; + using RankedShapedType::Trait::getDynamicDimIndex; /// This is a builder type that keeps local references to arguments. /// Arguments that are passed into the builder must outlive the builder. @@ -711,9 +710,8 @@ // RankedTensorType //===----------------------------------------------------------------------===// -def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [ - ShapedTypeInterface - ], "TensorType"> { +def Builtin_RankedTensor + : Builtin_Type<"RankedTensor", [RankedShapedTypeInterface], "TensorType"> { let summary = "Multi-dimensional array with a fixed number of dimensions"; let description = [{ Syntax: @@ -794,15 +792,15 @@ }]> ]; let extraClassDeclaration = [{ - using ShapedType::Trait::clone; - using ShapedType::Trait::getElementTypeBitWidth; - using ShapedType::Trait::getRank; - using ShapedType::Trait::getNumElements; - using ShapedType::Trait::isDynamicDim; - using ShapedType::Trait::hasStaticShape; - using ShapedType::Trait::getNumDynamicDims; - using ShapedType::Trait::getDimSize; - using ShapedType::Trait::getDynamicDimIndex; + using ShapedType::Trait::clone; + using ShapedType::Trait::getElementTypeBitWidth; + using RankedShapedType::Trait::getRank; + using RankedShapedType::Trait::getNumElements; + using RankedShapedType::Trait::isDynamicDim; + using RankedShapedType::Trait::hasStaticShape; + using RankedShapedType::Trait::getNumDynamicDims; + using RankedShapedType::Trait::getDimSize; + using RankedShapedType::Trait::getDynamicDimIndex; /// This is a builder type that keeps local references to arguments. /// Arguments that are passed into the builder must outlive the builder. @@ -883,9 +881,8 @@ // UnrankedMemRefType //===----------------------------------------------------------------------===// -def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [ - ShapedTypeInterface - ], "BaseMemRefType"> { +def Builtin_UnrankedMemRef + : Builtin_Type<"UnrankedMemRef", [ShapedTypeInterface], "BaseMemRefType"> { let summary = "Shaped reference, with unknown rank, to a region of memory"; let description = [{ Syntax: @@ -933,15 +930,6 @@ let extraClassDeclaration = [{ using ShapedType::Trait::clone; using ShapedType::Trait::getElementTypeBitWidth; - using ShapedType::Trait::getRank; - using ShapedType::Trait::getNumElements; - using ShapedType::Trait::isDynamicDim; - using ShapedType::Trait::hasStaticShape; - using ShapedType::Trait::getNumDynamicDims; - using ShapedType::Trait::getDimSize; - using ShapedType::Trait::getDynamicDimIndex; - - ArrayRef getShape() const { return std::nullopt; } /// [deprecated] Returns the memory space in old raw integer representation. /// New `Attribute getMemorySpace()` method should be used instead. @@ -955,9 +943,8 @@ // UnrankedTensorType //===----------------------------------------------------------------------===// -def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [ - ShapedTypeInterface - ], "TensorType"> { +def Builtin_UnrankedTensor + : Builtin_Type<"UnrankedTensor", [ShapedTypeInterface], "TensorType"> { let summary = "Multi-dimensional array with unknown dimensions"; let description = [{ Syntax: @@ -986,15 +973,6 @@ let extraClassDeclaration = [{ using ShapedType::Trait::clone; using ShapedType::Trait::getElementTypeBitWidth; - using ShapedType::Trait::getRank; - using ShapedType::Trait::getNumElements; - using ShapedType::Trait::isDynamicDim; - using ShapedType::Trait::hasStaticShape; - using ShapedType::Trait::getNumDynamicDims; - using ShapedType::Trait::getDimSize; - using ShapedType::Trait::getDynamicDimIndex; - - ArrayRef getShape() const { return std::nullopt; } }]; let skipDefaultBuilders = 1; let genVerifyDecl = 1; @@ -1004,7 +982,8 @@ // VectorType //===----------------------------------------------------------------------===// -def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> { +def Builtin_Vector + : Builtin_Type<"Vector", [RankedShapedTypeInterface], "Type"> { let summary = "Multi-dimensional SIMD vector type"; let description = [{ Syntax: diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -253,7 +253,8 @@ // For a ShapedType, verify that it has a static shape. def HasStaticShapePred : - CPred<"$_self.cast<::mlir::ShapedType>().hasStaticShape()">; + And<[CPred<"$_self.isa<::mlir::RankedShapedType>()">, + CPred<"$_self.cast<::mlir::RankedShapedType>().hasStaticShape()">]>; // Whether a type is a TupleType. def IsTupleTypePred : CPred<"$_self.isa<::mlir::TupleType>()">; @@ -548,20 +549,20 @@ descr # " of " # AnyTypeOf.summary # " values", cppClassName>; // Whether a shaped type is ranked. -def HasRankPred : CPred<"$_self.cast<::mlir::ShapedType>().hasRank()">; +def HasRankPred : CPred<"$_self.isa<::mlir::RankedShapedType>()">; // Whether a shaped type has one of the specified ranks. class HasAnyRankOfPred ranks> : And<[ HasRankPred, Or().getRank() + CPred<[{$_self.cast<::mlir::RankedShapedType>().getRank() == }] # rank>)>]>; // Whether a shaped type has a rank greater than or equal of the specified rank. class HasRankGreaterOrEqualPred : And<[ HasRankPred, - CPred<[{$_self.cast<::mlir::ShapedType>().getRank() >= }] # rank> + CPred<[{$_self.cast<::mlir::RankedShapedType>().getRank() >= }] # rank> ]>; // Vector types. @@ -1464,8 +1465,7 @@ CPred<"$_self.isa<::mlir::DenseFPElementsAttr>() &&" "$_self.cast<::mlir::DenseFPElementsAttr>().getType()." "getElementType().isF" # width # "() && " - // Check that this is ranked and has the specified shape. - "$_self.cast<::mlir::DenseFPElementsAttr>().getType().hasRank() && " + // Check that this has the specified shape. "$_self.cast<::mlir::DenseFPElementsAttr>().getType().getShape() == " "::mlir::ArrayRef({" # !interleave(dims, ", ") # "})">, width # "-bit float elements attribute of shape [" # @@ -2432,13 +2432,15 @@ // TODO: Improve the autogenerated error messages. class Rank : - StrFunc<"$" # name # ".getType().cast<::mlir::ShapedType>().getRank()">; + StrFunc<"$" # name # ".getType().cast<::mlir::RankedShapedType>()" + ".getRank()">; class Shape : - StrFunc<"$" # name # ".getType().cast<::mlir::ShapedType>().getShape()">; + StrFunc<"$" # name # ".getType().cast<::mlir::RankedShapedType>()" + ".getShape()">; class ElementCount : - StrFunc<"$" # name # ".getType().cast<::mlir::ShapedType>()" + StrFunc<"$" # name # ".getType().cast<::mlir::RankedShapedType>()" ".getNumElements()">; class ElementType : StrFunc<"getElementTypeOrSelf($" # name # ")">; diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -70,7 +70,7 @@ /// Returns whether the index'th dimension is dynamic. /// Requires: shape is ranked. bool isDynamicDim(int index) const { - return ShapedType::isDynamic(getDimSize(index)); + return RankedShapedType::isDynamic(getDimSize(index)); } /// Returns whether the shape is fully static. @@ -114,10 +114,10 @@ ShapedTypeComponents(Type elementType) : elementType(elementType), attr(nullptr), ranked(false) {} ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) { - ranked = shapedType.hasRank(); + ranked = isa(shapedType); elementType = shapedType.getElementType(); if (ranked) - dims = llvm::to_vector<4>(shapedType.getShape()); + dims = llvm::to_vector<4>(cast(shapedType).getShape()); } ShapedTypeComponents(ShapeAdaptor adaptor) : attr(nullptr) { ranked = adaptor.hasRank(); diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td --- a/mlir/include/mlir/Interfaces/VectorInterfaces.td +++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td @@ -150,13 +150,13 @@ /*defaultImplementation=*/ >, InterfaceMethod< - /*desc=*/"Return the ShapedType.", - /*retTy=*/"::mlir::ShapedType", + /*desc=*/"Return the RankedShapedType.", + /*retTy=*/"::mlir::RankedShapedType", /*methodName=*/"getShapedType", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/ - "return $_op.getSource().getType().template cast<::mlir::ShapedType>();" + "return $_op.getSource().getType().template cast<::mlir::RankedShapedType>();" >, InterfaceMethod< /*desc=*/"Return the VectorType.", diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -50,7 +50,7 @@ `getArrayAttrMaxRanks()`[0] (resp. [1], [2]). 3. if an entry of `static_offsets` (resp. `static_sizes`, `static_strides`) is equal to a special sentinel value, namely - `ShapedType::kDynamic`, then the corresponding entry is a dynamic + `RankedShapedType::kDynamic`, then the corresponding entry is a dynamic offset (resp. size, stride). 4. a variadic `offset` (resp. `sizes`, `strides`) operand must be present for each dynamic offset (resp. size, stride). @@ -204,7 +204,7 @@ /*args=*/(ins "unsigned":$idx), /*methodBody=*/"", /*defaultImplementation=*/[{ - return ::mlir::ShapedType::isDynamic(static_offsets()[idx]); + return ::mlir::RankedShapedType::isDynamic(static_offsets()[idx]); }] >, InterfaceMethod< @@ -214,7 +214,7 @@ /*args=*/(ins "unsigned":$idx), /*methodBody=*/"", /*defaultImplementation=*/[{ - return ::mlir::ShapedType::isDynamic(static_sizes()[idx]); + return ::mlir::RankedShapedType::isDynamic(static_sizes()[idx]); }] >, InterfaceMethod< @@ -224,7 +224,7 @@ /*args=*/(ins "unsigned":$idx), /*methodBody=*/"", /*defaultImplementation=*/[{ - return ::mlir::ShapedType::isDynamic(static_strides()[idx]); + return ::mlir::RankedShapedType::isDynamic(static_strides()[idx]); }] >, InterfaceMethod< @@ -280,7 +280,7 @@ assert($_op.isDynamicOffset(idx) && "expected dynamic offset"); auto numDynamic = getNumDynamicEntriesUpToIdx( static_offsets(), - ::mlir::ShapedType::isDynamic, + ::mlir::RankedShapedType::isDynamic, idx); return $_op.getOffsetSizeAndStrideStartOperandIndex() + numDynamic; }] @@ -297,7 +297,7 @@ /*defaultImplementation=*/[{ assert($_op.isDynamicSize(idx) && "expected dynamic size"); auto numDynamic = getNumDynamicEntriesUpToIdx( - static_sizes(), ::mlir::ShapedType::isDynamic, idx); + static_sizes(), ::mlir::RankedShapedType::isDynamic, idx); return $_op.getOffsetSizeAndStrideStartOperandIndex() + offsets().size() + numDynamic; }] @@ -315,7 +315,7 @@ assert($_op.isDynamicStride(idx) && "expected dynamic stride"); auto numDynamic = getNumDynamicEntriesUpToIdx( static_strides(), - ::mlir::ShapedType::isDynamic, + ::mlir::RankedShapedType::isDynamic, idx); return $_op.getOffsetSizeAndStrideStartOperandIndex() + offsets().size() + sizes().size() + numDynamic; diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -472,7 +472,7 @@ /// Build a dense attribute instance with the parsed elements and the given /// shaped type. - DenseElementsAttr getAttr(SMLoc loc, ShapedType type); + DenseElementsAttr getAttr(SMLoc loc, RankedShapedType type); ArrayRef getShape() const { return shape; } @@ -489,7 +489,7 @@ DenseElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy); /// Build a Dense attribute with hex data for the given type. - DenseElementsAttr getHexAttr(SMLoc loc, ShapedType type); + DenseElementsAttr getHexAttr(SMLoc loc, RankedShapedType type); /// Parse a single element, returning failure if it isn't a valid element /// literal. For example: @@ -538,7 +538,8 @@ /// Build a dense attribute instance with the parsed elements and the given /// shaped type. -DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) { +DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, + RankedShapedType type) { Type eltType = type.getElementType(); // Check to see if we parse the literal from a hex string. @@ -709,7 +710,8 @@ } /// Build a Dense attribute with hex data for the given type. -DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, ShapedType type) { +DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, + RankedShapedType type) { Type elementType = type.getElementType(); if (!elementType.isIntOrIndexOrFloat() && !elementType.isa()) { p.emitError(loc) @@ -1039,7 +1041,7 @@ /// elements-literal-type ::= vector-type | ranked-tensor-type /// /// This method also checks the type has static shape. -ShapedType Parser::parseElementsLiteralType(Type type) { +RankedShapedType Parser::parseElementsLiteralType(Type type) { // If the user didn't provide a type, parse the colon type for the literal. if (!type) { if (parseToken(Token::colon, "expected ':'")) @@ -1048,15 +1050,13 @@ return nullptr; } - auto sType = type.dyn_cast(); - if (!sType) { - emitError("elements literal must be a shaped type"); + auto sType = type.dyn_cast(); + if (!sType || !sType.hasStaticShape()) { + emitError( + "elements literal type must be a ranked shaped type with static shape"); return nullptr; } - if (!sType.hasStaticShape()) - return (emitError("elements literal type must have static shape"), nullptr); - return sType; } @@ -1072,7 +1072,7 @@ // of the type. Type indiceEltType = builder.getIntegerType(64); if (consumeIf(Token::greater)) { - ShapedType type = parseElementsLiteralType(attrType); + RankedShapedType type = parseElementsLiteralType(attrType); if (!type) return nullptr; @@ -1114,7 +1114,7 @@ // 2-dimensional shape where the second dimension is the rank of the type. // Given that the parsed indices is a splat, we know that we only have one // indice and thus one for the first dimension. - ShapedType indicesType; + RankedShapedType indicesType; if (indiceParser.getShape().empty()) { indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); } else { @@ -1152,7 +1152,7 @@ // must fit into int64_t limits. auto parseStrideOrOffset = [&]() -> std::optional { if (consumeIf(Token::question)) - return ShapedType::kDynamic; + return RankedShapedType::kDynamic; SMLoc loc = getToken().getLoc(); auto emitWrongTokenError = [&] { diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -262,7 +262,7 @@ /// Parse a dense elements attribute. Attribute parseDenseElementsAttr(Type attrType); - ShapedType parseElementsLiteralType(Type type); + RankedShapedType parseElementsLiteralType(Type type); /// Parse a dense resource elements attribute. Attribute parseDenseResourceElementsAttr(Type attrType); diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -528,7 +528,7 @@ if (consumeIf(Token::question)) { if (!allowDynamic) return emitError(loc, "expected static shape"); - dimensions.push_back(ShapedType::kDynamic); + dimensions.push_back(RankedShapedType::kDynamic); } else { int64_t value; if (failed(parseIntegerInDimensionList(value))) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -669,10 +669,10 @@ message.append(py::repr(py::cast(elementAttr))); throw SetPyError(PyExc_ValueError, message); } - if (!mlirTypeIsAShaped(shapedType) || - !mlirShapedTypeHasStaticShape(shapedType)) { + if (!mlirTypeIsARankedShaped(shapedType) || + !mlirRankedShapedTypeHasStaticShape(shapedType)) { std::string message = - "Expected a static ShapedType for the shaped_type parameter: "; + "Expected a static RankedShapedType for the shaped_type parameter: "; message.append(py::repr(py::cast(shapedType))); throw SetPyError(PyExc_ValueError, message); } @@ -810,7 +810,7 @@ template py::buffer_info bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) { - intptr_t rank = mlirShapedTypeGetRank(shapedType); + intptr_t rank = mlirRankedShapedTypeGetRank(shapedType); // Prepare the data for the buffer_info. // Buffer is configured for read-only access below. Type *data = static_cast( @@ -818,7 +818,7 @@ // Prepare the shape for the buffer_info. SmallVector shape; for (intptr_t i = 0; i < rank; ++i) - shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); + shape.push_back(mlirRankedShapedTypeGetDimSize(shapedType, i)); // Prepare the strides for the buffer_info. SmallVector strides; if (mlirDenseElementsAttrIsSplat(*this)) { @@ -828,7 +828,7 @@ for (intptr_t i = 1; i < rank; ++i) { intptr_t strideFactor = 1; for (intptr_t j = i; j < rank; ++j) - strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); + strideFactor *= mlirRankedShapedTypeGetDimSize(shapedType, j); strides.push_back(sizeof(Type) * strideFactor); } strides.push_back(sizeof(Type)); @@ -1071,7 +1071,7 @@ c.def_static( "get_fully_dynamic", [](int64_t rank, DefaultingPyMlirContext ctx) { - auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset(); + auto dynamic = mlirRankedShapedTypeGetDynamicStrideOrOffset(); std::vector strides(rank); std::fill(strides.begin(), strides.end(), dynamic); MlirAttribute attr = mlirStridedLayoutAttrGet( diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -323,6 +323,12 @@ using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { + c.def_property_readonly( + "has_rank", + [](PyShapedType &self) -> bool { + return mlirTypeIsARankedShaped(self); + }, + "Returns whether the given shaped type is ranked."); c.def_property_readonly( "element_type", [](PyShapedType &self) { @@ -330,91 +336,84 @@ return PyType(self.getContext(), t); }, "Returns the element type of the shaped type."); - c.def_property_readonly( - "has_rank", - [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, - "Returns whether the given shaped type is ranked."); + } +}; + +class PyRankedShapedType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedShaped; + static constexpr const char *pyClassName = "RankedShapedType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { c.def_property_readonly( "rank", - [](PyShapedType &self) { - self.requireHasRank(); - return mlirShapedTypeGetRank(self); - }, + [](PyShapedType &self) { return mlirRankedShapedTypeGetRank(self); }, "Returns the rank of the given ranked shaped type."); c.def_property_readonly( "has_static_shape", [](PyShapedType &self) -> bool { - return mlirShapedTypeHasStaticShape(self); + return mlirRankedShapedTypeHasStaticShape(self); }, - "Returns whether the given shaped type has a static shape."); + "Returns whether the given ranked shaped type has a static shape."); c.def( "is_dynamic_dim", [](PyShapedType &self, intptr_t dim) -> bool { - self.requireHasRank(); - return mlirShapedTypeIsDynamicDim(self, dim); + return mlirRankedShapedTypeIsDynamicDim(self, dim); }, py::arg("dim"), - "Returns whether the dim-th dimension of the given shaped type is " - "dynamic."); + "Returns whether the dim-th dimension of the given ranked shaped type " + "is dynamic."); c.def( "get_dim_size", [](PyShapedType &self, intptr_t dim) { - self.requireHasRank(); - return mlirShapedTypeGetDimSize(self, dim); + return mlirRankedShapedTypeGetDimSize(self, dim); }, py::arg("dim"), "Returns the dim-th dimension of the given ranked shaped type."); c.def_static( "is_dynamic_size", - [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, + [](int64_t size) -> bool { + return mlirRankedShapedTypeIsDynamicSize(size); + }, py::arg("dim_size"), "Returns whether the given dimension size indicates a dynamic " "dimension."); c.def( "is_dynamic_stride_or_offset", [](PyShapedType &self, int64_t val) -> bool { - self.requireHasRank(); - return mlirShapedTypeIsDynamicStrideOrOffset(val); + return mlirRankedShapedTypeIsDynamicStrideOrOffset(val); }, py::arg("dim_size"), "Returns whether the given value is used as a placeholder for dynamic " - "strides and offsets in shaped types."); + "strides and offsets in ranked shaped types."); c.def_property_readonly( "shape", [](PyShapedType &self) { - self.requireHasRank(); - std::vector shape; - int64_t rank = mlirShapedTypeGetRank(self); + int64_t rank = mlirRankedShapedTypeGetRank(self); shape.reserve(rank); for (int64_t i = 0; i < rank; ++i) - shape.push_back(mlirShapedTypeGetDimSize(self, i)); + shape.push_back(mlirRankedShapedTypeGetDimSize(self, i)); return shape; }, "Returns the shape of the ranked shaped type as a list of integers."); c.def_static( - "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); }, - "Returns the value used to indicate dynamic dimensions in shaped " - "types."); + "get_dynamic_size", + []() { return mlirRankedShapedTypeGetDynamicSize(); }, + "Returns the value used to indicate dynamic dimensions in ranked " + "shaped types."); c.def_static( "get_dynamic_stride_or_offset", - []() { return mlirShapedTypeGetDynamicStrideOrOffset(); }, + []() { return mlirRankedShapedTypeGetDynamicStrideOrOffset(); }, "Returns the value used to indicate dynamic strides or offsets in " - "shaped types."); - } - -private: - void requireHasRank() { - if (!mlirShapedTypeHasRank(*this)) { - throw SetPyError( - PyExc_ValueError, - "calling this method requires that the type has a rank."); - } + "ranked shaped types."); } }; /// Vector Type subclass - VectorType. -class PyVectorType : public PyConcreteType { +class PyVectorType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; static constexpr const char *pyClassName = "VectorType"; @@ -439,7 +438,7 @@ /// Ranked Tensor Type subclass - RankedTensorType. class PyRankedTensorType - : public PyConcreteType { + : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; static constexpr const char *pyClassName = "RankedTensorType"; @@ -495,7 +494,7 @@ }; /// Ranked MemRef Type subclass - MemRefType. -class PyMemRefType : public PyConcreteType { +class PyMemRefType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef; static constexpr const char *pyClassName = "MemRefType"; @@ -726,6 +725,7 @@ PyNoneType::bind(m); PyComplexType::bind(m); PyShapedType::bind(m); + PyRankedShapedType::bind(m); PyVectorType::bind(m); PyRankedTensorType::bind(m); PyUnrankedTensorType::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -164,43 +164,46 @@ bool mlirTypeIsAShaped(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsARankedShaped(MlirType type) { + return unwrap(type).isa(); +} + MlirType mlirShapedTypeGetElementType(MlirType type) { return wrap(unwrap(type).cast().getElementType()); } -bool mlirShapedTypeHasRank(MlirType type) { - return unwrap(type).cast().hasRank(); +int64_t mlirRankedShapedTypeGetRank(MlirType type) { + return unwrap(type).cast().getRank(); } -int64_t mlirShapedTypeGetRank(MlirType type) { - return unwrap(type).cast().getRank(); +bool mlirRankedShapedTypeHasStaticShape(MlirType type) { + return unwrap(type).cast().hasStaticShape(); } -bool mlirShapedTypeHasStaticShape(MlirType type) { - return unwrap(type).cast().hasStaticShape(); +bool mlirRankedShapedTypeIsDynamicDim(MlirType type, intptr_t dim) { + return unwrap(type).cast().isDynamicDim( + static_cast(dim)); } -bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) { - return unwrap(type).cast().isDynamicDim( +int64_t mlirRankedShapedTypeGetDimSize(MlirType type, intptr_t dim) { + return unwrap(type).cast().getDimSize( static_cast(dim)); } -int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) { - return unwrap(type).cast().getDimSize(static_cast(dim)); +int64_t mlirRankedShapedTypeGetDynamicSize() { + return RankedShapedType::kDynamic; } -int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamic; } - -bool mlirShapedTypeIsDynamicSize(int64_t size) { - return ShapedType::isDynamic(size); +bool mlirRankedShapedTypeIsDynamicSize(int64_t size) { + return RankedShapedType::isDynamic(size); } -bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) { - return ShapedType::isDynamic(val); +bool mlirRankedShapedTypeIsDynamicStrideOrOffset(int64_t val) { + return RankedShapedType::isDynamic(val); } -int64_t mlirShapedTypeGetDynamicStrideOrOffset() { - return ShapedType::kDynamic; +int64_t mlirRankedShapedTypeGetDynamicStrideOrOffset() { + return RankedShapedType::kDynamic; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -204,7 +204,7 @@ size_t i = pair.index(); Value index = pair.value(); Value strideOp; - if (ShapedType::isDynamic(strides[i])) { + if (RankedShapedType::isDynamic(strides[i])) { strideOp = rewriter.create( loc, memrefDescriptor.stride(rewriter, loc, i), byteWidthConst); } else { @@ -226,7 +226,7 @@ Value sgprOffset = adaptor.getSgprOffset(); if (!sgprOffset) sgprOffset = createI32Constant(rewriter, loc, 0); - if (ShapedType::isDynamic(offset)) + if (RankedShapedType::isDynamic(offset)) sgprOffset = rewriter.create( loc, memrefDescriptor.offset(rewriter, loc), sgprOffset); else if (offset > 0) diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -369,7 +369,7 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite( arith::ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto srcType = constOp.getType().dyn_cast(); + auto srcType = constOp.getType().dyn_cast(); if (!srcType || srcType.getNumElements() == 1) return failure(); @@ -435,10 +435,13 @@ // attributes; element attributes only works with builtin types. So we need // to prepare another converted builtin types for the destination elements // attribute. - if (dstAttrType.isa()) - dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); - else - dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); + if (auto rankedTensoType = dstAttrType.dyn_cast()) { + dstAttrType = + RankedTensorType::get(rankedTensoType.getShape(), dstElemType); + } else { + dstAttrType = VectorType::get( + dstAttrType.cast().getShape(), dstElemType); + } dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); } @@ -456,7 +459,7 @@ arith::ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Type srcType = constOp.getType(); - if (auto shapedType = srcType.dyn_cast()) { + if (auto shapedType = srcType.dyn_cast()) { if (shapedType.getNumElements() != 1) return failure(); srcType = shapedType.getElementType(); diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp --- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp @@ -54,9 +54,8 @@ // Extract all strides and offsets and verify they are static. auto [strides, offset] = getStridesAndOffset(type); - assert(!ShapedType::isDynamic(offset) && - "expected static offset"); - assert(!llvm::any_of(strides, ShapedType::isDynamic) && + assert(!RankedShapedType::isDynamic(offset) && "expected static offset"); + assert(!llvm::any_of(strides, RankedShapedType::isDynamic) && "expected static strides"); auto convertedType = typeConverter.convertType(type); diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -76,14 +76,14 @@ Value index; if (offset != 0) // Skip if offset is zero. - index = ShapedType::isDynamic(offset) + index = RankedShapedType::isDynamic(offset) ? memRefDescriptor.offset(rewriter, loc) : createIndexConstant(rewriter, loc, offset); for (int i = 0, e = indices.size(); i < e; ++i) { Value increment = indices[i]; if (strides[i] != 1) { // Skip if stride is 1. - Value stride = ShapedType::isDynamic(strides[i]) + Value stride = RankedShapedType::isDynamic(strides[i]) ? memRefDescriptor.stride(rewriter, loc, i) : createIndexConstant(rewriter, loc, strides[i]); increment = rewriter.create(loc, increment, stride); @@ -124,14 +124,14 @@ SmallVectorImpl &strides, Value &sizeBytes) const { assert(isConvertibleAndHasIdentityMaps(memRefType) && "layout maps must have been normalized away"); - assert(count(memRefType.getShape(), ShapedType::kDynamic) == + assert(count(memRefType.getShape(), RankedShapedType::kDynamic) == static_cast(dynamicSizes.size()) && "dynamicSizes size doesn't match dynamic sizes count in memref shape"); sizes.reserve(memRefType.getRank()); unsigned dynamicIndex = 0; for (int64_t size : memRefType.getShape()) { - sizes.push_back(size == ShapedType::kDynamic + sizes.push_back(size == RankedShapedType::kDynamic ? dynamicSizes[dynamicIndex++] : createIndexConstant(rewriter, loc, size)); } @@ -147,14 +147,14 @@ if (size == 0) continue; bool useSizeAsStride = stride == 1; - if (size == ShapedType::kDynamic) - stride = ShapedType::kDynamic; - if (stride != ShapedType::kDynamic) + if (size == RankedShapedType::kDynamic) + stride = RankedShapedType::kDynamic; + if (stride != RankedShapedType::kDynamic) stride *= size; if (useSizeAsStride) runningStride = sizes[i]; - else if (stride == ShapedType::kDynamic) + else if (stride == RankedShapedType::kDynamic) runningStride = rewriter.create(loc, runningStride, sizes[i]); else diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -428,10 +428,10 @@ return false; for (int64_t stride : strides) - if (ShapedType::isDynamic(stride)) + if (RankedShapedType::isDynamic(stride)) return false; - return !ShapedType::isDynamic(offset); + return !RankedShapedType::isDynamic(offset); } /// Convert a memref type to a bare pointer to the memref element type. diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -27,8 +27,8 @@ static MemRefType makeStridedLayoutDynamic(MemRefType type) { return MemRefType::Builder(type).setLayout(StridedLayoutAttr::get( - type.getContext(), ShapedType::kDynamic, - SmallVector(type.getRank(), ShapedType::kDynamic))); + type.getContext(), RankedShapedType::kDynamic, + SmallVector(type.getRank(), RankedShapedType::kDynamic))); } /// Helper function to extract the operand types that are passed to the diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -36,7 +36,7 @@ namespace { bool isStaticStrideOrOffset(int64_t strideOrOffset) { - return !ShapedType::isDynamic(strideOrOffset); + return !RankedShapedType::isDynamic(strideOrOffset); } LLVM::LLVMFuncOp getFreeFn(LLVMTypeConverter *typeConverter, ModuleOp module) { @@ -1380,7 +1380,7 @@ Value stride = nullptr; int64_t targetRank = targetMemRefType.getRank(); for (auto i : llvm::reverse(llvm::seq(0, targetRank))) { - if (!ShapedType::isDynamic(strides[i])) { + if (!RankedShapedType::isDynamic(strides[i])) { // If the stride for this dimension is dynamic, then use the product // of the sizes of the inner dimensions. stride = createIndexConstant(rewriter, loc, strides[i]); @@ -1634,11 +1634,11 @@ ArrayRef shape, ValueRange dynamicSizes, unsigned idx) const { assert(idx < shape.size()); - if (!ShapedType::isDynamic(shape[idx])) + if (!RankedShapedType::isDynamic(shape[idx])) return createIndexConstant(rewriter, loc, shape[idx]); // Count the number of dynamic dims in range [0, idx] unsigned nDynamic = - llvm::count_if(shape.take_front(idx), ShapedType::isDynamic); + llvm::count_if(shape.take_front(idx), RankedShapedType::isDynamic); return dynamicSizes[nDynamic]; } @@ -1650,7 +1650,7 @@ ArrayRef strides, Value nextSize, Value runningStride, unsigned idx) const { assert(idx < strides.size()); - if (!ShapedType::isDynamic(strides[idx])) + if (!RankedShapedType::isDynamic(strides[idx])) return createIndexConstant(rewriter, loc, strides[idx]); if (nextSize) return runningStride diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp @@ -178,7 +178,7 @@ auto storageAttr = spirv::StorageClassAttr::get(memRefType.getContext(), *storage); if (auto rankedType = memRefType.dyn_cast()) { - return MemRefType::get(memRefType.getShape(), memRefType.getElementType(), + return MemRefType::get(rankedType.getShape(), memRefType.getElementType(), rankedType.getLayout(), storageAttr); } return UnrankedMemRefType::get(memRefType.getElementType(), storageAttr); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -520,7 +520,8 @@ "All TOSA elementwise ops should only return a single result."); auto results = operation->getResults(); - auto resultTy = operation->getResult(0).getType().dyn_cast(); + auto resultTy = + operation->getResult(0).getType().dyn_cast(); if (!resultTy) return rewriter.notifyMatchFailure(operation, @@ -538,10 +539,10 @@ SmallVector emptyTensors; SmallVector dynDims; - dynDims.resize(results.front().getType().cast().getRank()); + dynDims.resize(results.front().getType().cast().getRank()); for (auto arg : operation->getOperands()) { - auto operandTy = arg.getType().cast(); + auto operandTy = arg.getType().cast(); for (int i = 0; i < operandTy.getRank(); i++) { if (operandTy.isDynamicDim(i) && !dynDims[i]) dynDims[i] = rewriter.create(loc, arg, i); @@ -551,7 +552,7 @@ SmallVector filteredDims = condenseValues(dynDims); for (auto result : results) { - auto resultTy = result.getType().template cast(); + auto resultTy = result.getType().template cast(); emptyTensors.push_back(rewriter.create( loc, resultTy.getShape(), resultTy.getElementType(), filteredDims)); opResultTypes.push_back(result.getType()); @@ -566,7 +567,7 @@ // Input indexing maps may be broadcasted. for (Value operand : operation->getOperands()) { - ShapedType type = operand.getType().cast(); + auto type = operand.getType().cast(); if (type.getShape() == resultTy.getShape()) { operands.push_back(operand); @@ -733,7 +734,7 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, PatternRewriter &rewriter) { auto loc = op->getLoc(); - auto inputTy = op->getOperand(0).getType().template cast(); + auto inputTy = op->getOperand(0).getType().template cast(); auto resultTy = op->getResult(0).getType().template cast(); auto elementTy = resultTy.getElementType(); Value input = op->getOperand(0); @@ -799,7 +800,7 @@ SmallVector reassociationMap; uint64_t expandInputRank = - linalgOp.getResults()[0].getType().cast().getRank(); + linalgOp.getResults()[0].getType().cast().getRank(); reassociationMap.resize(expandInputRank); for (uint64_t i = 0; i < expandInputRank; i++) { @@ -848,18 +849,19 @@ auto loc = op.getLoc(); auto input = op->getOperand(0); - auto resultTy = op.getType().cast(); + auto resultTy = op.getType().cast(); SmallVector dynDims; - dynDims.resize(op->getResult(0).getType().cast().getRank()); + dynDims.resize( + op->getResult(0).getType().cast().getRank()); SmallVector inputExprs; inputExprs.resize(resultTy.getRank()); - auto operandTy = input.getType().cast(); + auto operandTy = input.getType().dyn_cast(); for (const auto &permutation : llvm::enumerate(perms.getValues())) { auto index = permutation.index(); auto value = permutation.value().getZExtValue(); - if (!operandTy.hasRank() || operandTy.isDynamicDim(index)) { + if (!operandTy || operandTy.isDynamicDim(index)) { dynDims[value] = rewriter.create(loc, input, index); } inputExprs[value] = rewriter.getAffineDimExpr(index); @@ -893,8 +895,8 @@ PatternRewriter &rewriter) const final { auto loc = op.getLoc(); auto input = op.getInput(); - auto inputTy = op.getInput().getType().cast(); - auto outputTy = op.getOutput().getType().cast(); + auto inputTy = op.getInput().getType().cast(); + auto outputTy = op.getOutput().getType().cast(); unsigned rank = inputTy.getRank(); // This is an illegal configuration. terminate and log an error @@ -1135,7 +1137,8 @@ outputDynSize.push_back(builder.create(input, 3)); // Generate the elementwise operation for casting scaling the input value. - auto genericTy = collapseTy.clone(resultTy.getElementType()); + auto genericTy = + collapseTy.clone(resultTy.getElementType()).cast(); Value empty = builder.create( genericTy.getShape(), resultTy.getElementType(), outputDynSize); auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank()); @@ -1282,8 +1285,8 @@ Location loc = op.getLoc(); ImplicitLocOpBuilder b(loc, rewriter); auto input = op.getInput(); - auto inputTy = input.getType().cast(); - auto resultTy = op.getType().cast(); + auto inputTy = input.getType().cast(); + auto resultTy = op.getType().cast(); auto resultETy = resultTy.getElementType(); auto imageH = inputTy.getShape()[1]; @@ -1573,8 +1576,8 @@ PatternRewriter &rewriter) const final { auto loc = op.getLoc(); Value input = op.getInput(); - auto inputTy = input.getType().template cast(); - auto resultTy = op.getType().template cast(); + auto inputTy = input.getType().template cast(); + auto resultTy = op.getType().template cast(); auto axis = op.getAxis(); SmallVector dynDims; @@ -1635,9 +1638,9 @@ ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto input = op.getInput1(); - auto inputTy = input.getType().cast(); + auto inputTy = input.getType().cast(); auto inputShape = inputTy.getShape(); - auto resultTy = op.getType().cast(); + auto resultTy = op.getType().cast(); auto elementTy = inputTy.getElementType(); int64_t rank = inputTy.getRank(); @@ -1647,7 +1650,7 @@ SmallVector genericShape; for (int i = 0; i < rank; i++) { int64_t dim = multiples[i]; - genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim); + genericShape.push_back(dim == -1 ? RankedShapedType::kDynamic : dim); genericShape.push_back(inputShape[i]); } @@ -1710,8 +1713,8 @@ PatternRewriter &rewriter) const final { auto loc = argmaxOp.getLoc(); Value input = argmaxOp.getInput(); - auto inputTy = input.getType().cast(); - auto resultTy = argmaxOp.getOutput().getType().cast(); + auto inputTy = input.getType().cast(); + auto resultTy = argmaxOp.getOutput().getType().cast(); auto inElementTy = inputTy.getElementType(); auto outElementTy = resultTy.getElementType(); int axis = argmaxOp.getAxis(); @@ -1831,7 +1834,7 @@ auto valuesTy = op.getValues().getType().dyn_cast_or_null(); - auto resultTy = op.getType().cast(); + auto resultTy = op.getType().cast(); if (!valuesTy) return rewriter.notifyMatchFailure(op, "unranked tensors not supported"); @@ -1904,9 +1907,9 @@ auto loc = op.getLoc(); Value input = op.getInput(); Value table = op.getTable(); - auto inputTy = input.getType().cast(); + auto inputTy = input.getType().cast(); auto tableTy = table.getType().cast(); - auto resultTy = op.getType().cast(); + auto resultTy = op.getType().cast(); auto inputElementTy = inputTy.getElementType(); auto tableElementTy = tableTy.getElementType(); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -36,7 +36,7 @@ if (llvm::all_of(pad, [](int64_t p) { return p == 0; })) return input; - ShapedType inputTy = input.getType().cast(); + auto inputTy = input.getType().cast(); Type inputETy = inputTy.getElementType(); auto inputShape = inputTy.getShape(); @@ -48,7 +48,7 @@ for (int i = 0, s = inputShape.size(); i < s; i++) { auto lowPad = pad[i * 2]; auto highPad = pad[i * 2 + 1]; - if (ShapedType::isDynamic(inputShape[i])) + if (RankedShapedType::isDynamic(inputShape[i])) paddedShape.push_back(inputShape[i]); else paddedShape.push_back(inputShape[i] + highPad + lowPad); @@ -67,7 +67,7 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, Value conv, Value result, ArrayRef indexingMaps) { - ShapedType resultTy = conv.getType().cast(); + auto resultTy = conv.getType().cast(); return rewriter .create( loc, resultTy, ValueRange({bias, conv}), result, indexingMaps, @@ -121,11 +121,11 @@ // Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D static SmallVector inferDynamicDimsForConv( - Location loc, Value input, Value weight, ShapedType resultTy, + Location loc, Value input, Value weight, RankedShapedType resultTy, ArrayRef padAttr, ArrayRef strideAttr, ArrayRef dilationAttr, ArrayRef inputSizeDims, ArrayRef kernelSizeDims, OpBuilder &rewriter) { - ShapedType inputTy = input.getType().cast(); + auto inputTy = input.getType().cast(); Type inputETy = inputTy.getElementType(); int64_t inputRank = inputTy.getRank(); @@ -187,11 +187,11 @@ Value weight = op->getOperand(1); Value bias = op->getOperand(2); - ShapedType inputTy = input.getType().template cast(); - ShapedType weightTy = weight.getType().template cast(); - ShapedType biasTy = bias.getType().template cast(); - ShapedType resultTy = - op->getResult(0).getType().template cast(); + auto inputTy = input.getType().template cast(); + auto weightTy = weight.getType().template cast(); + auto biasTy = bias.getType().template cast(); + auto resultTy = + op->getResult(0).getType().template cast(); Type inputETy = inputTy.getElementType(); Type resultETy = resultTy.getElementType(); @@ -353,10 +353,10 @@ Value weight = op->getOperand(1); Value bias = op->getOperand(2); - ShapedType inputTy = input.getType().cast(); - ShapedType weightTy = weight.getType().cast(); - ShapedType biasTy = bias.getType().cast(); - ShapedType resultTy = op->getResult(0).getType().cast(); + auto inputTy = input.getType().cast(); + auto weightTy = weight.getType().cast(); + auto biasTy = bias.getType().cast(); + auto resultTy = op->getResult(0).getType().cast(); int64_t resultRank = resultTy.getRank(); Type inputETy = inputTy.getElementType(); @@ -426,7 +426,7 @@ // Create the convolution op. auto strideAttr = rewriter.getI64TensorAttr(stride); auto dilationAttr = rewriter.getI64TensorAttr(dilation); - ShapedType linalgConvTy = + RankedShapedType linalgConvTy = RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2], weightShape[2], weightShape[3]}, resultETy); @@ -505,24 +505,27 @@ ConversionPatternRewriter &rewriter) const final { Location loc = op.getLoc(); - auto outputTy = op.getType().cast(); + auto outputTy = op.getType().cast(); auto outputElementTy = outputTy.getElementType(); - auto firstOperandTy = op->getOperand(0).getType().cast(); - auto secondOperandTy = op->getOperand(1).getType().cast(); + auto firstOperandTy = + op->getOperand(0).getType().dyn_cast(); + auto secondOperandTy = + op->getOperand(1).getType().dyn_cast(); SmallVector dynDims; - dynDims.resize(op->getResult(0).getType().cast().getRank()); + dynDims.resize( + op->getResult(0).getType().cast().getRank()); - if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(0)) { + if (!firstOperandTy || firstOperandTy.isDynamicDim(0)) { dynDims[0] = rewriter.create(loc, op->getOperand(0), 0); } - if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(1)) { + if (!firstOperandTy || firstOperandTy.isDynamicDim(1)) { dynDims[1] = rewriter.create(loc, op->getOperand(0), 1); } - if (!secondOperandTy.hasRank() || secondOperandTy.isDynamicDim(2)) { + if (!secondOperandTy || secondOperandTy.isDynamicDim(2)) { dynDims[2] = rewriter.create(loc, op->getOperand(1), 2); } @@ -564,26 +567,27 @@ matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { Location loc = op.getLoc(); - auto outputTy = op.getType().cast(); + auto outputTy = op.getType().cast(); auto input = op.getInput(); - auto inputTy = input.getType().cast(); + auto inputTy = input.getType().dyn_cast(); auto bias = op.getBias(); auto weight = op.getWeight(); - auto weightTy = weight.getType().cast(); + auto weightTy = weight.getType().cast(); auto weightShape = weightTy.getShape(); auto outputETy = outputTy.getElementType(); SmallVector dynDims; - dynDims.resize(op->getResult(0).getType().cast().getRank()); + dynDims.resize( + op->getResult(0).getType().cast().getRank()); - if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) { + if (!inputTy || inputTy.isDynamicDim(0)) { dynDims[0] = rewriter.create(loc, input, 0); } - if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) { + if (weightTy.isDynamicDim(0)) { dynDims[1] = rewriter.create(loc, weight, 0); } @@ -678,7 +682,7 @@ Value input = op.getInput(); ShapedType inputTy = input.getType().cast(); - ShapedType resultTy = op.getType().template cast(); + auto resultTy = op.getType().template cast(); Type resultETy = inputTy.getElementType(); auto dynamicDimsOr = @@ -750,12 +754,12 @@ ShapedType inputTy = input.getType().cast(); Type inElementTy = inputTy.getElementType(); - ShapedType resultTy = op.getType().template cast(); + auto resultTy = op.getType().template cast(); Type resultETy = op.getType().cast().getElementType(); Type accETy = inElementTy.isa() ? rewriter.getI32Type() : inElementTy; - ShapedType accTy = resultTy.clone(accETy); + auto accTy = resultTy.clone(accETy).cast(); auto dynamicDimsOr = checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()}); diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -26,7 +26,7 @@ bool isDynamic) { if (isDynamic) { // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1 - intermediateShape = {ShapedType::kDynamic}; + intermediateShape = {RankedShapedType::kDynamic}; return true; } @@ -134,8 +134,8 @@ LogicalResult matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - ShapedType operandTy = adaptor.getInput1().getType().cast(); - ShapedType resultTy = reshape.getType().template cast(); + auto operandTy = adaptor.getInput1().getType().cast(); + auto resultTy = reshape.getType().template cast(); bool isDynamic = !operandTy.hasStaticShape(); if (isDynamic && resultTy.getRank() != 1) { @@ -172,8 +172,8 @@ LogicalResult matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - ShapedType operandTy = adaptor.getInput1().getType().cast(); - ShapedType resultTy = reshape.getType().template cast(); + auto operandTy = adaptor.getInput1().getType().cast(); + auto resultTy = reshape.getType().template cast(); bool isDynamic = !operandTy.hasStaticShape(); if (isDynamic && operandTy.getRank() != 1) { @@ -211,8 +211,8 @@ LogicalResult matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - ShapedType operandTy = adaptor.getInput1().getType().cast(); - ShapedType resultTy = reshape.getType().template cast(); + auto operandTy = adaptor.getInput1().getType().cast(); + auto resultTy = reshape.getType().template cast(); bool isDynamic = !operandTy.hasStaticShape(); SmallVector intermediateShape; @@ -247,14 +247,15 @@ Value input = adaptor.getInput(); SmallVector strides, sizes; ArrayRef starts = sliceOp.getStart(); - strides.resize(sliceOp.getType().template cast().getRank(), 1); + strides.resize( + sliceOp.getType().template cast().getRank(), 1); SmallVector dynSizes; for (const auto &i : llvm::enumerate(sliceOp.getSize())) { int64_t size = i.value(); size_t index = i.index(); - sizes.push_back(size == -1 ? ShapedType::kDynamic : size); - if (!ShapedType::isDynamic(sizes.back())) + sizes.push_back(size == -1 ? RankedShapedType::kDynamic : size); + if (!RankedShapedType::isDynamic(sizes.back())) continue; auto dim = rewriter.create(loc, input, index); @@ -284,7 +285,7 @@ auto input = padOp.getInput1(); auto padding = padOp.getPadding(); - ShapedType inputTy = input.getType().cast(); + auto inputTy = input.getType().cast(); Type elementTy = inputTy.getElementType(); int64_t rank = inputTy.getRank(); @@ -355,7 +356,8 @@ LogicalResult matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto inputType = op.getOperand(0).getType().template cast(); + auto inputType = + op.getOperand(0).getType().template cast(); auto resultType = op.getType().dyn_cast(); Location loc = op.getLoc(); diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -135,7 +135,7 @@ strides.back() != 1) return std::nullopt; int64_t stride = strides[strides.size() - 2]; - if (stride == ShapedType::kDynamic) + if (stride == RankedShapedType::kDynamic) return std::nullopt; return stride; } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1313,9 +1313,9 @@ // layout. auto sizes = memRefType.getShape(); for (int index = 0, e = strides.size() - 1; index < e; ++index) { - if (ShapedType::isDynamic(sizes[index + 1]) || - ShapedType::isDynamic(strides[index]) || - ShapedType::isDynamic(strides[index + 1])) + if (RankedShapedType::isDynamic(sizes[index + 1]) || + RankedShapedType::isDynamic(strides[index]) || + RankedShapedType::isDynamic(strides[index + 1])) return std::nullopt; if (strides[index] != strides[index + 1] * sizes[index + 1]) return std::nullopt; @@ -1360,7 +1360,7 @@ if (!targetStrides) return failure(); // Only support static strides for now, regardless of contiguity. - if (llvm::any_of(*targetStrides, ShapedType::isDynamic)) + if (llvm::any_of(*targetStrides, RankedShapedType::isDynamic)) return failure(); auto int64Ty = IntegerType::get(rewriter.getContext(), 64); diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -828,7 +828,7 @@ for (unsigned r = 0; r < rank; r++) { cstWithShapeBounds.addBound(BoundType::LB, r, 0); int64_t dimSize = memRefType.getDimSize(r); - if (ShapedType::isDynamic(dimSize)) + if (RankedShapedType::isDynamic(dimSize)) continue; cstWithShapeBounds.addBound(BoundType::UB, r, dimSize - 1); } @@ -850,7 +850,7 @@ // If no constant bound is found, then it can always be bound by the // memref's dim size if the latter has a constant size along this dim. auto dimSize = memRefType.getDimSize(d); - if (dimSize == ShapedType::kDynamic) + if (dimSize == RankedShapedType::kDynamic) return std::nullopt; diffConstant = dimSize; // Lower bound becomes 0. diff --git a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp --- a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp @@ -92,7 +92,7 @@ // Put together alloc operands for any dynamic dimensions of the memref. SmallVector allocOperands; for (const auto &dim : llvm::enumerate(oldMemRefType.getShape())) { - if (dim.value() == ShapedType::kDynamic) + if (dim.value() == RankedShapedType::kDynamic) allocOperands.push_back(bOuter.createOrFold( forOp.getLoc(), oldMemRef, dim.index())); } diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp --- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp @@ -54,7 +54,7 @@ continue; } - assert(cast(value.getType()).isDynamicDim(*dim) && + assert(cast(value.getType()).isDynamicDim(*dim) && "expected dynamic dim"); if (isa(value.getType())) { // A tensor dimension is used: generate a tensor.dim. diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -1821,7 +1821,7 @@ bool isDynDim = isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims, context); if (isDynDim) { - newShape[d] = ShapedType::kDynamic; + newShape[d] = RankedShapedType::kDynamic; } else { // The lower bound for the shape is always zero. std::optional ubConst = fac.getConstantBound64(BoundType::UB, d); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -145,7 +145,7 @@ auto &shape = resultDims[shapedValue.cast().getResultNumber()]; for (const auto &dim : enumerate(tensorType.getShape())) - if (ShapedType::isDynamic(dim.value())) + if (RankedShapedType::isDynamic(dim.value())) dynamicSizes.push_back(shape[dim.index()].get()); } } @@ -838,9 +838,9 @@ // Case 2: Ranked memref type. auto rankedTensorType = tensorType.cast(); - int64_t dynamicOffset = ShapedType::kDynamic; + int64_t dynamicOffset = RankedShapedType::kDynamic; SmallVector dynamicStrides(rankedTensorType.getRank(), - ShapedType::kDynamic); + RankedShapedType::kDynamic); auto stridedLayout = StridedLayoutAttr::get(tensorType.getContext(), dynamicOffset, dynamicStrides); return MemRefType::get(rankedTensorType.getShape(), diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -47,7 +47,7 @@ failed(getStridesAndOffset(target, targetStrides, targetOffset))) return false; auto dynamicToStatic = [](int64_t a, int64_t b) { - return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b); + return RankedShapedType::isDynamic(a) && !RankedShapedType::isDynamic(b); }; if (dynamicToStatic(sourceOffset, targetOffset)) return false; @@ -69,7 +69,7 @@ auto loc = value.getLoc(); SmallVector dynamicOperands; for (int i = 0; i < destType.getRank(); ++i) { - if (destType.getShape()[i] != ShapedType::kDynamic) + if (destType.getShape()[i] != RankedShapedType::kDynamic) continue; auto index = b.createOrFold(loc, i); Value size = b.create(loc, value, index); @@ -132,7 +132,7 @@ void mlir::bufferization::populateDynamicDimSizes( OpBuilder &b, Location loc, Value shapedValue, SmallVector &dynamicDims) { - auto shapedType = shapedValue.getType().cast(); + auto shapedType = shapedValue.getType().cast(); for (int64_t i = 0; i < shapedType.getRank(); ++i) { if (shapedType.isDynamicDim(i)) { if (shapedType.isa()) { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp @@ -43,7 +43,7 @@ /// exceed the stack space. static bool defaultIsSmallAlloc(Value alloc, unsigned maximumSizeInBytes, unsigned maxRankOfAllocatedMemRef) { - auto type = alloc.getType().dyn_cast(); + auto type = alloc.getType().dyn_cast(); if (!type || !alloc.getDefiningOp()) return false; if (!type.hasStaticShape()) { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -28,9 +28,9 @@ SmallVector strides; if (failed(getStridesAndOffset(type, strides, offset))) return false; - if (!llvm::all_of(strides, ShapedType::isDynamic)) + if (!llvm::all_of(strides, RankedShapedType::isDynamic)) return false; - if (!ShapedType::isDynamic(offset)) + if (!RankedShapedType::isDynamic(offset)) return false; return true; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -151,8 +151,8 @@ // and the second // non-dynamic (scalable vector). if (dims.empty() || dims.size() > 2 || - ((dims.size() == 2) ^ (ShapedType::isDynamic(dims[0]))) || - (dims.size() == 2 && ShapedType::isDynamic(dims[1]))) { + ((dims.size() == 2) ^ (RankedShapedType::isDynamic(dims[0]))) || + (dims.size() == 2 && RankedShapedType::isDynamic(dims[1]))) { parser.emitError(dimPos) << "expected '? x x ' or ' x '"; return Type(); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -514,8 +514,8 @@ } static OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) { - auto shapedType = source.getType().cast(); - if (!shapedType.hasRank() || shapedType.isDynamicDim(dim)) + auto shapedType = source.getType().dyn_cast(); + if (!shapedType || shapedType.isDynamicDim(dim)) return createOrFoldDimOp(b, loc, source, dim); return b.getIndexAttr(shapedType.getDimSize(dim)); } @@ -644,7 +644,7 @@ for (OpOperand *opOperand : getDpsInitOperands()) { SmallVector shapes; for (int64_t dim : llvm::seq(0, getRank(opOperand))) { - auto shapedType = opOperand->get().getType().cast(); + auto shapedType = opOperand->get().getType().cast(); if (!shapedType.isDynamicDim(dim)) { // Static dim: Return IntegerAttr. shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim))); @@ -731,7 +731,7 @@ // Verify only static cases since we can't get exact dimension sizes and loop // ranges for dynamic cases in this stage. - if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) { + if (llvm::none_of(endLoopRangeValues, RankedShapedType::isDynamic)) { for (int64_t &range : endLoopRangeValues) range -= 1; for (OpOperand &opOperand : linalgOp->getOpOperands()) { @@ -743,7 +743,7 @@ ArrayRef shape = linalgOp.getShape(&opOperand); for (auto dim : llvm::seq(0, shape.size())) { // Ignore dynamic dimension or the case that the dimension size is 0 - if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0) + if (RankedShapedType::isDynamic(shape[dim]) || shape[dim] == 0) continue; // The first index or last index should be the maximum or the minimum in diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1180,7 +1180,7 @@ // The shape of each input must match the shape of the output. auto outputShape = getInit().getType().getShape(); for (Type inputArgType : TypeRange{getInputs()}) { - auto inputElemShape = inputArgType.cast().getShape(); + auto inputElemShape = inputArgType.cast().getShape(); if (inputElemShape != outputShape) { return emitOpError() << "expected shape of input (" << inputElemShape << ") to match shape of output (" << outputShape @@ -1250,7 +1250,8 @@ } SmallVector ReduceOp::getIteratorTypesArray() { - int64_t inputRank = getInputs()[0].getType().cast().getRank(); + int64_t inputRank = + getInputs()[0].getType().cast().getRank(); SmallVector iteratorTypes(inputRank, utils::IteratorType::parallel); for (int64_t reductionDim : getDimensions()) @@ -1259,7 +1260,8 @@ } ArrayAttr ReduceOp::getIndexingMaps() { - int64_t inputRank = getInputs()[0].getType().cast().getRank(); + int64_t inputRank = + getInputs()[0].getType().cast().getRank(); SmallVector affineMaps( getNumDpsInputs(), AffineMap::getMultiDimIdentityMap(inputRank, getContext())); @@ -1360,8 +1362,8 @@ ArrayRef dimensionsRef = getDimensions(); for (int64_t i = 1; i < getNumDpsInputs(); ++i) { - if (getInputs()[i].getType().cast().getShape() != - getInputs()[0].getType().cast().getShape()) { + if (getInputs()[i].getType().cast().getShape() != + getInputs()[0].getType().cast().getShape()) { return emitOpError() << "expects all inputs to have the same shapes. " "Shape at input-index " << i @@ -1369,16 +1371,16 @@ } } for (int64_t i = 1; i < getNumDpsInits(); ++i) { - if (getInits()[i].getType().cast().getShape() != - getInits()[0].getType().cast().getShape()) { + if (getInits()[i].getType().cast().getShape() != + getInits()[0].getType().cast().getShape()) { return emitOpError() << "expects all outputs to have the same shapes. " "Shape at output-index " << i << " is not equal to the shape at output-index 0."; } } - auto inputType = getInputs()[0].getType().cast(); - auto initType = getInits()[0].getType().cast(); + auto inputType = getInputs()[0].getType().cast(); + auto initType = getInits()[0].getType().cast(); DenseSet dimensionsToReduce; for (int64_t dimension : dimensionsRef) { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2153,7 +2153,7 @@ } staticSplitPoint = - parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic); + parser.getBuilder().getI64IntegerAttr(RankedShapedType::kDynamic); } result.addAttribute( @@ -2166,7 +2166,7 @@ void SplitOp::print(OpAsmPrinter &printer) { printer << " " << getTarget() << " after "; int64_t staticSplitSize = static_cast(getStaticSplitPoint()); - if (staticSplitSize != ShapedType::kDynamic) + if (staticSplitSize != RankedShapedType::kDynamic) printer << staticSplitSize; else printer << getDynamicSplitPoint(); @@ -2174,12 +2174,13 @@ printer.printOptionalAttrDict(getOperation()->getAttrs(), {getStaticSplitPointAttrName()}); printer << " : " << getTarget().getType(); - if (staticSplitSize == ShapedType::kDynamic) + if (staticSplitSize == RankedShapedType::kDynamic) printer << ", " << getDynamicSplitPoint().getType(); } LogicalResult SplitOp::verify() { - if ((static_cast(getStaticSplitPoint()) != ShapedType::kDynamic) ^ + if ((static_cast(getStaticSplitPoint()) != + RankedShapedType::kDynamic) ^ (getDynamicSplitPoint() == nullptr)) { return emitOpError() << "expects either a dynamic or a static split " "point to be provided"; @@ -2528,7 +2529,7 @@ unsigned dynamicPos = 0; Builder builder(getContext()); for (int64_t size : tileSizes) { - if (size == ShapedType::kDynamic) { + if (size == RankedShapedType::kDynamic) { results.push_back(dynamic[dynamicPos++]); } else { results.push_back(builder.getIndexAttr(size)); @@ -2918,7 +2919,7 @@ unsigned dynamicPos = 0; Builder builder(getContext()); for (int64_t size : tileSizes) { - if (size == ShapedType::kDynamic) { + if (size == RankedShapedType::kDynamic) { results.push_back(dynamic[dynamicPos++]); } else { results.push_back(builder.getIndexAttr(size)); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp @@ -64,7 +64,8 @@ if (genericOp.getNumDpsInits() != 1) return failure(); - auto outputType = genericOp.getResultTypes().front().dyn_cast(); + auto outputType = + genericOp.getResultTypes().front().dyn_cast(); // Require the output types to be static given that we are generating // constants. if (!outputType || !outputType.hasStaticShape()) @@ -174,7 +175,7 @@ auto inputShapes = llvm::to_vector<4>( llvm::map_range(genericOp.getInputs(), [](Value value) { - return value.getType().cast().getShape(); + return value.getType().cast().getShape(); })); // Given a `linearIndex`, remap it to a linear index to access linalg op diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp @@ -73,15 +73,16 @@ FailureOr> rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { - auto inputType = convOp.getInputs()[0].getType().cast(); - auto filterType = convOp.getInputs()[1].getType().cast(); - auto outputType = convOp.getOutputs()[0].getType().cast(); + auto inputType = convOp.getInputs()[0].getType().dyn_cast(); + auto filterType = + convOp.getInputs()[1].getType().dyn_cast(); + auto outputType = convOp.getOutputs()[0].getType().cast(); - if (!filterType.hasStaticShape()) + if (!filterType || !filterType.hasStaticShape()) return rewriter.notifyMatchFailure( convOp, "expected a static shape for the filter"); - if (!inputType.hasStaticShape()) + if (!inputType || !inputType.hasStaticShape()) return rewriter.notifyMatchFailure(convOp, "expected a static shape for the input"); @@ -359,15 +360,16 @@ FailureOr> rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { - auto inputType = convOp.getInputs()[0].getType().cast(); - auto filterType = convOp.getInputs()[1].getType().cast(); - auto outputType = convOp.getOutputs()[0].getType().cast(); + auto inputType = convOp.getInputs()[0].getType().dyn_cast(); + auto filterType = + convOp.getInputs()[1].getType().dyn_cast(); + auto outputType = convOp.getOutputs()[0].getType().cast(); - if (!filterType.hasStaticShape()) + if (!filterType || !filterType.hasStaticShape()) return rewriter.notifyMatchFailure( convOp, "expected a static shape for the filter"); - if (!inputType.hasStaticShape()) + if (!inputType || !inputType.hasStaticShape()) return rewriter.notifyMatchFailure(convOp, "expected a static shape for the input"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -49,7 +49,8 @@ /// /// Returns true if tensorType can be detensored. bool canBeDetensored(TensorType tensorType) { - return tensorType.hasRank() && tensorType.getRank() == 0; + auto rankedTensorType = tensorType.dyn_cast(); + return rankedTensorType && rankedTensorType.getRank() == 0; } bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -620,7 +620,7 @@ if (expandedShape.size() == 1) continue; for (int64_t shape : expandedShape.drop_front()) { - if (ShapedType::isDynamic(shape)) { + if (RankedShapedType::isDynamic(shape)) { return rewriter.notifyMatchFailure( genericOp, "cannot expand due to index semantics and dynamic dims"); } @@ -716,7 +716,7 @@ [&](int64_t dim) { return rewriter.create(loc, dim); }); Value newIndex = rewriter.create(loc, expandedDims.front()); for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) { - assert(!ShapedType::isDynamic(std::get<0>(it))); + assert(!RankedShapedType::isDynamic(std::get<0>(it))); AffineExpr idx, acc; bindDims(rewriter.getContext(), idx, acc); newIndex = rewriter.create( @@ -1512,8 +1512,9 @@ Value collapsedOpResult = collapsedGenericOp->getResult(originalResult.index()); auto originalResultType = - originalResult.value().getType().cast(); - auto collapsedOpResultType = collapsedOpResult.getType().cast(); + originalResult.value().getType().cast(); + auto collapsedOpResultType = + collapsedOpResult.getType().cast(); if (collapsedOpResultType.getRank() != originalResultType.getRank()) { AffineMap indexingMap = genericOp.getIndexingMapMatchingResult(originalResult.value()); @@ -1755,7 +1756,7 @@ modifiedOutput = true; SmallVector dynamicDims; for (const auto &dim : llvm::enumerate(operandType.getShape())) { - if (dim.value() != ShapedType::kDynamic) + if (dim.value() != RankedShapedType::kDynamic) continue; dynamicDims.push_back(rewriter.createOrFold( loc, operandVal, dim.index())); diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp @@ -92,7 +92,7 @@ sizes.reserve(offsets.size()); for (const auto &shape : llvm::enumerate( source.getType().cast().getShape())) { - if (ShapedType::isDynamic(shape.value())) { + if (RankedShapedType::isDynamic(shape.value())) { sizes.push_back( rewriter.create(loc, source, shape.index()) .getResult()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -155,11 +155,11 @@ if (!tensorType) continue; unsigned rank = tensorType.getRank(); - SmallVector staticOffsetsVector( - rank, ShapedType::kDynamic); - SmallVector staticSizesVector(rank, ShapedType::kDynamic); - SmallVector staticStridesVector( - rank, ShapedType::kDynamic); + SmallVector staticOffsetsVector(rank, + RankedShapedType::kDynamic); + SmallVector staticSizesVector(rank, RankedShapedType::kDynamic); + SmallVector staticStridesVector(rank, + RankedShapedType::kDynamic); resultTypes.push_back(tensor::ExtractSliceOp::inferResultType( tensorType, staticOffsetsVector, staticSizesVector, staticStridesVector)); diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -620,7 +620,8 @@ sizes = SmallVector(nPackedLoops, rewriter.getIndexAttr(1)); for (int64_t sz : transposedTensorType.getShape()) { // TODO: go grab dims when needed, atm tensor::PadOp yields a static tensor. - assert(!ShapedType::isDynamic(sz) && "padded tensor needs static sizes"); + assert(!RankedShapedType::isDynamic(sz) && + "padded tensor needs static sizes"); sizes.push_back(rewriter.getIndexAttr(sz)); } // strides = [1 .. 1]. @@ -691,7 +692,7 @@ } // Create the packed tensor. - SmallVector packedShape(nPackedLoops, ShapedType::kDynamic); + SmallVector packedShape(nPackedLoops, RankedShapedType::kDynamic); // TODO: go grab dims when needed, atm tensor::PadOp yields a static tensor. llvm::append_range(packedShape, transposedTensorType->getShape()); auto hoistedPackedTensorType = RankedTensorType::get( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -68,7 +68,7 @@ // Fallback dynamic buffer. auto dynamicBufferType = - MemRefType::get(ShapedType::kDynamic, b.getIntegerType(8)); + MemRefType::get(RankedShapedType::kDynamic, b.getIntegerType(8)); Value mul = b.createOrFold( b.create(width), allocSize); if (options.useAlloca) @@ -95,7 +95,7 @@ Value buffer = allocBuffer(b, options, viewType.getElementType(), allocSize, layout, alignment); SmallVector dynSizes(boundingSubViewSize.size(), - ShapedType::kDynamic); + RankedShapedType::kDynamic); Value view = b.createOrFold( MemRefType::get(dynSizes, viewType.getElementType()), buffer, zero, boundingSubViewSize); @@ -247,7 +247,8 @@ partialSizes.push_back( b.createOrFold(loc, subView, resultDimIdx++)); } - SmallVector dynSizes(fullSizes.size(), ShapedType::kDynamic); + SmallVector dynSizes(fullSizes.size(), + RankedShapedType::kDynamic); // If a callback is not specified, then use the default implementation for // allocating the promoted buffer. std::optional fullLocalView = diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -51,7 +51,8 @@ } SmallVector loopRanges = op.getStaticLoopRanges(); int64_t reductionDimSize = loopRanges[reductionDim]; - if (reductionDimSize == ShapedType::kDynamic || reductionDimSize % ratio != 0) + if (reductionDimSize == RankedShapedType::kDynamic || + reductionDimSize % ratio != 0) return b.notifyMatchFailure( op, "Reduction dimension not divisible by split ratio"); if (op.getNumDpsInits() != 1) @@ -262,7 +263,7 @@ unsigned reductionDimPos = dims[0]; SmallVector loopRanges = op.getStaticLoopRanges(); int64_t reductionDimSize = loopRanges[reductionDimPos]; - if (reductionDimSize == ShapedType::kDynamic || + if (reductionDimSize == RankedShapedType::kDynamic || reductionDimSize % splitFactor != 0 || insertSplitDimension >= loopRanges.size()) return b.notifyMatchFailure( diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -290,7 +290,7 @@ int64_t oldIdx = idx < insertSplitDimension ? idx : idx - 1; int64_t dim = oldShape[oldIdx]; newOutputShape.push_back(dim); - if (ShapedType::isDynamic(dim)) + if (RankedShapedType::isDynamic(dim)) dynamicDims.push_back(b.createOrFold( loc, linalgOp.getDpsInitOperand(0)->get(), oldIdx)); } @@ -366,7 +366,7 @@ // Then create a new reduction that only reduce the newly added dimension // from the previous op. int64_t intermRank = - partialReduce[0].getType().cast().getRank(); + partialReduce[0].getType().cast().getRank(); AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); SmallVector reductionIteratorTypes; SmallVector exprs; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -73,7 +73,7 @@ ? packPaddings[opOperand->getOperandNumber()] : false; bool hasStaticShape = llvm::none_of(shapeDimsToPad, [&](int64_t dim) { - return ShapedType::isDynamic(shape[dim]); + return RankedShapedType::isDynamic(shape[dim]); }); if (!nofold && hasStaticShape) return opOperand->get(); @@ -312,7 +312,8 @@ } // Fail hoisting if the operand shape is not fully static. - if (llvm::any_of(paddedOp.getShape(&opOperand), ShapedType::isDynamic)) { + if (llvm::any_of(paddedOp.getShape(&opOperand), + RankedShapedType::isDynamic)) { (void)rewriter.notifyMatchFailure(linalgOp, "non static padding shape -- skip"); continue; @@ -810,8 +811,8 @@ PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const { - auto inputShapedType = padOp.getSource().getType().cast(); - auto resultShapedType = padOp.getResult().getType().cast(); + auto inputShapedType = padOp.getSource().getType(); + auto resultShapedType = padOp.getResult().getType(); // Bail on non-static shapes. if (!inputShapedType.hasStaticShape()) @@ -986,7 +987,7 @@ } Location loc = packOp.getLoc(); - ShapedType inputType = packOp.getSourceType(); + auto inputType = packOp.getSourceType(); int64_t inputRank = inputType.getRank(); assert(llvm::all_of(packOp.getDestType().getShape().take_front(inputRank), [](int64_t val) { return val == 1; })); @@ -1058,7 +1059,7 @@ } readSizes.push_back(dimAndTileMapping[i]); readShape.push_back(getConstantIntValue(dimAndTileMapping[i]) - .value_or(ShapedType::kDynamic)); + .value_or(RankedShapedType::kDynamic)); } Type elemType = packOp.getSourceType().getElementType(); auto readType = RankedTensorType::get(readShape, elemType); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -237,7 +237,7 @@ LinalgOp linalgOp) { // TODO: Support 0-d vectors. for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) { - if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) { + if (!RankedShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) { // Create constant index op for static dimensions. iterSpaceValueSizes.push_back(rewriter.create( linalgOp.getLoc(), iterSpaceStaticSizes[vecDim])); @@ -287,7 +287,7 @@ LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << "\n"); - if (ShapedType::isDynamicShape(canonicalVecShape)) + if (RankedShapedType::isDynamicShape(canonicalVecShape)) return failure(); // Initialize iteration space static sizes. @@ -864,7 +864,7 @@ targetShape.back() == 1) return VectorMemoryAccessKind::Gather; - auto inputShape = extractOp.getTensor().getType().cast(); + auto inputShape = extractOp.getTensor().getType(); // 2. Assume that it's a gather load when reading _from_ a tensor for which // the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`. @@ -1302,7 +1302,7 @@ if (!inputVectorSizes.empty()) { assert(inputVectorSizes.size() == linalgOp.getNumLoops() && "Input vector sizes don't match the number of loops"); - assert(!ShapedType::isDynamicShape(inputVectorSizes) && + assert(!RankedShapedType::isDynamicShape(inputVectorSizes) && "Input vector sizes can't have dynamic dimensions"); assert( llvm::all_of( @@ -1310,7 +1310,7 @@ [](std::tuple sizePair) { int64_t staticSize = std::get<0>(sizePair); int64_t inputSize = std::get<1>(sizePair); - return ShapedType::isDynamic(staticSize) || + return RankedShapedType::isDynamic(staticSize) || staticSize <= inputSize; }) && "Input vector sizes must be greater than or equal to iteration space " @@ -1923,7 +1923,7 @@ if (!padValue) return failure(); // Dynamic shapes not supported. - if (!padOp.getResult().getType().cast().hasStaticShape()) + if (!padOp.getResult().getType().hasStaticShape()) return failure(); // Pad result not used as destination. if (insertOp.getDest() == padOp.getResult()) @@ -2164,17 +2164,18 @@ //===----------------------------------------------------------------------===// template -static void bindShapeDims(ShapedType shapedType) {} +static void bindShapeDims(RankedShapedType shapedType) {} template -static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) { +static void bindShapeDims(RankedShapedType shapedType, IntTy &val, + IntTy2 &...vals) { val = shapedType.getShape()[N]; bindShapeDims(shapedType, vals...); } /// Bind a pack of int& to the leading dimensions of shapedType.getShape(). template -static void bindShapeDims(ShapedType shapedType, IntTy &...vals) { +static void bindShapeDims(RankedShapedType shapedType, IntTy &...vals) { bindShapeDims<0>(shapedType, vals...); } @@ -2245,9 +2246,9 @@ lhsShaped = linalgOp.getDpsInputOperand(0)->get(); rhsShaped = linalgOp.getDpsInputOperand(1)->get(); resShaped = linalgOp.getDpsInitOperand(0)->get(); - lhsShapedType = lhsShaped.getType().dyn_cast(); - rhsShapedType = rhsShaped.getType().dyn_cast(); - resShapedType = resShaped.getType().dyn_cast(); + lhsShapedType = lhsShaped.getType().dyn_cast(); + rhsShapedType = rhsShaped.getType().dyn_cast(); + resShapedType = resShaped.getType().dyn_cast(); if (!lhsShapedType || !rhsShapedType || !resShapedType) return; // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR @@ -2822,7 +2823,7 @@ bool isPoolExt = false; int strideW, dilationW; Value lhsShaped, rhsShaped, resShaped; - ShapedType lhsShapedType, rhsShapedType, resShapedType; + RankedShapedType lhsShapedType, rhsShapedType, resShapedType; // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops. // Returns true iff it is a valid conv/pooling op. diff --git a/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp @@ -52,20 +52,20 @@ OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim) { - auto shapedType = val.getType().cast(); - if (!shapedType.hasRank() || shapedType.isDynamicDim(dim)) + auto shapedType = val.getType().dyn_cast(); + if (!shapedType || shapedType.isDynamicDim(dim)) return createOrFoldDimOp(b, loc, val, dim); return b.getIndexAttr(shapedType.getDimSize(dim)); } SmallVector createDynamicDimensions(OpBuilder &b, Location loc, Value val) { - auto shapedType = val.getType().cast(); - assert(shapedType.hasRank() && "`val` must have a static rank"); + auto shapedType = val.getType().dyn_cast(); + assert(shapedType && "`val` must have a static rank"); SmallVector res; res.reserve(shapedType.getRank()); for (const auto &dim : llvm::enumerate(shapedType.getShape())) { - if (dim.value() == ShapedType::kDynamic) + if (dim.value() == RankedShapedType::kDynamic) res.push_back(createOrFoldDimOp(b, loc, val, dim.index())); } return res; @@ -73,8 +73,8 @@ SmallVector getMixedDimensions(OpBuilder &b, Location loc, Value val) { - auto shapedType = val.getType().cast(); - assert(shapedType.hasRank() && "`val` must have a static rank"); + auto shapedType = val.getType().dyn_cast(); + assert(shapedType && "`val` must have a static rank"); SmallVector dynamicDims = createDynamicDimensions(b, loc, val); return getMixedValues(shapedType.getShape(), dynamicDims, b); } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -683,7 +683,7 @@ ArrayRef lbs, ArrayRef ubs, ArrayRef subShapeSizes, bool omitPartialTileCheck) { - auto shapedType = valueToTile.getType().dyn_cast(); + auto shapedType = valueToTile.getType().dyn_cast(); assert(shapedType && "only shaped types can be tiled"); ArrayRef shape = shapedType.getShape(); int64_t rank = shapedType.getRank(); @@ -740,7 +740,7 @@ int64_t shapeSize = shape[r]; std::optional sizeCst = getConstantIntValue(size); auto hasTileSizeOne = sizeCst && *sizeCst == 1; - auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) && + auto dividesEvenly = sizeCst && !RankedShapedType::isDynamic(shapeSize) && ((shapeSize % *sizeCst) == 0); if (!hasTileSizeOne && !dividesEvenly) { LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -31,23 +31,20 @@ namespace saturated_arith { struct Wrapper { static Wrapper stride(int64_t v) { - return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} - : Wrapper{false, v}; + return (RankedShapedType::isDynamic(v)) ? Wrapper{true, 0} + : Wrapper{false, v}; } static Wrapper offset(int64_t v) { - return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} - : Wrapper{false, v}; + return (RankedShapedType::isDynamic(v)) ? Wrapper{true, 0} + : Wrapper{false, v}; } static Wrapper size(int64_t v) { - return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v}; - } - int64_t asOffset() { - return saturated ? ShapedType::kDynamic : v; - } - int64_t asSize() { return saturated ? ShapedType::kDynamic : v; } - int64_t asStride() { - return saturated ? ShapedType::kDynamic : v; + return (RankedShapedType::isDynamic(v)) ? Wrapper{true, 0} + : Wrapper{false, v}; } + int64_t asOffset() { return saturated ? RankedShapedType::kDynamic : v; } + int64_t asSize() { return saturated ? RankedShapedType::kDynamic : v; } + int64_t asStride() { return saturated ? RankedShapedType::kDynamic : v; } bool operator==(Wrapper other) { return (saturated && other.saturated) || (!saturated && !other.saturated && v == other.v); @@ -151,7 +148,7 @@ /// - `memRefTy == memref>` /// - `getAttributes == getConstantStrides` (i.e., a wrapper around /// `getStridesAndOffset`), and -/// - `isDynamic == ShapedType::isDynamic` +/// - `isDynamic == RankedShapedType::isDynamic` /// Will yield: `values == [2, 1]` static void constifyIndexValues( SmallVectorImpl &values, MemRefType memRefTy, @@ -303,7 +300,7 @@ for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { int64_t dimSize = memrefType.getDimSize(dim); // If this is already static dimension, keep it. - if (!ShapedType::isDynamic(dimSize)) { + if (!RankedShapedType::isDynamic(dimSize)) { newShapeConstants.push_back(dimSize); continue; } @@ -315,7 +312,7 @@ newShapeConstants.push_back(constSizeArg.getZExtValue()); } else { // Dynamic shape dimension not folded; copy dynamicSize from old memref. - newShapeConstants.push_back(ShapedType::kDynamic); + newShapeConstants.push_back(RankedShapedType::kDynamic); dynamicSizes.push_back(dynamicSize); } dynamicDimPos++; @@ -718,22 +715,21 @@ for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { auto ss = std::get<0>(it), st = std::get<1>(it); if (ss != st) - if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st)) + if (RankedShapedType::isDynamic(ss) && !RankedShapedType::isDynamic(st)) return false; } // If cast is towards more static offset along any dimension, don't fold. if (sourceOffset != resultOffset) - if (ShapedType::isDynamic(sourceOffset) && - !ShapedType::isDynamic(resultOffset)) + if (RankedShapedType::isDynamic(sourceOffset) && + !RankedShapedType::isDynamic(resultOffset)) return false; // If cast is towards more static strides along any dimension, don't fold. for (auto it : llvm::zip(sourceStrides, resultStrides)) { auto ss = std::get<0>(it), st = std::get<1>(it); if (ss != st) - if (ShapedType::isDynamic(ss) && - !ShapedType::isDynamic(st)) + if (RankedShapedType::isDynamic(ss) && !RankedShapedType::isDynamic(st)) return false; } @@ -766,8 +762,8 @@ // same. They are also compatible if either one is dynamic (see // description of MemRefCastOp for details). auto checkCompatible = [](int64_t a, int64_t b) { - return (ShapedType::isDynamic(a) || - ShapedType::isDynamic(b) || a == b); + return (RankedShapedType::isDynamic(a) || + RankedShapedType::isDynamic(b) || a == b); }; if (!checkCompatible(aOffset, bOffset)) return false; @@ -784,8 +780,8 @@ for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); - if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) && - aDim != bDim) + if (!RankedShapedType::isDynamic(aDim) && + !RankedShapedType::isDynamic(bDim) && aDim != bDim) return false; } return true; @@ -1441,7 +1437,7 @@ SmallVector ExtractStridedMetadataOp::getConstifiedMixedSizes() { SmallVector values = getAsOpFoldResult(getSizes()); constifyIndexValues(values, getSource().getType(), getContext(), - getConstantSizes, ShapedType::isDynamic); + getConstantSizes, RankedShapedType::isDynamic); return values; } @@ -1449,7 +1445,7 @@ ExtractStridedMetadataOp::getConstifiedMixedStrides() { SmallVector values = getAsOpFoldResult(getStrides()); constifyIndexValues(values, getSource().getType(), getContext(), - getConstantStrides, ShapedType::isDynamic); + getConstantStrides, RankedShapedType::isDynamic); return values; } @@ -1457,7 +1453,7 @@ OpFoldResult offsetOfr = getAsOpFoldResult(getOffset()); SmallVector values(1, offsetOfr); constifyIndexValues(values, getSource().getType(), getContext(), - getConstantOffset, ShapedType::isDynamic); + getConstantOffset, RankedShapedType::isDynamic); return values[0]; } @@ -1794,8 +1790,8 @@ OpFoldResult RankOp::fold(FoldAdaptor adaptor) { // Constant fold rank when the rank of the operand is known. auto type = getOperand().getType(); - auto shapedType = type.dyn_cast(); - if (shapedType && shapedType.hasRank()) + auto shapedType = type.dyn_cast(); + if (shapedType) return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); return IntegerAttr(); } @@ -1873,8 +1869,9 @@ // Match sizes in result memref type and in static_sizes attribute. for (auto [idx, resultSize, expectedSize] : llvm::enumerate(resultType.getShape(), getStaticSizes())) { - if (!ShapedType::isDynamic(resultSize) && - !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize) + if (!RankedShapedType::isDynamic(resultSize) && + !RankedShapedType::isDynamic(expectedSize) && + resultSize != expectedSize) return emitError("expected result type with size = ") << expectedSize << " instead of " << resultSize << " in dim = " << idx; @@ -1891,8 +1888,8 @@ // Match offset in result memref type and in static_offsets attribute. int64_t expectedOffset = getStaticOffsets().front(); - if (!ShapedType::isDynamic(resultOffset) && - !ShapedType::isDynamic(expectedOffset) && + if (!RankedShapedType::isDynamic(resultOffset) && + !RankedShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset) return emitError("expected result type with offset = ") << expectedOffset << " instead of " << resultOffset; @@ -1900,8 +1897,8 @@ // Match strides in result memref type and in static_strides attribute. for (auto [idx, resultStride, expectedStride] : llvm::enumerate(resultStrides, getStaticStrides())) { - if (!ShapedType::isDynamic(resultStride) && - !ShapedType::isDynamic(expectedStride) && + if (!RankedShapedType::isDynamic(resultStride) && + !RankedShapedType::isDynamic(expectedStride) && resultStride != expectedStride) return emitError("expected result type with stride = ") << expectedStride << " instead of " << resultStride @@ -1944,14 +1941,14 @@ SmallVector ReinterpretCastOp::getConstifiedMixedSizes() { SmallVector values = getMixedSizes(); constifyIndexValues(values, getType(), getContext(), getConstantSizes, - ShapedType::isDynamic); + RankedShapedType::isDynamic); return values; } SmallVector ReinterpretCastOp::getConstifiedMixedStrides() { SmallVector values = getMixedStrides(); constifyIndexValues(values, getType(), getContext(), getConstantStrides, - ShapedType::isDynamic); + RankedShapedType::isDynamic); return values; } @@ -1960,7 +1957,7 @@ assert(values.size() == 1 && "reinterpret_cast must have one and only one offset"); constifyIndexValues(values, getType(), getContext(), getConstantOffset, - ShapedType::isDynamic); + RankedShapedType::isDynamic); return values[0]; } @@ -2120,7 +2117,7 @@ << expandedDim << " is out of bounds"; // Check if there are multiple dynamic dims in a reassociation group. - if (ShapedType::isDynamic(expandedShape[expandedDim])) { + if (RankedShapedType::isDynamic(expandedShape[expandedDim])) { if (foundDynamic && !allowMultipleDynamicDimsPerGroup) return op->emitOpError( "at most one dimension in a reassociation group may be dynamic"); @@ -2129,7 +2126,8 @@ } // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity. - if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic) + if (RankedShapedType::isDynamic(collapsedShape[collapsedDim]) != + foundDynamic) return op->emitOpError("collapsed dim (") << collapsedDim << ") must be dynamic if and only if reassociation group is " @@ -2323,14 +2321,14 @@ ArrayRef ref = llvm::ArrayRef(reassoc); while (srcShape[ref.back()] == 1 && ref.size() > 1) ref = ref.drop_back(); - if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) { + if (!RankedShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) { resultStrides.push_back(srcStrides[ref.back()]); } else { // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so // the corresponding stride may have to be skipped. (See above comment.) // Therefore, the result stride cannot be statically determined and must // be dynamic. - resultStrides.push_back(ShapedType::kDynamic); + resultStrides.push_back(RankedShapedType::kDynamic); } } @@ -2533,7 +2531,7 @@ if (resultMemRefType) { if (!resultMemRefType.getLayout().isIdentity()) return emitOpError("result memref type should have identity affine map"); - if (shapeSize == ShapedType::kDynamic) + if (shapeSize == RankedShapedType::kDynamic) return emitOpError("cannot use shape operand with dynamic length to " "reshape to statically-ranked memref type"); if (shapeSize != resultMemRefType.getRank()) @@ -3134,8 +3132,8 @@ } OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) { - auto resultShapedType = getResult().getType().cast(); - auto sourceShapedType = getSource().getType().cast(); + auto resultShapedType = getResult().getType().cast(); + auto sourceShapedType = getSource().getType().cast(); if (resultShapedType.hasStaticShape() && resultShapedType == sourceShapedType) { @@ -3327,7 +3325,7 @@ for (unsigned dim = 0, e = rank; dim < e; ++dim) { int64_t dimSize = memrefType.getDimSize(dim); // If this is already static dimension, keep it. - if (!ShapedType::isDynamic(dimSize)) { + if (!RankedShapedType::isDynamic(dimSize)) { newShapeConstants.push_back(dimSize); continue; } diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -81,7 +81,7 @@ bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols}); AffineExpr expr = symbols.front(); - values[0] = ShapedType::isDynamic(sourceOffset) + values[0] = RankedShapedType::isDynamic(sourceOffset) ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) : rewriter.getIndexAttr(sourceOffset); SmallVector subOffsets = subview.getMixedOffsets(); @@ -91,7 +91,7 @@ for (unsigned i = 0; i < sourceRank; ++i) { // Compute the stride. OpFoldResult origStride = - ShapedType::isDynamic(sourceStrides[i]) + RankedShapedType::isDynamic(sourceStrides[i]) ? origStrides[i] : OpFoldResult(rewriter.getIndexAttr(sourceStrides[i])); strides.push_back(makeComposedFoldedAffineApply( @@ -257,7 +257,7 @@ // Fill up all the statically known sizes. for (unsigned i = 0; i < groupSize; ++i) { uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]); - if (ShapedType::isDynamic(dimSize)) { + if (RankedShapedType::isDynamic(dimSize)) { assert(!dynSizeIdx && "There must be at most one dynamic size per group"); dynSizeIdx = i; continue; @@ -325,7 +325,7 @@ for (int i = groupSize - 1; i >= 0; --i) { expandedStrides[i] = builder.getIndexAttr(currentStride); uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]); - if (ShapedType::isDynamic(dimSize)) { + if (RankedShapedType::isDynamic(dimSize)) { assert(!dynSizeIdx && "There must be at most one dynamic size per group"); dynSizeIdx = i; continue; @@ -339,7 +339,7 @@ auto sourceType = source.getType().cast(); auto [strides, offset] = getStridesAndOffset(sourceType); - OpFoldResult origStride = ShapedType::isDynamic(strides[groupId]) + OpFoldResult origStride = RankedShapedType::isDynamic(strides[groupId]) ? origStrides[groupId] : builder.getIndexAttr(strides[groupId]); @@ -349,7 +349,7 @@ // that dimension with the dynamic size. if (dynSizeIdx) { int64_t productOfAllStaticSizes = currentStride; - assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) && + assert(RankedShapedType::isDynamic(sourceType.getDimSize(groupId)) && "We shouldn't be able to change dynamicity"); OpFoldResult origSize = origSizes[groupId]; @@ -434,7 +434,7 @@ MemRefType collapseShapeType = collapseShape.getResultType(); uint64_t size = collapseShapeType.getDimSize(groupId); - if (!ShapedType::isDynamic(size)) { + if (!RankedShapedType::isDynamic(size)) { collapsedSize.push_back(builder.getIndexAttr(size)); return collapsedSize; } @@ -450,7 +450,7 @@ collapsedSize.push_back(getProductOfValues( reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(), - origSizes, ShapedType::isDynamic)); + origSizes, RankedShapedType::isDynamic)); return collapsedSize; } @@ -493,7 +493,7 @@ continue; int64_t currentStride = strides[currentDim]; - groupStrides.push_back(ShapedType::isDynamic(currentStride) + groupStrides.push_back(RankedShapedType::isDynamic(currentStride) ? origStrides[currentDim] : builder.getIndexAttr(currentStride)); } @@ -504,14 +504,14 @@ auto [collapsedStrides, collapsedOffset] = getStridesAndOffset(collapsedType); int64_t finalStride = collapsedStrides[groupId]; - if (ShapedType::isDynamic(finalStride)) { + if (RankedShapedType::isDynamic(finalStride)) { // Look for a dynamic stride. At this point we don't know which one is // desired, but they are all equally good/bad. for (int64_t currentDim : reassocGroup) { assert(srcShape[currentDim] == 1 && "We should be dealing with 1x1x...x1"); - if (ShapedType::isDynamic(strides[currentDim])) + if (RankedShapedType::isDynamic(strides[currentDim])) return {origStrides[currentDim]}; } llvm_unreachable("We should have found a dynamic stride"); @@ -572,7 +572,7 @@ unsigned reshapeRank = reshapeType.getRank(); OpFoldResult offsetOfr = - ShapedType::isDynamic(offset) + RankedShapedType::isDynamic(offset) ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) : rewriter.getIndexAttr(offset); @@ -663,7 +663,7 @@ sizes.reserve(rank); unsigned dynamicPos = 0; for (int64_t size : memRefType.getShape()) { - if (ShapedType::isDynamic(size)) + if (RankedShapedType::isDynamic(size)) sizes.push_back(dynamic[dynamicPos++]); else sizes.push_back(rewriter.getIndexAttr(size)); @@ -747,7 +747,7 @@ // Collect the sizes. ArrayRef sizes = memRefType.getShape(); - assert(!llvm::any_of(sizes, ShapedType::isDynamic) && + assert(!llvm::any_of(sizes, RankedShapedType::isDynamic) && "unexpected dynamic shape for result of `memref.get_global` op"); // Strides (just creates identity strides). diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -60,9 +60,9 @@ // strides from unranked memrefs, so cast the source to a type with fully // dynamic layout, from which we can then extract the offset and strides. // (Rank was already verified.) - int64_t dynamicOffset = ShapedType::kDynamic; + int64_t dynamicOffset = RankedShapedType::kDynamic; SmallVector dynamicShape(resultType.getRank(), - ShapedType::kDynamic); + RankedShapedType::kDynamic); auto stridedLayout = StridedLayoutAttr::get(builder.getContext(), dynamicOffset, dynamicShape); auto dynStridesType = @@ -102,7 +102,7 @@ return; // Check offset. - if (resultOffset != ShapedType::kDynamic) { + if (resultOffset != RankedShapedType::kDynamic) { // Static/dynamic offset -> dynamic offset does not need verification. Value srcOffset = metadataOp.getResult(1); Value resultOffsetVal = @@ -116,7 +116,7 @@ // Check strides. for (const auto &it : llvm::enumerate(resultStrides)) { // Static/dynamic stride -> dynamic stride does not need verification. - if (it.value() == ShapedType::kDynamic) + if (it.value() == RankedShapedType::kDynamic) continue; Value srcStride = diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -130,7 +130,7 @@ } if (quantizedType.isa()) { // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - ShapedType sType = quantizedType.cast(); + RankedShapedType sType = quantizedType.cast(); if (!sType.getElementType().isa()) { return nullptr; } @@ -156,7 +156,8 @@ return *this; } if (candidateType.isa()) { - ShapedType candidateShapedType = candidateType.cast(); + RankedShapedType candidateShapedType = + candidateType.cast(); if (candidateShapedType.getElementType() != getExpressedType()) { return nullptr; } @@ -185,7 +186,7 @@ } if (quantizedType.isa()) { // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - ShapedType sType = quantizedType.cast(); + RankedShapedType sType = quantizedType.cast(); if (!sType.getElementType().isa()) { return nullptr; } diff --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp --- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp +++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp @@ -61,7 +61,7 @@ UniformQuantizedPerAxisValueConverter::convert(DenseFPElementsAttr attr) { // Creates the converter for each chunk. Normally the size of the // quantization dim is 3, so we can cache all the converters. - ShapedType type = attr.getType(); + RankedShapedType type = attr.getType(); size_t dimSize = type.getDimSize(quantizationDim); if (dimSize != scales.size()) { return {}; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1976,7 +1976,7 @@ if (valueType == opType) return success(); auto arrayType = opType.dyn_cast(); - auto shapedType = valueType.dyn_cast(); + auto shapedType = valueType.dyn_cast(); if (!arrayType) return op.emitOpError("result or element type (") << opType << ") does not match value type (" << valueType diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -185,9 +185,9 @@ return elementSize; auto dims = memRefType.getShape(); - if (llvm::is_contained(dims, ShapedType::kDynamic) || - ShapedType::isDynamic(offset) || - llvm::is_contained(strides, ShapedType::kDynamic)) + if (llvm::is_contained(dims, RankedShapedType::kDynamic) || + RankedShapedType::isDynamic(offset) || + llvm::is_contained(strides, RankedShapedType::kDynamic)) return std::nullopt; int64_t memrefSize = -1; @@ -197,7 +197,7 @@ return (offset + memrefSize) * *elementSize; } - if (auto tensorType = type.dyn_cast()) { + if (auto tensorType = type.dyn_cast()) { if (!tensorType.hasStaticShape()) return std::nullopt; @@ -342,7 +342,8 @@ const SPIRVConversionOptions &options, TensorType type) { // TODO: Handle dynamic shapes. - if (!type.hasStaticShape()) { + auto rankedTensorType = dyn_cast(type); + if (!rankedTensorType || !rankedTensorType.hasStaticShape()) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: dynamic shape unimplemented\n"); return nullptr; @@ -817,8 +818,8 @@ int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(baseType, strides, offset)) || - llvm::is_contained(strides, ShapedType::kDynamic) || - ShapedType::isDynamic(offset)) { + llvm::is_contained(strides, RankedShapedType::kDynamic) || + RankedShapedType::isDynamic(offset)) { return nullptr; } @@ -848,8 +849,8 @@ int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(baseType, strides, offset)) || - llvm::is_contained(strides, ShapedType::kDynamic) || - ShapedType::isDynamic(offset)) { + llvm::is_contained(strides, RankedShapedType::kDynamic) || + RankedShapedType::isDynamic(offset)) { return nullptr; } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -48,8 +48,8 @@ LogicalResult shape::getShapeVec(Value input, SmallVectorImpl &shapeValues) { if (auto inputOp = input.getDefiningOp()) { - auto type = inputOp.getArg().getType().cast(); - if (!type.hasRank()) + auto type = inputOp.getArg().getType().dyn_cast(); + if (!type) return failure(); llvm::append_range(shapeValues, type.getShape()); return success(); @@ -1078,8 +1078,8 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) { Type valType = getValue().getType(); - auto valShapedType = valType.dyn_cast(); - if (!valShapedType || !valShapedType.hasRank()) + auto valShapedType = valType.dyn_cast(); + if (!valShapedType) return nullptr; std::optional index = getConstantIndex(); if (!index.has_value()) @@ -1087,7 +1087,7 @@ if (index.value() >= valShapedType.getRank()) return nullptr; auto extent = valShapedType.getDimSize(*index); - if (ShapedType::isDynamic(extent)) + if (RankedShapedType::isDynamic(extent)) return nullptr; return IntegerAttr::get(IndexType::get(getContext()), extent); } @@ -1106,8 +1106,8 @@ } LogicalResult mlir::shape::DimOp::verify() { - auto st = getValue().getType().cast(); - if (!st.hasRank()) + auto st = getValue().getType().dyn_cast(); + if (!st) return success(); if (auto index = getConstantIndex()) { if (*index < 0 || *index >= st.getRank()) @@ -1439,9 +1439,9 @@ } else if (isExtentTensorType(l)) { auto rank1 = l.cast().getShape()[0]; auto rank2 = r.cast().getShape()[0]; - if (ShapedType::isDynamic(rank1)) + if (RankedShapedType::isDynamic(rank1)) acc = l; - else if (ShapedType::isDynamic(rank2)) + else if (RankedShapedType::isDynamic(rank2)) acc = r; else if (rank1 != rank2) return emitOptionalError(location, "unequal shape cardinality"); @@ -1696,7 +1696,7 @@ //===----------------------------------------------------------------------===// OpFoldResult ShapeOfOp::fold(FoldAdaptor) { - auto type = getOperand().getType().dyn_cast(); + auto type = getOperand().getType().dyn_cast(); if (!type || !type.hasStaticShape()) return nullptr; Builder builder(getContext()); @@ -1766,9 +1766,8 @@ if (operands[0].getType().isa()) inferredReturnTypes.assign({ShapeType::get(context)}); else { - auto shapedTy = operands[0].getType().cast(); - int64_t rank = - shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic; + auto shapedTy = operands[0].getType().dyn_cast(); + int64_t rank = shapedTy ? shapedTy.getRank() : RankedShapedType::kDynamic; Type indexTy = IndexType::get(context); Type extentTensorTy = RankedTensorType::get({rank}, indexTy); inferredReturnTypes.assign({extentTensorTy}); diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td --- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td +++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td @@ -11,7 +11,8 @@ }]>>; def HasStaticShape : Constraint().hasStaticShape() + $0.getType().isa() + && $0.getType().cast().hasStaticShape() }]>>; // Helper that takes the first element of a range. diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -666,8 +666,8 @@ const auto valuesNSE = valuesTp.getShape()[0]; const auto coordsNSE = coordinatesTp.getShape()[0]; - if (!ShapedType::isDynamic(valuesNSE) && !ShapedType::isDynamic(coordsNSE) && - valuesNSE != coordsNSE) + if (!RankedShapedType::isDynamic(valuesNSE) && + !RankedShapedType::isDynamic(coordsNSE) && valuesNSE != coordsNSE) return op->emitError("values/coordinates number-of-elements don't match"); // NOTE: We use `getLvlRank` because the `coordinatesTp` is for @@ -675,7 +675,7 @@ const DynSize coordsRank = coordinatesTp.getShape()[1]; const Level tensorRank = tensorTp.getLvlRank(); // FIXME: replace the `operator!=` with our backported `safelyNE`. - if (!ShapedType::isDynamic(coordsRank) && + if (!RankedShapedType::isDynamic(coordsRank) && coordsRank != static_cast(tensorRank)) return op->emitError("input/output level-ranks don't match"); @@ -712,7 +712,7 @@ // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++) - if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic) + if (shape1[d] != shape2[d] && shape2[d] != RankedShapedType::kDynamic) return emitError("unexpected conversion mismatch in dimension ") << d; return success(); } @@ -917,7 +917,7 @@ for (Dimension d = 0; d < dimRank; d++) { const DynSize dstSh = dstTp.getDimShape()[d]; if (d == concatDim) { - if (!ShapedType::isDynamic(dstSh)) { + if (!RankedShapedType::isDynamic(dstSh)) { // If we reach here, then all inputs have static shapes. So we // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)` // to avoid redundant assertions in the loop. @@ -935,7 +935,7 @@ DynSize prev = dstSh; for (const auto src : getInputs()) { const auto sh = getSparseTensorType(src).getDimShape()[d]; - if (!ShapedType::isDynamic(prev) && sh != prev) + if (!RankedShapedType::isDynamic(prev) && sh != prev) return emitError("All dimensions (expect for the concatenating one) " "should be equal."); prev = sh; @@ -1078,7 +1078,7 @@ const DynSize sh = mtp.getShape()[0]; // We can't check the size of dynamic dimension at compile-time, but all // xs and ys should have a dimension not less than n at runtime. - if (n && !ShapedType::isDynamic(sh) && sh < n.value()) + if (n && !RankedShapedType::isDynamic(sh) && sh < n.value()) return emitError(llvm::formatv("xs and ys need to have a dimension >= n" ": {0} < {1}", sh, n.value())); @@ -1116,7 +1116,7 @@ // of this lambda. const auto checkDim = [&](Value v, StaticSize minSize, const char *message) { const DynSize sh = getMemRefType(v).getShape()[0]; - if (!ShapedType::isDynamic(sh) && sh < minSize) + if (!RankedShapedType::isDynamic(sh) && sh < minSize) emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize)); }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -185,10 +185,10 @@ /// a memref with a standard unit stride zero offset layout is returned. inline MemRefType get1DMemRefType(Type etp, bool withLayout) { auto layout = withLayout ? StridedLayoutAttr::StridedLayoutAttr::get( - etp.getContext(), ShapedType::kDynamic, - {ShapedType::kDynamic}) + etp.getContext(), RankedShapedType::kDynamic, + {RankedShapedType::kDynamic}) : StridedLayoutAttr(); - return MemRefType::get(ShapedType::kDynamic, etp, layout); + return MemRefType::get(RankedShapedType::kDynamic, etp, layout); } /// Scans to top of generated loop. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -275,12 +275,12 @@ // expanded from the i-th dimension in srcShape. // For example, if srcDim = 8, then the expanded shape could be <2x?x2>, // but not <2x?x?>. - if (staticDstShape[j] == ShapedType::kDynamic) { + if (staticDstShape[j] == RankedShapedType::kDynamic) { // The expanded dimension has dynamic size. We compute the dimension // by dividing srcDim by the product of the static dimensions. StaticSize product = 1; for (unsigned k = start; k < start + map.size(); k++) { - if (staticDstShape[k] != ShapedType::kDynamic) { + if (staticDstShape[k] != RankedShapedType::kDynamic) { product *= staticDstShape[k]; } } @@ -391,7 +391,7 @@ Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp) { - auto memTp = MemRefType::get({ShapedType::kDynamic}, tp); + auto memTp = MemRefType::get({RankedShapedType::kDynamic}, tp); return builder.create(loc, memTp, ValueRange{sz}); } @@ -420,7 +420,7 @@ auto memTp = MemRefType::get(shape, elemTp); SmallVector dynamicSizes; for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) { - if (shape[i] == ShapedType::kDynamic) + if (shape[i] == RankedShapedType::kDynamic) dynamicSizes.push_back(sizes[i]); } Value mem = builder.create(loc, memTp, dynamicSizes); @@ -584,7 +584,8 @@ const auto memTp = mem.getType().cast(); assert(memTp.getRank() == 1); const DynSize memSh = memTp.getDimSize(0); - assert(ShapedType::isDynamic(memSh) || memSh >= static_cast(size)); + assert(RankedShapedType::isDynamic(memSh) || + memSh >= static_cast(size)); assert(offsetIdx == 0 || offsetIdx < size); #endif // NDEBUG SmallVector vs; @@ -606,7 +607,8 @@ const auto memTp = mem.getType().cast(); assert(memTp.getRank() == 1); const DynSize memSh = memTp.getDimSize(0); - assert(ShapedType::isDynamic(memSh) || memSh >= static_cast(vsize)); + assert(RankedShapedType::isDynamic(memSh) || + memSh >= static_cast(vsize)); assert(offsetIdx == 0 || offsetIdx < vsize); #endif // NDEBUG for (const auto &v : llvm::enumerate(vs)) { @@ -639,7 +641,7 @@ const auto lvlSizesTp = MemRefType::get(lvlSizesShape, iTp); lvlCoords = builder.create(loc, lvlSizesTp, lvlCoords); // Finally, create the ReshapeOp. - const SmallVector resShape(lvlRank, ShapedType::kDynamic); + const SmallVector resShape(lvlRank, RankedShapedType::kDynamic); const Type elemTp = getMemRefType(valuesBuffer).getElementType(); const auto resTp = MemRefType::get(resShape, elemTp); return builder.create(loc, resTp, valuesBuffer, lvlCoords); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -1217,7 +1217,7 @@ auto mtp = getMemRefType(v); if (!mtp.isDynamicDim(0)) { auto newMtp = - MemRefType::get({ShapedType::kDynamic}, mtp.getElementType()); + MemRefType::get({RankedShapedType::kDynamic}, mtp.getElementType()); v = rewriter.create(loc, newMtp, v); } operands.push_back(v); @@ -1313,7 +1313,7 @@ Value c2 = constantIndex(rewriter, loc, 2); auto bufferType = - MemRefType::get({ShapedType::kDynamic}, value.getType()); + MemRefType::get({RankedShapedType::kDynamic}, value.getType()); scf::IfOp ifOp = rewriter.create(loc, bufferType, cond, /*else=*/true); // True branch. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -207,7 +207,7 @@ dimSizes.reserve(dimRank); unsigned i = 0; // cumulative index into `dynSizes`. for (const DynSize sh : stt.getDimShape()) - dimSizes.push_back(ShapedType::isDynamic(sh) + dimSizes.push_back(RankedShapedType::isDynamic(sh) ? dynSizes[i++] : constantIndex(builder, loc, sh)); @@ -871,7 +871,7 @@ const auto sz = sizeFromTensorAtDim(rewriter, loc, desc, innerDim); // Generate a memref for `sz` elements of type `t`. const auto genAlloc = [&](Type t) { - const auto memTp = MemRefType::get({ShapedType::kDynamic}, t); + const auto memTp = MemRefType::get({RankedShapedType::kDynamic}, t); return rewriter.create(loc, memTp, ValueRange{sz}); }; // Allocate temporary buffers for values/filled-switch and added. @@ -1310,7 +1310,7 @@ for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) { // FIXME: dim/lvl confusion! const auto sh = stt.getDimShape()[lvl]; - assert(!ShapedType::isDynamic(sh)); + assert(!RankedShapedType::isDynamic(sh)); desc.setLvlSize(rewriter, loc, lvl, constantIndex(rewriter, loc, sh)); if (lvl == 0) desc.setPosMemSize(rewriter, loc, lvl, constantIndex(rewriter, loc, 2)); @@ -1420,7 +1420,7 @@ createFuncCall(rewriter, loc, "copySparseTensorReaderDimSizes", {}, {reader, dimSizes}, EmitCInterface::On); for (const auto &d : llvm::enumerate(dstTp.getDimShape())) - if (ShapedType::isDynamic(d.value())) + if (RankedShapedType::isDynamic(d.value())) dynSizes.push_back(rewriter.create( loc, dimSizes, constantIndex(rewriter, loc, d.index()))); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -145,7 +145,7 @@ out.clear(); out.reserve(stt.getDimRank()); for (const DynSize sh : stt.getDimShape()) { - const auto s = ShapedType::isDynamic(sh) ? 0 : sh; + const auto s = RankedShapedType::isDynamic(sh) ? 0 : sh; out.push_back(constantIndex(builder, loc, s)); } } @@ -196,7 +196,7 @@ /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack, /// this buffer must be explicitly deallocated by client. static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) { - auto memTp = MemRefType::get({ShapedType::kDynamic}, tp); + auto memTp = MemRefType::get({RankedShapedType::kDynamic}, tp); return rewriter.create(loc, memTp, ValueRange{sz}); } @@ -737,7 +737,7 @@ Value dimSizesBuffer; if (stt.hasDynamicDimShape()) { Type indexTp = rewriter.getIndexType(); - auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp); + auto memTp = MemRefType::get({RankedShapedType::kDynamic}, indexTp); dimSizesBuffer = createFuncCall(rewriter, loc, "getSparseTensorReaderDimSizes", memTp, reader, EmitCInterface::On) @@ -1114,7 +1114,7 @@ Location loc = op.getLoc(); // Query values array size for the actually stored values size. Type eltType = op.getTensor().getType().cast().getElementType(); - auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType); + auto resTp = MemRefType::get({RankedShapedType::kDynamic}, eltType); Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands()); rewriter.replaceOpWithNewOp(op, values, constantIndex(rewriter, loc, 0)); @@ -1308,9 +1308,9 @@ dstTensor = dst; // Get the values buffer for the sparse tensor and reshape it to the // corresponding dense tensor shape. - dst = genValuesCall(rewriter, loc, - MemRefType::get({ShapedType::kDynamic}, elemTp), - {dst}); + dst = genValuesCall( + rewriter, loc, + MemRefType::get({RankedShapedType::kDynamic}, elemTp), {dst}); // Pass the `dstDimCoords` buffer for `reshapeValuesToLevels` // to reuse for storing level-sizes (yes, "level-sizes"). // This is safe to do because `dstTp` is a dense-tensor type, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -116,10 +116,10 @@ /// Populates given sizes array from type (for static sizes) and from /// the tensor (for dynamic sizes). static void sizesForTensor(OpBuilder &builder, SmallVectorImpl &sizes, - Location loc, ShapedType stp, Value tensor) { + Location loc, RankedShapedType stp, Value tensor) { for (const auto &d : enumerate(stp.getShape())) { Value dim; - if (d.value() == ShapedType::kDynamic) + if (d.value() == RankedShapedType::kDynamic) dim = builder.create(loc, tensor, d.index()); else dim = constantIndex(builder, loc, d.value()); @@ -145,7 +145,7 @@ const SmallVectorImpl &sizes, SmallVectorImpl &dynSizes) { for (const auto &d : enumerate(tp.getShape())) { - if (d.value() == ShapedType::kDynamic) + if (d.value() == RankedShapedType::kDynamic) dynSizes.push_back(sizes[d.index()]); } } @@ -183,13 +183,13 @@ /// sizes) and from the source tensors (for dynamic sizes). static void concatSizesFromInputs(OpBuilder &builder, SmallVectorImpl &sizes, Location loc, - ShapedType dstTp, ValueRange srcs, + RankedShapedType dstTp, ValueRange srcs, unsigned dim) { auto dstShape = dstTp.getShape(); sizesFromSrc(builder, sizes, loc, srcs[0]); // Sum up on the `dim` if the dimension is dynamic. - if (dstShape[dim] != ShapedType::kDynamic) { + if (dstShape[dim] != RankedShapedType::kDynamic) { // Faithfully take the static size. sizes[dim] = constantIndex(builder, loc, dstShape[dim]); } else { @@ -375,7 +375,7 @@ genReshapeDstShape(rewriter, loc, dstSizes, srcSizes, dstShape, op.getReassociationIndices()); for (auto [idx, shape] : llvm::enumerate(dstShape)) { - if (shape == ShapedType::kDynamic) + if (shape == RankedShapedType::kDynamic) dstDynSizes.push_back(dstSizes[idx]); } } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp @@ -178,11 +178,14 @@ const Type metaDataType = StorageSpecifierType::get(stt.getEncoding()); // memref positions - const Type posMemType = MemRefType::get({ShapedType::kDynamic}, posType); + const Type posMemType = + MemRefType::get({RankedShapedType::kDynamic}, posType); // memref coordinates - const Type crdMemType = MemRefType::get({ShapedType::kDynamic}, crdType); + const Type crdMemType = + MemRefType::get({RankedShapedType::kDynamic}, crdType); // memref values - const Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType); + const Type valMemType = + MemRefType::get({RankedShapedType::kDynamic}, eltType); foreachFieldInSparseTensor( stt.getEncoding(), diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1241,7 +1241,7 @@ Value tensor = lhs->get(); Location loc = op.getLoc(); if (atStart) { - auto dynShape = {ShapedType::kDynamic}; + auto dynShape = {RankedShapedType::kDynamic}; Type etp = tensor.getType().cast().getElementType(); Type t1 = MemRefType::get(dynShape, etp); Type t2 = MemRefType::get(dynShape, builder.getI1Type()); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp @@ -41,7 +41,7 @@ static OpFoldResult getCollapsedOutputDimFromInputShape( OpBuilder &builder, Location loc, int64_t dimIndex, Value src, ArrayRef dstStaticShape, ArrayRef reassociationMap) { - if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) { + if (!RankedShapedType::isDynamic(dstStaticShape[dimIndex])) { return builder.getIndexAttr(dstStaticShape[dimIndex]); } AffineMap map = reassociationMap[dimIndex]; @@ -78,7 +78,7 @@ OpBuilder &builder, Location loc, int64_t dimIndex, Value src, ArrayRef dstStaticShape, ArrayRef reassociation, llvm::DenseMap &expandedDimToCollapsedDim) { - if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) { + if (!RankedShapedType::isDynamic(dstStaticShape[dimIndex])) { return builder.getIndexAttr(dstStaticShape[dimIndex]); } unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex]; @@ -97,7 +97,7 @@ llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) { if (d.index() + startPos == static_cast(dimIndex)) continue; - assert(!ShapedType::isDynamic(d.value()) && + assert(!RankedShapedType::isDynamic(d.value()) && "single dimension cannot be expanded into multiple dynamic " "dimensions"); linearizedStaticDim *= d.value(); @@ -130,7 +130,8 @@ ArrayRef dstStaticShape, ArrayRef reassocation) { return dstStaticShape.size() > - static_cast(src.getType().cast().getRank()) + static_cast( + src.getType().cast().getRank()) ? getExpandedOutputShapeFromInputShape( builder, loc, src, dstStaticShape, reassocation) : getCollapsedOutputShapeFromInputShape( diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -64,6 +64,7 @@ FailureOr tensor::getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult) { auto tensorType = opResult.getType().dyn_cast(); + auto rankedTensorType = opResult.getType().dyn_cast(); assert(tensorType && "expected tensor type"); // If the op has a destination, it implements DestinationStyleOpInterface and @@ -78,7 +79,7 @@ // Compute sizes. SmallVector mixedSizes; - if (!tensorType.hasStaticShape()) { + if (!rankedTensorType || !rankedTensorType.hasStaticShape()) { // Dynamic shape: Query ReifyRankedShapedTypeOpInterface. ReifiedRankedShapedTypeDims reifiedShapes; if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes))) @@ -86,7 +87,7 @@ mixedSizes = reifiedShapes[opResult.getResultNumber()]; } else { // Static shape: Take static sizes directly. - for (int64_t sz : tensorType.getShape()) + for (int64_t sz : rankedTensorType.getShape()) mixedSizes.push_back(b.getIndexAttr(sz)); } @@ -180,8 +181,8 @@ // If cast is towards more static sizes along any dimension, don't fold. for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) { - if (!ShapedType::isDynamic(std::get<0>(t)) && - ShapedType::isDynamic(std::get<1>(t))) + if (!RankedShapedType::isDynamic(std::get<0>(t)) && + RankedShapedType::isDynamic(std::get<1>(t))) return false; } @@ -281,29 +282,31 @@ static TensorType joinShapes(TensorType one, TensorType two) { assert(one.getElementType() == two.getElementType()); - if (!one.hasRank()) + auto oneRanked = one.dyn_cast(); + if (!oneRanked) return two; - if (!two.hasRank()) + auto twoRanked = two.dyn_cast(); + if (!twoRanked) return one; - int64_t rank = one.getRank(); - if (rank != two.getRank()) + int64_t rank = oneRanked.getRank(); + if (rank != twoRanked.getRank()) return {}; SmallVector join; join.reserve(rank); for (int64_t i = 0; i < rank; ++i) { - if (one.isDynamicDim(i)) { - join.push_back(two.getDimSize(i)); + if (oneRanked.isDynamicDim(i)) { + join.push_back(twoRanked.getDimSize(i)); continue; } - if (two.isDynamicDim(i)) { - join.push_back(one.getDimSize(i)); + if (twoRanked.isDynamicDim(i)) { + join.push_back(oneRanked.getDimSize(i)); continue; } - if (one.getDimSize(i) != two.getDimSize(i)) + if (oneRanked.getDimSize(i) != twoRanked.getDimSize(i)) return {}; - join.push_back(one.getDimSize(i)); + join.push_back(oneRanked.getDimSize(i)); } return RankedTensorType::get(join, one.getElementType()); } @@ -389,7 +392,7 @@ if (dimMask && dimMask->count(i)) continue; int64_t dim = rankedResultType.getShape()[dimIndex++]; - if (ShapedType::isDynamic(dim)) + if (RankedShapedType::isDynamic(dim)) continue; sizes[i] = rewriter.getIndexAttr(dim); } @@ -473,12 +476,12 @@ fromElements.getResult().getType().cast(); // The case where the type encodes the size of the dimension is handled // above. - assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()])); + assert(RankedShapedType::isDynamic(resultType.getShape()[index.getInt()])); // Find the operand of the fromElements that corresponds to this index. auto dynExtents = fromElements.getDynamicExtents().begin(); for (auto dim : resultType.getShape().take_front(index.getInt())) - if (ShapedType::isDynamic(dim)) + if (RankedShapedType::isDynamic(dim)) dynExtents++; return Value{*dynExtents}; @@ -533,7 +536,7 @@ ArrayRef staticShape, Type elementType, Attribute encoding) { assert(all_of(staticShape, - [](int64_t sz) { return !ShapedType::isDynamic(sz); }) && + [](int64_t sz) { return !RankedShapedType::isDynamic(sz); }) && "expected only static sizes"); build(builder, result, staticShape, elementType, ValueRange{}, encoding); } @@ -707,7 +710,7 @@ // Case 1: The empty tensor dim is static. Check that the tensor cast // result dim matches. if (auto attr = currDim.dyn_cast()) { - if (ShapedType::isDynamic(newDim) || + if (RankedShapedType::isDynamic(newDim) || newDim != attr.cast().getInt()) { // Something is off, the cast result shape cannot be more dynamic // than the empty tensor result shape (enforced by @@ -722,7 +725,7 @@ // Case 2 : The tensor cast shape is static, but empty tensor result // shape is dynamic. - if (!ShapedType::isDynamic(newDim)) { + if (!RankedShapedType::isDynamic(newDim)) { newMixedSizes.push_back(rewriter.getIndexAttr(newDim)); continue; } @@ -1129,13 +1132,13 @@ auto operandsIt = tensorFromElements.getDynamicExtents().begin(); for (int64_t dim : resultType.getShape()) { - if (!ShapedType::isDynamic(dim)) { + if (!RankedShapedType::isDynamic(dim)) { newShape.push_back(dim); continue; } APInt index; if (!matchPattern(*operandsIt, m_ConstantInt(&index))) { - newShape.push_back(ShapedType::kDynamic); + newShape.push_back(RankedShapedType::kDynamic); newOperands.push_back(*operandsIt++); continue; } @@ -1210,8 +1213,8 @@ OpFoldResult RankOp::fold(FoldAdaptor adaptor) { // Constant fold rank when the rank of the operand is known. auto type = getOperand().getType(); - auto shapedType = type.dyn_cast(); - if (shapedType && shapedType.hasRank()) + auto shapedType = type.dyn_cast(); + if (shapedType) return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); return IntegerAttr(); } @@ -1225,7 +1228,7 @@ setNameFn(getResult(), "reshape"); } -static int64_t getNumElements(ShapedType type) { +static int64_t getNumElements(RankedShapedType type) { int64_t numElements = 1; for (auto dim : type.getShape()) numElements *= dim; @@ -1252,7 +1255,7 @@ return emitOpError("source and destination tensor should have the " "same number of elements"); } - if (ShapedType::isDynamic(shapeSize)) + if (RankedShapedType::isDynamic(shapeSize)) return emitOpError("cannot use shape operand with dynamic length to " "reshape to statically-ranked tensor type"); if (shapeSize != resultRankedType.getRank()) @@ -1325,8 +1328,8 @@ unsigned dim = m.getNumResults(); auto band = shape.slice(currentDim, dim); int64_t size = 1; - if (llvm::is_contained(band, ShapedType::kDynamic)) - size = ShapedType::kDynamic; + if (llvm::is_contained(band, RankedShapedType::kDynamic)) + size = RankedShapedType::kDynamic; else for (unsigned d = 0; d < dim; ++d) size *= shape[currentDim + d]; @@ -1449,9 +1452,9 @@ if (!fromElements) return failure(); - auto shapedTy = reshapeOp.getType().template cast(); + auto shapedTy = reshapeOp.getType().template dyn_cast(); - if (!shapedTy.hasStaticShape()) + if (!shapedTy || !shapedTy.hasStaticShape()) return failure(); rewriter.replaceOpWithNewOp(reshapeOp, reshapeOp.getType(), @@ -1928,8 +1931,8 @@ return failure(); // Dynamic result shape is not supported. - auto sourceType = op.getSource().getType().cast(); - auto resultType = op.getResult().getType().cast(); + auto sourceType = op.getSource().getType().cast(); + auto resultType = op.getResult().getType().cast(); if (!sourceType.hasStaticShape() || !resultType.hasStaticShape()) return failure(); @@ -1943,13 +1946,13 @@ // Check if there are any dynamic parts, which are not supported. auto offsets = op.getStaticOffsets(); - if (llvm::is_contained(offsets, ShapedType::kDynamic)) + if (llvm::is_contained(offsets, RankedShapedType::kDynamic)) return failure(); auto sizes = op.getStaticSizes(); - if (llvm::is_contained(sizes, ShapedType::kDynamic)) + if (llvm::is_contained(sizes, RankedShapedType::kDynamic)) return failure(); auto strides = op.getStaticStrides(); - if (llvm::is_contained(strides, ShapedType::kDynamic)) + if (llvm::is_contained(strides, RankedShapedType::kDynamic)) return failure(); // Compute the stride for each dimension. @@ -2035,7 +2038,7 @@ // static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, - ShapedType shapedType) { + RankedShapedType shapedType) { OpBuilder b(op.getContext()); for (OpFoldResult ofr : op.getMixedOffsets()) if (getConstantIntValue(ofr) != static_cast(0)) @@ -2069,8 +2072,8 @@ OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) { if (auto splat = adaptor.getSource().dyn_cast_or_null()) { - auto resultType = getResult().getType().cast(); - if (resultType.hasStaticShape()) + auto resultType = getResult().getType().dyn_cast(); + if (resultType && resultType.hasStaticShape()) return splat.resizeSplat(resultType); } if (getSourceType() == getType() && @@ -2537,14 +2540,15 @@ SmallVector inferredShape; for (auto i : llvm::seq(0, rank)) { - if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic || - staticHigh[i] == ShapedType::kDynamic) { - inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic + if (sourceType.isDynamicDim(i) || + staticLow[i] == RankedShapedType::kDynamic || + staticHigh[i] == RankedShapedType::kDynamic) { + inferredShape.push_back(resultShape.empty() ? RankedShapedType::kDynamic : resultShape[i]); } else { int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i]; assert((resultShape.empty() || size == resultShape[i] || - resultShape[i] == ShapedType::kDynamic) && + resultShape[i] == RankedShapedType::kDynamic) && "mismatch between inferred shape and result shape"); inferredShape.push_back(size); } @@ -2571,7 +2575,7 @@ ArrayRef attrs) { auto sourceType = source.getType().cast(); unsigned rank = sourceType.getRank(); - SmallVector staticVector(rank, ShapedType::kDynamic); + SmallVector staticVector(rank, RankedShapedType::kDynamic); build(b, result, resultType, source, staticVector, staticVector, low, high, nofold, attrs); } @@ -2838,7 +2842,7 @@ continue; OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()]; int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()]; - assert(!ShapedType::isDynamic(sourceSize) && + assert(!RankedShapedType::isDynamic(sourceSize) && "expected padded dimension to have a static size"); if (getConstantIntValue(sliceSize) != sourceSize) { return rewriter.notifyMatchFailure( @@ -2896,7 +2900,7 @@ for (auto operand : padTensorOp.getLow()) { APSInt intOp; if (!matchPattern(operand, m_ConstantInt(&intOp))) { - constOperandsLow.push_back(ShapedType::kDynamic); + constOperandsLow.push_back(RankedShapedType::kDynamic); continue; } constOperandsLow.push_back(intOp.getExtValue()); @@ -2905,7 +2909,7 @@ for (auto operand : padTensorOp.getHigh()) { APSInt intOp; if (!matchPattern(operand, m_ConstantInt(&intOp))) { - constOperandsHigh.push_back(ShapedType::kDynamic); + constOperandsHigh.push_back(RankedShapedType::kDynamic); continue; } constOperandsHigh.push_back(intOp.getExtValue()); @@ -2923,9 +2927,9 @@ auto lowCount = 0; auto highCount = 0; for (size_t i = 0; i < inputRank; i++) { - if (constLow[i] == ShapedType::kDynamic) + if (constLow[i] == RankedShapedType::kDynamic) constLow[i] = constOperandsLow[lowCount++]; - if (constHigh[i] == ShapedType::kDynamic) + if (constHigh[i] == RankedShapedType::kDynamic) constHigh[i] = constOperandsHigh[highCount++]; } @@ -2935,12 +2939,12 @@ // Calculate the output sizes with the static information. SmallVector newOutDims; for (size_t i = 0; i < inputRank; i++) { - if (outputDims[i] == ShapedType::kDynamic) { + if (outputDims[i] == RankedShapedType::kDynamic) { newOutDims.push_back( - (staticLow[i] == ShapedType::kDynamic || - staticHigh[i] == ShapedType::kDynamic || - inputDims[i] == ShapedType::kDynamic - ? ShapedType::kDynamic + (staticLow[i] == RankedShapedType::kDynamic || + staticHigh[i] == RankedShapedType::kDynamic || + inputDims[i] == RankedShapedType::kDynamic + ? RankedShapedType::kDynamic : inputDims[i] + staticLow[i] + staticHigh[i])); } else { newOutDims.push_back(outputDims[i]); @@ -2948,8 +2952,9 @@ } if (SmallVector(outputDims) == newOutDims || - llvm::all_of(newOutDims, - [&](int64_t x) { return x == ShapedType::kDynamic; })) + llvm::all_of(newOutDims, [&](int64_t x) { + return x == RankedShapedType::kDynamic; + })) return failure(); // Rewrite the op using the new static type. @@ -3196,7 +3201,7 @@ "applies to only pack or unpack operations"); int64_t destRank = op.getDestRank(); reifiedReturnShapes.resize(1, SmallVector(destRank)); - ShapedType resultType = op.getResult().getType().template cast(); + auto resultType = op.getResult().getType().template cast(); for (auto dim : llvm::seq(0, destRank)) { if (resultType.isDynamicDim(dim)) { reifiedReturnShapes[0][dim] = @@ -3232,7 +3237,7 @@ SmallVector mixedInnerTiles; unsigned dynamicValIndex = 0; for (int64_t staticTile : op.getStaticInnerTiles()) { - if (!ShapedType::isDynamic(staticTile)) + if (!RankedShapedType::isDynamic(staticTile)) mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile)); else mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]); @@ -3280,8 +3285,8 @@ llvm::zip(sourceShape, limitShape), [](std::tuple it) { int64_t sourceExtent = std::get<0>(it); int64_t limit = std::get<1>(it); - return ShapedType::isDynamic(sourceExtent) || - ShapedType::isDynamic(limit) || sourceExtent <= limit; + return RankedShapedType::isDynamic(sourceExtent) || + RankedShapedType::isDynamic(limit) || sourceExtent <= limit; }); } @@ -3327,9 +3332,9 @@ "tiling factors must equal the number of dimensions to tile"); } - ShapedType packedType = (std::is_same::value) - ? packOrUnPack.getDestType() - : packOrUnPack.getSourceType(); + RankedShapedType packedType = (std::is_same::value) + ? packOrUnPack.getDestType() + : packOrUnPack.getSourceType(); size_t packedRank = packedType.getRank(); // Require output rank to match input rank + number of blocking factors. if (unpackedRank + mixedTiles.size() != packedRank) { @@ -3357,9 +3362,9 @@ if (!constTileSize) { // If specified tile size is dynamic, output shape should // be dynamic too. - return ShapedType::isDynamic(shape); + return RankedShapedType::isDynamic(shape); } - if (ShapedType::isDynamic(shape)) { + if (RankedShapedType::isDynamic(shape)) { // For the shape being dynamic when tile size is // specified, return true. In canonical form a constant // tile size should lead to constant shape of the tiled @@ -3474,7 +3479,7 @@ ArrayRef innerDimsPos, ArrayRef innerTiles) { for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) { - if (ShapedType::isDynamic(inputShape[pos])) + if (RankedShapedType::isDynamic(inputShape[pos])) continue; std::optional constantTile = getConstantIntValue(tileSize); if (!constantTile) @@ -3517,9 +3522,10 @@ for (auto o : ofrs) { // Have to do this first, as getConstantIntValue special-cases constants. if (o.dyn_cast()) - result.push_back(ShapedType::kDynamic); + result.push_back(RankedShapedType::kDynamic); else - result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic)); + result.push_back( + getConstantIntValue(o).value_or(RankedShapedType::kDynamic)); } return result; } @@ -3532,10 +3538,10 @@ ArrayRef innerDimsPos, ArrayRef outerDimsPerm) { SmallVector resultShape = llvm::to_vector(sourceShape); for (auto tiledDim : llvm::enumerate(innerDimsPos)) { - if (ShapedType::isDynamic(resultShape[tiledDim.value()])) + if (RankedShapedType::isDynamic(resultShape[tiledDim.value()])) continue; - if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) { - resultShape[tiledDim.value()] = ShapedType::kDynamic; + if (RankedShapedType::isDynamic(innerTileSizes[tiledDim.index()])) { + resultShape[tiledDim.value()] = RankedShapedType::kDynamic; continue; } resultShape[tiledDim.value()] = ceilDiv(resultShape[tiledDim.value()], @@ -3579,7 +3585,7 @@ // use dispatchIndexOpFoldResults on the result, and rely on exact number of // dynamic dims returned by that. for (unsigned i = 0; i < resultDims.size(); ++i) { - if (!ShapedType::isDynamic(resultTypeShape[i])) + if (!RankedShapedType::isDynamic(resultTypeShape[i])) continue; resultDims[i] = getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]); @@ -3612,7 +3618,7 @@ SmallVector mixedSizes; for (auto [index, value] : llvm::enumerate(source.getType().cast().getShape())) { - if (ShapedType::isDynamic(value)) + if (RankedShapedType::isDynamic(value)) mixedSizes.push_back(b.create(loc, source, index).getResult()); else mixedSizes.push_back(b.getIndexAttr(value)); @@ -3648,14 +3654,14 @@ bool areTilesAndTiledDimsAllConstant(OpTy op) { static_assert(llvm::is_one_of::value, "applies to only pack or unpack operations"); - ShapedType packedType = (std::is_same::value) - ? op.getDestType() - : op.getSourceType(); + RankedShapedType packedType = (std::is_same::value) + ? op.getDestType() + : op.getSourceType(); SmallVector mixedTiles = op.getMixedTiles(); for (auto [dimDest, tile] : llvm::zip( packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) { std::optional constTileSize = getConstantIntValue(tile); - if (!constTileSize || ShapedType::isDynamic(dimDest)) + if (!constTileSize || RankedShapedType::isDynamic(dimDest)) return false; } return true; diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -28,7 +28,7 @@ SmallVector high(type.getRank(), zero); for (const auto &en : enumerate(type.getShape())) { // Pad only the static dimensions of the result tensor type. - if (ShapedType::isDynamic(en.value())) + if (RankedShapedType::isDynamic(en.value())) continue; // Compute the padding width. AffineExpr d0; @@ -46,7 +46,7 @@ auto tensorTy = rankedTensor.getType().cast(); SmallVector dynamicDims; for (const auto &en : llvm::enumerate(tensorTy.getShape())) { - if (en.value() == ShapedType::kDynamic) + if (en.value() == RankedShapedType::kDynamic) dynamicDims.push_back( b.create(loc, rankedTensor, en.index())); } @@ -62,7 +62,7 @@ auto shape = tensorTy.getShape(); if (dim >= static_cast(shape.size())) return failure(); - if (ShapedType::isDynamic(shape[dim])) + if (RankedShapedType::isDynamic(shape[dim])) return OpFoldResult(b.createOrFold(loc, rankedTensor, dim)); return OpFoldResult(b.getIndexAttr(shape[dim])); } @@ -72,7 +72,7 @@ auto tensorTy = rankedTensor.getType().cast(); SmallVector dims; for (const auto &en : llvm::enumerate(tensorTy.getShape())) { - if (ShapedType::isDynamic(en.value())) { + if (RankedShapedType::isDynamic(en.value())) { dims.push_back( b.createOrFold(loc, rankedTensor, en.index())); } else { diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -200,8 +200,8 @@ } auto input = op.getInput1(); - auto inputTy = input.getType().cast(); - if (!inputTy.hasRank()) + auto inputTy = input.getType().dyn_cast(); + if (!inputTy) return rewriter.notifyMatchFailure(op, "Unranked input."); int64_t numDynDims = 0; @@ -300,8 +300,8 @@ PatternRewriter &rewriter) const override { Value input = op.getInput(); Value output = op.getOutput(); - ShapedType inputType = input.getType().cast(); - ShapedType outputType = output.getType().cast(); + auto inputType = input.getType().cast(); + auto outputType = output.getType().cast(); if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) { return failure(); @@ -827,8 +827,8 @@ #define REDUCE_FOLDER(OP) \ OpFoldResult OP::fold(FoldAdaptor adaptor) { \ - ShapedType inputTy = getInput().getType().cast(); \ - if (!inputTy.hasRank()) \ + auto inputTy = getInput().getType().dyn_cast(); \ + if (!inputTy) \ return {}; \ if (inputTy.getDimSize(getAxis()) == 1) \ return getInput(); \ @@ -906,14 +906,14 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) { auto operand = getInput(); - auto operandTy = operand.getType().cast(); + auto operandTy = operand.getType().dyn_cast(); auto axis = getAxis(); auto operandAttr = adaptor.getInput().dyn_cast_or_null(); if (operandAttr) return operandAttr; // If the dim-length is 1, tosa.reverse is a no-op. - if (operandTy.hasRank() && operandTy.getDimSize(axis) == 1) + if (operandTy && operandTy.getDimSize(axis) == 1) return operand; return {}; @@ -974,8 +974,8 @@ } OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { - auto inputTy = getInput1().getType().cast(); - auto resultTy = getType().cast(); + auto inputTy = getInput1().getType().cast(); + auto resultTy = getType().cast(); // Transposing splat values just means reshaping. if (auto input = adaptor.getInput1().dyn_cast_or_null()) { diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -397,14 +397,14 @@ return failure(); llvm::SmallVector outputShape; - outputShape.resize(3, ShapedType::kDynamic); + outputShape.resize(3, RankedShapedType::kDynamic); outputShape[0] = inputShape.getDimSize(0); outputShape[1] = inputShape.getDimSize(1); int64_t inWidth = inputShape.getDimSize(2); // Note that we can support this calculation symbolically // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1] - if (inWidth != ShapedType::kDynamic) + if (inWidth != RankedShapedType::kDynamic) outputShape[2] = inWidth / 2 + 1; inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); @@ -438,13 +438,13 @@ // Copy the Operand's rank. if (!hasRankedInput) - outputShape.resize(operandShape.getRank(), ShapedType::kDynamic); + outputShape.resize(operandShape.getRank(), RankedShapedType::kDynamic); // Copy shapes until the dim is non-dynamic. for (int i = 0, s = operandShape.getRank(); i < s; i++) { if (i == axis || operandShape.isDynamicDim(i)) continue; - if (outputShape[i] == ShapedType::kDynamic) + if (outputShape[i] == RankedShapedType::kDynamic) outputShape[i] = operandShape.getDimSize(i); if (outputShape[i] != operandShape.getDimSize(i)) return emitOptionalError(location, @@ -469,7 +469,7 @@ // We need to know the length of the concatenation axis of all inputs to // determine the dimension size of the output shape. if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) { - concatDimSize = ShapedType::kDynamic; + concatDimSize = RankedShapedType::kDynamic; break; } @@ -513,7 +513,7 @@ // All shapes are dynamic. SmallVector outShape; - outShape.resize(2, ShapedType::kDynamic); + outShape.resize(2, RankedShapedType::kDynamic); if (inputShape.hasRank()) { outShape[0] = inputShape.getDimSize(0); @@ -524,8 +524,9 @@ } if (biasShape.hasRank()) { - outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0) - : outShape[1]; + outShape[1] = outShape[1] == RankedShapedType::kDynamic + ? biasShape.getDimSize(0) + : outShape[1]; } inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); @@ -543,7 +544,7 @@ // All shapes are dynamic. SmallVector outShape; - outShape.resize(3, ShapedType::kDynamic); + outShape.resize(3, RankedShapedType::kDynamic); if (lhsShape.hasRank()) { outShape[0] = lhsShape.getDimSize(0); @@ -551,8 +552,9 @@ } if (rhsShape.hasRank()) { - outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0) - : outShape[0]; + outShape[0] = outShape[0] == RankedShapedType::kDynamic + ? rhsShape.getDimSize(0) + : outShape[0]; outShape[2] = rhsShape.getDimSize(2); } @@ -583,7 +585,7 @@ return success(); } - outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamic); + outputShape.resize(paddingShape.getDimSize(0), RankedShapedType::kDynamic); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } @@ -591,7 +593,7 @@ DenseIntElementsAttr paddings; // If the paddings value is not a constant, all dimensions must be dynamic. if (!matchPattern(operands[1], m_Constant(&paddings))) { - outputShape.resize(inputShape.getRank(), ShapedType::kDynamic); + outputShape.resize(inputShape.getRank(), RankedShapedType::kDynamic); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } @@ -604,7 +606,7 @@ outputShape.reserve(inputShape.getRank()); for (int i = 0, s = inputShape.getRank(); i < s; i++) { if (inputShape.isDynamicDim(i)) { - outputShape.push_back(ShapedType::kDynamic); + outputShape.push_back(RankedShapedType::kDynamic); continue; } @@ -618,7 +620,7 @@ static SmallVector convertToMlirShape(ArrayRef shape) { return to_vector(llvm::map_range(shape, [](int64_t dim) { - return dim == -1 ? ShapedType::kDynamic : dim; + return dim == -1 ? RankedShapedType::kDynamic : dim; })); } @@ -656,7 +658,7 @@ ShapeAdaptor inputShape = operands.getShape(0); SmallVector outputShape; if (!inputShape.hasRank()) { - outputShape.resize(multiples.size(), ShapedType::kDynamic); + outputShape.resize(multiples.size(), RankedShapedType::kDynamic); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } @@ -665,7 +667,7 @@ outputShape.reserve(multiples.size()); for (int i = 0, s = inputShape.getRank(); i < s; i++) { int64_t dim = inputShape.getDimSize(i); - if (dim != ShapedType::kDynamic) + if (dim != RankedShapedType::kDynamic) dim *= multiples[i]; outputShape.push_back(dim); } @@ -696,14 +698,14 @@ int64_t numElements = inputShape.getNumElements(); int64_t staticMul = 1; for (auto val : newShapeValue) { - if (!ShapedType::isDynamic(val)) { + if (!RankedShapedType::isDynamic(val)) { staticMul *= val; } } // Determine the length of the dynamic dimension. for (auto &val : newShapeValue) { - if (ShapedType::isDynamic(val)) + if (RankedShapedType::isDynamic(val)) val = numElements / staticMul; } @@ -712,10 +714,11 @@ } mlir::LogicalResult tosa::ReshapeOp::verify() { - ShapedType inputType = getInput1().getType().cast(); - ShapedType outputType = getType().cast(); + auto inputType = getInput1().getType().dyn_cast(); + auto outputType = getType().dyn_cast(); - if (inputType.hasStaticShape() && outputType.hasStaticShape()) { + if (inputType && inputType.hasStaticShape() && outputType && + outputType.hasStaticShape()) { int64_t inputElementsNum = inputType.getNumElements(); int64_t outputElementsNum = outputType.getNumElements(); if (inputElementsNum != outputElementsNum) { @@ -765,7 +768,7 @@ // can determine the output rank. SmallVector outputShape; if (!inputShape.hasRank()) { - outputShape.resize(permsShape.getDimSize(0), ShapedType::kDynamic); + outputShape.resize(permsShape.getDimSize(0), RankedShapedType::kDynamic); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } @@ -793,7 +796,7 @@ return success(); } - outputShape.resize(inputShape.getRank(), ShapedType::kDynamic); + outputShape.resize(inputShape.getRank(), RankedShapedType::kDynamic); // If the permuations are a constant we can directly determine the output // shape. if (ShapeAdaptor permShape = operands.getValueAsShape(1)) { @@ -812,7 +815,7 @@ ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape; - outputShape.resize(3, ShapedType::kDynamic); + outputShape.resize(3, RankedShapedType::kDynamic); ShapeAdaptor valuesShape = operands.getShape(0); if (valuesShape.hasRank()) { @@ -822,9 +825,9 @@ ShapeAdaptor indicesShape = operands.getShape(1); if (indicesShape.hasRank()) { - if (outputShape[0] == ShapedType::kDynamic) + if (outputShape[0] == RankedShapedType::kDynamic) outputShape[0] = indicesShape.getDimSize(0); - if (outputShape[1] == ShapedType::kDynamic) + if (outputShape[1] == RankedShapedType::kDynamic) outputShape[1] = indicesShape.getDimSize(1); } @@ -838,7 +841,7 @@ SmallVectorImpl &inferredReturnShapes) { ResizeOpAdaptor adaptor(operands, attributes); llvm::SmallVector outputShape; - outputShape.resize(4, ShapedType::kDynamic); + outputShape.resize(4, RankedShapedType::kDynamic); ShapeAdaptor inputShape = operands.getShape(adaptor.getInput()); if (!inputShape.hasRank()) @@ -849,8 +852,8 @@ int64_t inputHeight = inputShape.getDimSize(1); int64_t inputWidth = inputShape.getDimSize(2); - if ((inputHeight == ShapedType::kDynamic) || - (inputWidth == ShapedType::kDynamic)) + if ((inputHeight == RankedShapedType::kDynamic) || + (inputWidth == RankedShapedType::kDynamic)) return failure(); llvm::ArrayRef scaleInt = adaptor.getScale(); @@ -877,7 +880,7 @@ ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape; - outputShape.resize(3, ShapedType::kDynamic); + outputShape.resize(3, RankedShapedType::kDynamic); ShapeAdaptor valuesInShape = operands.getShape(0); if (valuesInShape.hasRank()) { @@ -888,15 +891,15 @@ ShapeAdaptor indicesShape = operands.getShape(1); if (indicesShape.hasRank()) { - if (outputShape[0] == ShapedType::kDynamic) + if (outputShape[0] == RankedShapedType::kDynamic) outputShape[0] = indicesShape.getDimSize(0); } ShapeAdaptor inputShape = operands.getShape(2); if (inputShape.hasRank()) { - if (outputShape[0] == ShapedType::kDynamic) + if (outputShape[0] == RankedShapedType::kDynamic) outputShape[0] = inputShape.getDimSize(0); - if (outputShape[2] == ShapedType::kDynamic) + if (outputShape[2] == RankedShapedType::kDynamic) outputShape[2] = inputShape.getDimSize(2); } @@ -1018,7 +1021,7 @@ SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape = operands.getShape(0); llvm::SmallVector outputShape; - outputShape.resize(4, ShapedType::kDynamic); + outputShape.resize(4, RankedShapedType::kDynamic); // We only know the rank if the input type is unranked. if (!inputShape) { @@ -1037,12 +1040,12 @@ ArrayRef stride = attributes.get("stride").cast(); ArrayRef pad = attributes.get("pad").cast(); - if (!ShapedType::isDynamic(height)) { + if (!RankedShapedType::isDynamic(height)) { int64_t padded = height + pad[0] + pad[1] - kernel[0]; outputShape[1] = padded / stride[0] + 1; } - if (!ShapedType::isDynamic(width)) { + if (!RankedShapedType::isDynamic(width)) { int64_t padded = width + pad[2] + pad[3] - kernel[1]; outputShape[2] = padded / stride[1] + 1; } @@ -1055,13 +1058,13 @@ MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - llvm::SmallVector outputShape(4, ShapedType::kDynamic); + llvm::SmallVector outputShape(4, RankedShapedType::kDynamic); Conv2DOp::Adaptor adaptor(operands.getValues(), attributes); - int64_t inputWidth = ShapedType::kDynamic; - int64_t inputHeight = ShapedType::kDynamic; - int64_t weightWidth = ShapedType::kDynamic; - int64_t weightHeight = ShapedType::kDynamic; + int64_t inputWidth = RankedShapedType::kDynamic; + int64_t inputHeight = RankedShapedType::kDynamic; + int64_t weightWidth = RankedShapedType::kDynamic; + int64_t weightHeight = RankedShapedType::kDynamic; // Input shape describes input width/height and batch. @@ -1083,7 +1086,7 @@ // Bias shape can describe the output channels. ShapeAdaptor biasShape = operands.getShape(adaptor.getBias()); if (biasShape.hasRank()) { - outputShape[3] = ShapedType::isDynamic(outputShape[3]) + outputShape[3] = RankedShapedType::isDynamic(outputShape[3]) ? biasShape.getDimSize(0) : outputShape[3]; } @@ -1092,16 +1095,16 @@ llvm::ArrayRef stride = adaptor.getStride(); llvm::ArrayRef padding = adaptor.getPad(); - if (!ShapedType::isDynamic(inputHeight) && - !ShapedType::isDynamic(weightHeight)) { + if (!RankedShapedType::isDynamic(inputHeight) && + !RankedShapedType::isDynamic(weightHeight)) { int64_t inputSize = inputHeight + padding[0] + padding[1]; int64_t filterSize = (weightHeight - 1) * dilation[0] + 1; int64_t unstridedResult = inputSize - filterSize + 1; outputShape[1] = (unstridedResult - 1) / stride[0] + 1; } - if (!ShapedType::isDynamic(inputWidth) && - !ShapedType::isDynamic(weightWidth)) { + if (!RankedShapedType::isDynamic(inputWidth) && + !RankedShapedType::isDynamic(weightWidth)) { int64_t inputSize = inputWidth + padding[2] + padding[3]; int64_t filterSize = (weightWidth - 1) * dilation[1] + 1; int64_t unstridedResult = inputSize - filterSize + 1; @@ -1118,16 +1121,16 @@ MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - llvm::SmallVector outputShape(5, ShapedType::kDynamic); + llvm::SmallVector outputShape(5, RankedShapedType::kDynamic); Conv3DOp::Adaptor adaptor(operands.getValues(), attributes); - int64_t inputWidth = ShapedType::kDynamic; - int64_t inputHeight = ShapedType::kDynamic; - int64_t inputDepth = ShapedType::kDynamic; + int64_t inputWidth = RankedShapedType::kDynamic; + int64_t inputHeight = RankedShapedType::kDynamic; + int64_t inputDepth = RankedShapedType::kDynamic; - int64_t weightWidth = ShapedType::kDynamic; - int64_t weightHeight = ShapedType::kDynamic; - int64_t weightDepth = ShapedType::kDynamic; + int64_t weightWidth = RankedShapedType::kDynamic; + int64_t weightHeight = RankedShapedType::kDynamic; + int64_t weightDepth = RankedShapedType::kDynamic; // Input shape describes input width/height and batch. ShapeAdaptor inputShape = operands.getShape(adaptor.getInput()); @@ -1149,7 +1152,7 @@ // Bias shape can describe the output channels. ShapeAdaptor biasShape = operands.getShape(adaptor.getBias()); - if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) { + if (biasShape.hasRank() && RankedShapedType::isDynamic(outputShape[4])) { outputShape[4] = biasShape.getDimSize(0); } @@ -1157,24 +1160,24 @@ llvm::ArrayRef stride = adaptor.getStride(); llvm::ArrayRef pad = adaptor.getPad(); - if (!ShapedType::isDynamic(inputDepth) && - !ShapedType::isDynamic(weightDepth)) { + if (!RankedShapedType::isDynamic(inputDepth) && + !RankedShapedType::isDynamic(weightDepth)) { int32_t inputSize = inputDepth + pad[0] + pad[1]; int32_t filterSize = (weightDepth - 1) * dilation[0] + 1; int32_t unstridedResult = inputSize - filterSize + 1; outputShape[1] = (unstridedResult - 1) / stride[0] + 1; } - if (!ShapedType::isDynamic(inputHeight) && - !ShapedType::isDynamic(weightHeight)) { + if (!RankedShapedType::isDynamic(inputHeight) && + !RankedShapedType::isDynamic(weightHeight)) { int32_t inputSize = inputHeight + pad[2] + pad[3]; int32_t filterSize = (weightHeight - 1) * dilation[1] + 1; int32_t unstridedResult = inputSize - filterSize + 1; outputShape[2] = (unstridedResult - 1) / stride[1] + 1; } - if (!ShapedType::isDynamic(inputWidth) && - !ShapedType::isDynamic(weightWidth)) { + if (!RankedShapedType::isDynamic(inputWidth) && + !RankedShapedType::isDynamic(weightWidth)) { int32_t inputSize = inputWidth + pad[4] + pad[5]; int32_t filterSize = (weightWidth - 1) * dilation[2] + 1; int32_t unstridedResult = inputSize - filterSize + 1; @@ -1205,16 +1208,16 @@ MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - llvm::SmallVector outputShape(4, ShapedType::kDynamic); + llvm::SmallVector outputShape(4, RankedShapedType::kDynamic); DepthwiseConv2DOp::Adaptor adaptor(operands.getValues(), attributes); - int64_t inputWidth = ShapedType::kDynamic; - int64_t inputHeight = ShapedType::kDynamic; - int64_t inputChannels = ShapedType::kDynamic; + int64_t inputWidth = RankedShapedType::kDynamic; + int64_t inputHeight = RankedShapedType::kDynamic; + int64_t inputChannels = RankedShapedType::kDynamic; - int64_t weightWidth = ShapedType::kDynamic; - int64_t weightHeight = ShapedType::kDynamic; - int64_t depthChannels = ShapedType::kDynamic; + int64_t weightWidth = RankedShapedType::kDynamic; + int64_t weightHeight = RankedShapedType::kDynamic; + int64_t depthChannels = RankedShapedType::kDynamic; // Input shape describes input width/height and batch. ShapeAdaptor inputShape = operands.getShape(adaptor.getInput()); @@ -1230,7 +1233,7 @@ if (weightShape.hasRank()) { weightHeight = weightShape.getDimSize(0); weightWidth = weightShape.getDimSize(1); - inputChannels = ShapedType::isDynamic(inputChannels) + inputChannels = RankedShapedType::isDynamic(inputChannels) ? weightShape.getDimSize(2) : inputChannels; depthChannels = weightShape.getDimSize(3); @@ -1238,15 +1241,15 @@ // If both inputChannels and depthChannels are available we can determine // the output channels. - if (!ShapedType::isDynamic(inputChannels) && - !ShapedType::isDynamic(depthChannels)) { + if (!RankedShapedType::isDynamic(inputChannels) && + !RankedShapedType::isDynamic(depthChannels)) { outputShape[3] = inputChannels * depthChannels; } // Bias shape can describe the output channels. ShapeAdaptor biasShape = operands.getShape(adaptor.getBias()); if (biasShape.hasRank()) { - outputShape[3] = ShapedType::isDynamic(outputShape[3]) + outputShape[3] = RankedShapedType::isDynamic(outputShape[3]) ? biasShape.getDimSize(0) : outputShape[3]; } @@ -1255,16 +1258,16 @@ llvm::ArrayRef padding = adaptor.getPad(); llvm::ArrayRef stride = adaptor.getStride(); - if (!ShapedType::isDynamic(inputHeight) && - !ShapedType::isDynamic(weightHeight)) { + if (!RankedShapedType::isDynamic(inputHeight) && + !RankedShapedType::isDynamic(weightHeight)) { int64_t inputSize = inputHeight + padding[0] + padding[1]; int64_t filterSize = (weightHeight - 1) * dilation[0] + 1; int64_t unstridedResult = inputSize - filterSize + 1; outputShape[1] = (unstridedResult - 1) / stride[0] + 1; } - if (!ShapedType::isDynamic(inputWidth) && - !ShapedType::isDynamic(weightWidth)) { + if (!RankedShapedType::isDynamic(inputWidth) && + !RankedShapedType::isDynamic(weightWidth)) { int64_t inputSize = inputWidth + padding[2] + padding[3]; int64_t filterSize = (weightWidth - 1) * dilation[1] + 1; int64_t unstridedResult = inputSize - filterSize + 1; @@ -1286,15 +1289,15 @@ llvm::SmallVector outputShape = convertToMlirShape(adaptor.getOutShape()); - int64_t inputWidth = ShapedType::kDynamic; - int64_t inputHeight = ShapedType::kDynamic; - int64_t weightWidth = ShapedType::kDynamic; - int64_t weightHeight = ShapedType::kDynamic; + int64_t inputWidth = RankedShapedType::kDynamic; + int64_t inputHeight = RankedShapedType::kDynamic; + int64_t weightWidth = RankedShapedType::kDynamic; + int64_t weightHeight = RankedShapedType::kDynamic; // Input shape describes input width/height and batch. ShapeAdaptor inputShape = operands.getShape(adaptor.getInput()); if (inputShape.hasRank()) { - outputShape[0] = ShapedType::isDynamic(outputShape[0]) + outputShape[0] = RankedShapedType::isDynamic(outputShape[0]) ? inputShape.getDimSize(0) : outputShape[0]; inputHeight = inputShape.getDimSize(1); @@ -1304,7 +1307,7 @@ // Weight shapes describes the filter width/height and the output channels. ShapeAdaptor weightShape = operands.getShape(adaptor.getFilter()); if (weightShape.hasRank()) { - outputShape[3] = ShapedType::isDynamic(outputShape[3]) + outputShape[3] = RankedShapedType::isDynamic(outputShape[3]) ? weightShape.getDimSize(0) : outputShape[3]; weightHeight = weightShape.getDimSize(1); @@ -1314,7 +1317,7 @@ // Bias shape can describe the output channels. ShapeAdaptor biasShape = operands.getShape(adaptor.getInput()); if (biasShape.hasRank()) { - outputShape[3] = ShapedType::isDynamic(outputShape[3]) + outputShape[3] = RankedShapedType::isDynamic(outputShape[3]) ? biasShape.getDimSize(0) : outputShape[3]; } @@ -1322,20 +1325,22 @@ llvm::ArrayRef padding = adaptor.getOutPad(); llvm::ArrayRef stride = adaptor.getStride(); - if (!ShapedType::isDynamic(inputHeight) && - !ShapedType::isDynamic(weightHeight)) { + if (!RankedShapedType::isDynamic(inputHeight) && + !RankedShapedType::isDynamic(weightHeight)) { int64_t calculateSize = (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight; - outputShape[1] = - ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1]; + outputShape[1] = RankedShapedType::isDynamic(outputShape[1]) + ? calculateSize + : outputShape[1]; } - if (!ShapedType::isDynamic(inputWidth) && - !ShapedType::isDynamic(weightWidth)) { + if (!RankedShapedType::isDynamic(inputWidth) && + !RankedShapedType::isDynamic(weightWidth)) { int64_t calculateSize = (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth; - outputShape[2] = - ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2]; + outputShape[2] = RankedShapedType::isDynamic(outputShape[2]) + ? calculateSize + : outputShape[2]; } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp @@ -22,7 +22,7 @@ SmallVector convertFromMlirShape(ArrayRef shape) { return to_vector(llvm::map_range(shape, [](int64_t dim) { - return ShapedType::isDynamic(dim) ? -1 : dim; + return RankedShapedType::isDynamic(dim) ? -1 : dim; })); } @@ -34,16 +34,16 @@ PatternRewriter &rewriter) const override { Value input = op.getInput(); Value weight = op.getWeight(); - ShapedType inputType = input.getType().cast(); - ShapedType weightType = weight.getType().cast(); - ShapedType resultType = op.getType().cast(); + auto inputType = input.getType().cast(); + auto weightType = weight.getType().dyn_cast(); + auto resultType = op.getType().cast(); auto numDynamic = - llvm::count_if(inputType.getShape(), ShapedType::isDynamic); + llvm::count_if(inputType.getShape(), RankedShapedType::isDynamic); if (numDynamic > 1) return rewriter.notifyMatchFailure( op, "at most one dim in input may be dynamic"); - if (!weightType.hasRank()) + if (!weightType) return rewriter.notifyMatchFailure(op, "unranked weight input"); if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; })) @@ -76,7 +76,7 @@ llvm::SmallVector newShape(inputType.getShape()); for (int i = 0, s = newShape.size(); i < s; ++i) { - if (newShape[i] != ShapedType::kDynamic) { + if (newShape[i] != RankedShapedType::kDynamic) { newShape[i] += pad[i * 2] + pad[i * 2 + 1]; } } @@ -98,7 +98,7 @@ // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC]. ArrayRef inputShape = inputType.getShape(); - int64_t combined = ShapedType::kDynamic; + int64_t combined = RankedShapedType::kDynamic; if (numDynamic == 0) combined = inputShape[0] * inputShape[1] * inputShape[2]; llvm::SmallVector revisedInputShape{combined, inputShape[3]}; diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp @@ -28,9 +28,9 @@ PatternRewriter &rewriter) const override { Value input = op.getInput(); Value weight = op.getWeight(); - ShapedType inputType = input.getType().cast(); - ShapedType weightType = weight.getType().cast(); - ShapedType resultType = op.getOutput().getType().cast(); + auto inputType = input.getType().cast(); + auto weightType = weight.getType().cast(); + auto resultType = op.getOutput().getType().cast(); if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && resultType.hasStaticShape())) { @@ -100,7 +100,7 @@ llvm::SmallVector newShape(inputType.getShape()); for (int i = 0, s = pad.size(); i < s; ++i) { - if (newShape[i / 2] != ShapedType::kDynamic) { + if (newShape[i / 2] != RankedShapedType::kDynamic) { newShape[i / 2] += pad[i]; } } @@ -133,7 +133,8 @@ .getResult(); // Reshape output to [N, H, W, C * M]. - auto outputShape = op.getOutput().getType().cast().getShape(); + auto outputShape = + op.getOutput().getType().cast().getShape(); auto outputShapeType = RankedTensorType::get( outputShape, input.getType().dyn_cast().getElementType()); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -82,10 +82,13 @@ Value weight = op->getOperand(1); Value bias = op->getOperand(2); - ShapedType inputTy = input.getType().cast(); - ShapedType weightTy = weight.getType().cast(); - ShapedType biasTy = bias.getType().cast(); - ShapedType resultTy = op->getResult(0).getType().cast(); + auto inputTy = input.getType().dyn_cast(); + auto weightTy = weight.getType().dyn_cast(); + auto biasTy = bias.getType().dyn_cast(); + auto resultTy = op->getResult(0).getType().dyn_cast(); + + if (!inputTy || !weightTy || !biasTy || !resultTy) + return failure(); llvm::ArrayRef stride = op.getStride(); llvm::ArrayRef pad = op.getOutPad(); @@ -145,10 +148,13 @@ Value weight = op->getOperand(1); Value bias = op->getOperand(2); - ShapedType inputTy = input.getType().cast(); - ShapedType weightTy = weight.getType().cast(); - ShapedType biasTy = bias.getType().cast(); - ShapedType resultTy = op->getResult(0).getType().cast(); + auto inputTy = input.getType().dyn_cast(); + auto weightTy = weight.getType().dyn_cast(); + auto biasTy = bias.getType().dyn_cast(); + auto resultTy = op->getResult(0).getType().dyn_cast(); + + if (!inputTy || !weightTy || !biasTy || !resultTy) + return failure(); Type inputETy = inputTy.getElementType(); Type weightETy = weightTy.getElementType(); @@ -230,7 +236,7 @@ weight = createOpAndInfer( rewriter, loc, UnrankedTensorType::get(weightETy), weight, rewriter.getDenseI64ArrayAttr(weightReshapeDims1)); - ShapedType restridedWeightTy = weight.getType().cast(); + auto restridedWeightTy = weight.getType().cast(); weight = createOpAndInfer( rewriter, loc, UnrankedTensorType::get(weightETy), weight, @@ -296,7 +302,7 @@ } // Factor the resulting width / height. - ShapedType convTy = conv2d.getType().cast(); + auto convTy = conv2d.getType().cast(); Type convETy = convTy.getElementType(); int64_t convHeight = convTy.getDimSize(1); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp @@ -22,8 +22,8 @@ namespace { template -DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType, - ShapedType outputType, +DenseElementsAttr transposeType(ElementsAttr attr, RankedShapedType inputType, + RankedShapedType outputType, llvm::ArrayRef permValues) { if (inputType.getNumElements() == 0) return DenseElementsAttr::get(outputType, llvm::ArrayRef{}); diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp --- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp @@ -309,7 +309,7 @@ builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0], narrowRange.getValue(), convfunc.expressedType, isSigned); } else if (min.size() > 1) { // Per-axis quant on filterQuantDim. - auto shape = inputDType.dyn_cast(); + auto shape = inputDType.dyn_cast(); if (!shape) return {}; if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) { diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -41,7 +41,7 @@ // Dimensions are compatible when //. 1. One is dynamic, the rest are 1 - if (ShapedType::isDynamic(dim)) { + if (RankedShapedType::isDynamic(dim)) { if (seenDynamic || nonOneDim) return false; seenDynamic = true; @@ -81,7 +81,7 @@ // Check each dimension is consistent. for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) { - if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) { + if (RankedShapedType::isDynamic(*i1) || RankedShapedType::isDynamic(*i2)) { // One or both dimensions is unknown. Follow TensorFlow behavior: // - If either dimension is greater than 1, we assume that the program is // correct, and the other dimension will be broadcast to match it. @@ -95,7 +95,7 @@ } else if (*i2 == 1) { *iR = *i1; } else { - *iR = ShapedType::kDynamic; + *iR = RankedShapedType::kDynamic; } } else { if (*i1 == *i2 || *i2 == 1) { @@ -116,7 +116,7 @@ /// Returns the shape of the given type. Scalars will be considered as having a /// shape with zero dimensions. static ArrayRef getShape(Type type) { - if (auto sType = type.dyn_cast()) + if (auto sType = type.dyn_cast()) return sType.getShape(); return {}; } @@ -200,8 +200,8 @@ // then it is compatible, else if the inferred dim is 1 then it is also // compatible. But if the existing dim is 1 and the inferred is greater than // 1 then flag. - return dim1 == dim2 || ShapedType::isDynamic(dim1) || - ShapedType::isDynamic(dim2) || dim1 == 1; + return dim1 == dim2 || RankedShapedType::isDynamic(dim1) || + RankedShapedType::isDynamic(dim2) || dim1 == 1; }; if (inferred.size() != existing.size()) return false; @@ -220,7 +220,7 @@ llvm::interleave( shape, ss, [&](int64_t dim) { - if (ShapedType::isDynamic(dim)) + if (RankedShapedType::isDynamic(dim)) ss << '?'; else ss << dim; diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -17,8 +17,8 @@ using namespace mlir; std::optional> -mlir::getReassociationIndicesForReshape(ShapedType sourceType, - ShapedType targetType) { +mlir::getReassociationIndicesForReshape(RankedShapedType sourceType, + RankedShapedType targetType) { if (sourceType.getRank() > targetType.getRank()) return getReassociationIndicesForCollapse(sourceType.getShape(), targetType.getShape()); @@ -48,7 +48,7 @@ int64_t currTargetShape = targetShape[targetDim]; while (sourceDim < sourceShape.size() && - sourceShape[sourceDim] != ShapedType::kDynamic && + sourceShape[sourceDim] != RankedShapedType::kDynamic && prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) { prodOfCollapsedDims *= sourceShape[sourceDim]; currIndices.push_back(sourceDim++); @@ -57,14 +57,15 @@ // If the current expanded dimension is dynamic, then the collapsed // dimensions should also be dynamic and product of all previous unprocessed // dimensions of the expanded shape should be 1. - if (sourceShape[sourceDim] == ShapedType::kDynamic && - (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1)) + if (sourceShape[sourceDim] == RankedShapedType::kDynamic && + (currTargetShape != RankedShapedType::kDynamic || + prodOfCollapsedDims != 1)) return std::nullopt; // If the collapsed dim is dynamic, the current expanded dim should also // be dynamic. - if (currTargetShape == ShapedType::kDynamic && - sourceShape[sourceDim] != ShapedType::kDynamic) + if (currTargetShape == RankedShapedType::kDynamic && + sourceShape[sourceDim] != RankedShapedType::kDynamic) return std::nullopt; // For static shapes, if the product of dimensions of the expanded shape @@ -83,7 +84,7 @@ // Process any remaining entries in the source shape. They all need to be // 1 or dynamic. for (; sourceDim < sourceShape.size(); sourceDim++) { - if (sourceShape[sourceDim] != ShapedType::kDynamic && + if (sourceShape[sourceDim] != RankedShapedType::kDynamic && sourceShape[sourceDim] != 1) return std::nullopt; // The map is empty when the target type is a scalar. @@ -234,7 +235,7 @@ int64_t linearizedStaticShape = 1; for (const auto &dim : llvm::enumerate( expandedShape.slice(expandedDimStart, map.value().size()))) { - if (ShapedType::isDynamic(dim.value())) { + if (RankedShapedType::isDynamic(dim.value())) { if (isExpandingReshape && dynamicShape) { return emitError("invalid to have a single dimension (" + Twine(map.index()) + @@ -248,7 +249,7 @@ } } if (dynamicShape) { - if (!ShapedType::isDynamic(collapsedShape[map.index()])) { + if (!RankedShapedType::isDynamic(collapsedShape[map.index()])) { return emitError( "expected dimension " + Twine(map.index()) + " of collapsed type to be dynamic since one or more of the " diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -58,7 +58,7 @@ return; } dynamicVec.push_back(v); - staticVec.push_back(ShapedType::kDynamic); + staticVec.push_back(RankedShapedType::kDynamic); } void dispatchIndexOpFoldResults(ArrayRef ofrs, @@ -158,7 +158,7 @@ } /// Return a vector of OpFoldResults with the same size a staticValues, but all -/// elements for which ShapedType::isDynamic is true, will be replaced by +/// elements for which RankedShapedType::isDynamic is true, will be replaced by /// dynamicValues. SmallVector getMixedValues(ArrayRef staticValues, ValueRange dynamicValues, Builder &b) { @@ -168,7 +168,7 @@ unsigned count = static_cast(staticValues.size()); for (unsigned idx = 0; idx < count; ++idx) { int64_t value = staticValues[idx]; - res.push_back(ShapedType::isDynamic(value) + res.push_back(RankedShapedType::isDynamic(value) ? OpFoldResult{dynamicValues[numDynamic++]} : OpFoldResult{b.getI64IntegerAttr(staticValues[idx])}); } @@ -186,7 +186,7 @@ if (it.is()) { staticValues.push_back(it.get().cast().getInt()); } else { - staticValues.push_back(ShapedType::kDynamic); + staticValues.push_back(RankedShapedType::kDynamic); dynamicValues.push_back(it.get()); } } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -139,7 +139,7 @@ return succeeded(successStrides) && (strides.empty() || strides.back() == 1); } -AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType, +AffineMap mlir::vector::getTransferMinorIdentityMap(RankedShapedType shapedType, VectorType vectorType) { int64_t elementVectorRank = 0; VectorType elementVectorType = @@ -916,7 +916,7 @@ VectorType rhsType = this->getRhsType(); unsigned numVecDims = lhsIdxMap.getNumDims(); - SmallVector maskShape(numVecDims, ShapedType::kDynamic); + SmallVector maskShape(numVecDims, RankedShapedType::kDynamic); // Using the information in the indexing maps, extract the size of each // dimension in the vector.contract operation from the two input operands. @@ -925,7 +925,7 @@ for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize; - assert(!ShapedType::isDynamicShape(maskShape) && + assert(!RankedShapedType::isDynamicShape(maskShape) && "Mask shape couldn't be computed"); return VectorType::get(maskShape, @@ -3347,7 +3347,7 @@ } static LogicalResult -verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, +verifyTransferOp(VectorTransferOpInterface op, RankedShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds) { @@ -3535,7 +3535,7 @@ LogicalResult TransferReadOp::verify() { // Consistency of elemental types in source and vector. - ShapedType shapedType = getShapedType(); + RankedShapedType shapedType = getShapedType(); VectorType vectorType = getVectorType(); VectorType maskType = getMaskType(); auto paddingType = getPadding().getType(); @@ -3994,7 +3994,7 @@ LogicalResult TransferWriteOp::verify() { // Consistency of elemental types in shape and vector. - ShapedType shapedType = getShapedType(); + RankedShapedType shapedType = getShapedType(); VectorType vectorType = getVectorType(); VectorType maskType = getMaskType(); auto permutationMap = getPermutationMap(); @@ -4602,7 +4602,7 @@ VectorType indVType = getIndexVectorType(); VectorType maskVType = getMaskVectorType(); VectorType resVType = getVectorType(); - ShapedType baseType = getBaseType(); + RankedShapedType baseType = getBaseType(); if (!baseType.isa()) return emitOpError("requires base to be a memref or ranked tensor type"); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -427,7 +427,7 @@ loc, accType, multiReductionOp.getAcc()); Value castMask; if (maskableOp.isMasked()) { - auto maskType = mask.getType().cast(); + auto maskType = mask.getType().cast(); auto castMaskType = VectorType::get(ArrayRef{1, maskType.getShape().back()}, maskType.getElementType()); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -123,7 +123,8 @@ // appropriate indices for the extract/insert operations. Value result = rewriter.create( loc, resType, rewriter.getZeroAttr(resType)); - int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape); + int64_t numTransposedElements = + RankedShapedType::getNumElements(prunedInShape); for (int64_t linearIdx = 0; linearIdx < numTransposedElements; ++linearIdx) { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -406,7 +406,7 @@ /// input starting at `firstDimToCollapse`. static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, Value input, int64_t firstDimToCollapse) { - ShapedType inputType = input.getType().cast(); + auto inputType = input.getType().cast(); if (inputType.getRank() == 1) return input; SmallVector reassociation; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -171,11 +171,12 @@ resStrides(bT.getRank(), 0); for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) { resShape[idx] = - (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic; - resStrides[idx] = - (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic; + (aShape[idx] == bShape[idx]) ? aShape[idx] : RankedShapedType::kDynamic; + resStrides[idx] = (aStrides[idx] == bStrides[idx]) + ? aStrides[idx] + : RankedShapedType::kDynamic; } - resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic; + resOffset = (aOffset == bOffset) ? aOffset : RankedShapedType::kDynamic; return MemRefType::get( resShape, aT.getElementType(), StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides)); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2249,7 +2249,7 @@ } static void -printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os, +printDenseElementsAttrImpl(bool isSplat, RankedShapedType type, raw_ostream &os, function_ref printEltFn) { // Special case for 0-d and splat tensors. if (isSplat) @@ -2476,7 +2476,7 @@ .Case([&](RankedTensorType tensorTy) { os << "tensor<"; for (int64_t dim : tensorTy.getShape()) { - if (ShapedType::isDynamic(dim)) + if (RankedShapedType::isDynamic(dim)) os << '?'; else os << dim; @@ -2498,7 +2498,7 @@ .Case([&](MemRefType memrefTy) { os << "memref<"; for (int64_t dim : memrefTy.getShape()) { - if (ShapedType::isDynamic(dim)) + if (RankedShapedType::isDynamic(dim)) os << '?'; else os << dim; diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -83,7 +83,8 @@ /// signals if the data is already known to be a splat. Callers to this /// function are expected to tag preknown splat values when possible, e.g. one /// element shapes. - static KeyTy getKey(ShapedType ty, ArrayRef data, bool isKnownSplat) { + static KeyTy getKey(RankedShapedType ty, ArrayRef data, + bool isKnownSplat) { // Handle an empty storage instance. if (data.empty()) return KeyTy(ty, data, 0); @@ -234,7 +235,7 @@ /// signals if the data is already known to be a splat. Callers to this /// function are expected to tag preknown splat values when possible, e.g. one /// element shapes. - static KeyTy getKey(ShapedType ty, ArrayRef data, + static KeyTy getKey(RankedShapedType ty, ArrayRef data, bool isKnownSplat) { // Handle an empty storage instance. if (data.empty()) diff --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp --- a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp +++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp @@ -32,7 +32,8 @@ return elementsAttr.getType().getNumElements(); } -bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef index) { +bool ElementsAttr::isValidIndex(RankedShapedType type, + ArrayRef index) { // Verify that the rank of the indices matches the held type. int64_t rank = type.getRank(); if (rank == 0 && index.size() == 1 && index[0] == 0) @@ -53,7 +54,7 @@ } uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef index) { - ShapedType shapeType = type.cast(); + RankedShapedType shapeType = type.cast(); assert(isValidIndex(shapeType, index) && "expected valid multi-dimensional index"); diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -208,7 +208,7 @@ /// Prints a strided layout attribute. void StridedLayoutAttr::print(llvm::raw_ostream &os) const { auto printIntOrQuestion = [&](int64_t value) { - if (ShapedType::isDynamic(value)) + if (RankedShapedType::isDynamic(value)) os << "?"; else os << value; @@ -581,7 +581,8 @@ /// Returns true if 'values' corresponds to a splat, i.e. one element, or has /// the same element count as 'type'. template -static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { +static bool hasSameElementsOrSplat(RankedShapedType type, + const Values &values) { return (values.size() == 1) || (type.getNumElements() == static_cast(values.size())); } @@ -887,7 +888,7 @@ return attr.isa(); } -DenseElementsAttr DenseElementsAttr::get(ShapedType type, +DenseElementsAttr DenseElementsAttr::get(RankedShapedType type, ArrayRef values) { assert(hasSameElementsOrSplat(type, values)); @@ -972,7 +973,7 @@ return DenseIntOrFPElementsAttr::getRaw(type, data); } -DenseElementsAttr DenseElementsAttr::get(ShapedType type, +DenseElementsAttr DenseElementsAttr::get(RankedShapedType type, ArrayRef values) { assert(hasSameElementsOrSplat(type, values)); assert(type.getElementType().isInteger(1)); @@ -997,7 +998,7 @@ return DenseIntOrFPElementsAttr::getRaw(type, buff); } -DenseElementsAttr DenseElementsAttr::get(ShapedType type, +DenseElementsAttr DenseElementsAttr::get(RankedShapedType type, ArrayRef values) { assert(!type.getElementType().isIntOrFloat()); return DenseStringElementsAttr::get(type, values); @@ -1006,14 +1007,14 @@ /// Constructs a dense integer elements attribute from an array of APInt /// values. Each APInt value is expected to have the same bitwidth as the /// element type of 'type'. -DenseElementsAttr DenseElementsAttr::get(ShapedType type, +DenseElementsAttr DenseElementsAttr::get(RankedShapedType type, ArrayRef values) { assert(type.getElementType().isIntOrIndex()); assert(hasSameElementsOrSplat(type, values)); size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values); } -DenseElementsAttr DenseElementsAttr::get(ShapedType type, +DenseElementsAttr DenseElementsAttr::get(RankedShapedType type, ArrayRef> values) { ComplexType complex = type.getElementType().cast(); assert(complex.getElementType().isa()); @@ -1027,7 +1028,7 @@ // Constructs a dense float elements attribute from an array of APFloat // values. Each APFloat value is expected to have the same bitwidth as the // element type of 'type'. -DenseElementsAttr DenseElementsAttr::get(ShapedType type, +DenseElementsAttr DenseElementsAttr::get(RankedShapedType type, ArrayRef values) { assert(type.getElementType().isa()); assert(hasSameElementsOrSplat(type, values)); @@ -1035,7 +1036,7 @@ return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values); } DenseElementsAttr -DenseElementsAttr::get(ShapedType type, +DenseElementsAttr::get(RankedShapedType type, ArrayRef> values) { ComplexType complex = type.getElementType().cast(); assert(complex.getElementType().isa()); @@ -1050,12 +1051,13 @@ /// data for this attribute. Users should generally not use this methods as /// the expected buffer format may not be a form the user expects. DenseElementsAttr -DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef rawBuffer) { +DenseElementsAttr::getFromRawBuffer(RankedShapedType type, + ArrayRef rawBuffer) { return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer); } /// Returns true if the given buffer is a valid raw buffer for the given type. -bool DenseElementsAttr::isValidRawBuffer(ShapedType type, +bool DenseElementsAttr::isValidRawBuffer(RankedShapedType type, ArrayRef rawBuffer, bool &detectedSplat) { size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); @@ -1119,14 +1121,14 @@ } /// Defaults down the subclass implementation. -DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, +DenseElementsAttr DenseElementsAttr::getRawComplex(RankedShapedType type, ArrayRef data, int64_t dataEltSize, bool isInt, bool isSigned) { return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, isSigned); } -DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, +DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(RankedShapedType type, ArrayRef data, int64_t dataEltSize, bool isInt, @@ -1203,8 +1205,8 @@ /// Return a new DenseElementsAttr that has the same data as the current /// attribute, but has been reshaped to 'newType'. The new type must have the /// same total number of elements as well as element type. -DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { - ShapedType curType = getType(); +DenseElementsAttr DenseElementsAttr::reshape(RankedShapedType newType) { + RankedShapedType curType = getType(); if (curType == newType) return *this; @@ -1215,10 +1217,10 @@ return DenseIntOrFPElementsAttr::getRaw(newType, getRawData()); } -DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) { +DenseElementsAttr DenseElementsAttr::resizeSplat(RankedShapedType newType) { assert(isSplat() && "expected a splat type"); - ShapedType curType = getType(); + RankedShapedType curType = getType(); if (curType == newType) return *this; @@ -1232,7 +1234,7 @@ /// type must have the same shape and element types of the same bitwidth as the /// current type. DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) { - ShapedType curType = getType(); + RankedShapedType curType = getType(); Type curElType = curType.getElementType(); if (curElType == newElType) return *this; @@ -1255,7 +1257,7 @@ return cast().mapValues(newElementType, mapping); } -ShapedType DenseElementsAttr::getType() const { +RankedShapedType DenseElementsAttr::getType() const { return static_cast(impl)->type; } @@ -1292,7 +1294,7 @@ /// Constructs a dense elements attribute from an array of raw APFloat values. /// Each APFloat value is expected to have the same bitwidth as the element /// type of 'type'. 'type' must be a vector or tensor with static shape. -DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, +DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(RankedShapedType type, size_t storageWidth, ArrayRef values) { std::vector data; @@ -1304,7 +1306,7 @@ /// Constructs a dense elements attribute from an array of raw APInt values. /// Each APInt value is expected to have the same bitwidth as the element type /// of 'type'. -DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, +DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(RankedShapedType type, size_t storageWidth, ArrayRef values) { std::vector data; @@ -1312,7 +1314,7 @@ return DenseIntOrFPElementsAttr::getRaw(type, data); } -DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, +DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(RankedShapedType type, ArrayRef data) { assert(type.hasStaticShape() && "type must have static shape"); bool isSplat = false; @@ -1325,7 +1327,7 @@ /// Overload of the raw 'get' method that asserts that the given type is of /// complex type. This method is used to verify type invariants that the /// templatized 'get' method cannot. -DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, +DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(RankedShapedType type, ArrayRef data, int64_t dataEltSize, bool isInt, @@ -1343,10 +1345,9 @@ /// Overload of the 'getRaw' method that asserts that the given type is of /// integer type. This method is used to verify type invariants that the /// templatized 'get' method cannot. -DenseElementsAttr -DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef data, - int64_t dataEltSize, bool isInt, - bool isSigned) { +DenseElementsAttr DenseIntOrFPElementsAttr::getRawIntOrFloat( + RankedShapedType type, ArrayRef data, int64_t dataEltSize, bool isInt, + bool isSigned) { assert( ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); @@ -1401,7 +1402,7 @@ void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( ArrayRef inRawData, MutableArrayRef outRawData, - ShapedType type) { + RankedShapedType type) { size_t numElements = type.getNumElements(); Type elementType = type.getElementType(); if (ComplexType complexTy = elementType.dyn_cast()) { @@ -1423,13 +1424,14 @@ //===----------------------------------------------------------------------===// template -static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, - Type newElementType, - llvm::SmallVectorImpl &data) { +static RankedShapedType +mappingHelper(Fn mapping, Attr &attr, RankedShapedType inType, + Type newElementType, llvm::SmallVectorImpl &data) { size_t bitWidth = getDenseElementBitWidth(newElementType); size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); - ShapedType newArrayType = inType.cloneWith(inType.getShape(), newElementType); + RankedShapedType newArrayType = cast( + inType.cloneWith(inType.getShape(), newElementType)); size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); data.resize(llvm::divideCeil(storageBitWidth * numRawElements, CHAR_BIT)); @@ -1499,12 +1501,12 @@ //===----------------------------------------------------------------------===// DenseResourceElementsAttr -DenseResourceElementsAttr::get(ShapedType type, +DenseResourceElementsAttr::get(RankedShapedType type, DenseResourceElementsHandle handle) { return Base::get(type.getContext(), type, handle); } -DenseResourceElementsAttr DenseResourceElementsAttr::get(ShapedType type, +DenseResourceElementsAttr DenseResourceElementsAttr::get(RankedShapedType type, StringRef blobName, AsmResourceBlob blob) { // Extract the builtin dialect resource manager from context and construct a @@ -1573,7 +1575,7 @@ template DenseResourceElementsAttrBase -DenseResourceElementsAttrBase::get(ShapedType type, StringRef blobName, +DenseResourceElementsAttrBase::get(RankedShapedType type, StringRef blobName, AsmResourceBlob blob) { // Check that the blob is in the form we were expecting. assert(blob.getDataAlignment() == alignof(T) && @@ -1687,16 +1689,15 @@ return flatSparseIndices; } -LogicalResult -SparseElementsAttr::verify(function_ref emitError, - ShapedType type, DenseIntElementsAttr sparseIndices, - DenseElementsAttr values) { - ShapedType valuesType = values.getType(); +LogicalResult SparseElementsAttr::verify( + function_ref emitError, RankedShapedType type, + DenseIntElementsAttr sparseIndices, DenseElementsAttr values) { + RankedShapedType valuesType = values.getType(); if (valuesType.getRank() != 1) return emitError() << "expected 1-d tensor for sparse element values"; // Verify the indices and values shape. - ShapedType indicesType = sparseIndices.getType(); + RankedShapedType indicesType = sparseIndices.getType(); auto emitShapeError = [&]() { return emitError() << "expected shape ([" << type.getShape() << "]); inferred shape of indices literal ([" @@ -1757,7 +1758,7 @@ // AffineExpr for offset. // Static case. - if (!ShapedType::isDynamic(offset)) { + if (!RankedShapedType::isDynamic(offset)) { auto cst = getAffineConstantExpr(offset, context); expr = cst; } else { @@ -1774,7 +1775,7 @@ auto d = getAffineDimExpr(dim, context); AffineExpr mult; // Static case. - if (!ShapedType::isDynamic(stride)) + if (!RankedShapedType::isDynamic(stride)) mult = getAffineConstantExpr(stride, context); else // Dynamic case, new symbol for each new stride. diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp --- a/mlir/lib/IR/BuiltinDialectBytecode.cpp +++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp @@ -536,7 +536,7 @@ DenseStringElementsAttr BuiltinDialectBytecodeInterface::readDenseStringElementsAttr( DialectBytecodeReader &reader) const { - ShapedType type; + RankedShapedType type; uint64_t isSplat; if (failed(reader.readType(type)) || failed(reader.readVarInt(isSplat))) return DenseStringElementsAttr(); diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp --- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp +++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp @@ -20,14 +20,15 @@ #include "mlir/IR/BuiltinTypeInterfaces.cpp.inc" //===----------------------------------------------------------------------===// -// ShapedType +// RankedShapedType //===----------------------------------------------------------------------===// -constexpr int64_t ShapedType::kDynamic; +constexpr int64_t RankedShapedType::kDynamic; -int64_t ShapedType::getNumElements(ArrayRef shape) { +int64_t RankedShapedType::getNumElements(ArrayRef shape) { int64_t num = 1; for (int64_t dim : shape) { + assert(!RankedShapedType::isDynamic(dim) && "expected only static dims"); num *= dim; assert(num >= 0 && "integer overflow in element count computation"); } diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -271,10 +271,6 @@ bool TensorType::hasRank() const { return !isa(); } -ArrayRef TensorType::getShape() const { - return cast().getShape(); -} - TensorType TensorType::cloneWith(std::optional> shape, Type elementType) const { if (auto unrankedTy = dyn_cast()) { @@ -319,7 +315,7 @@ ArrayRef shape, Type elementType, Attribute encoding) { for (int64_t s : shape) - if (s < 0 && !ShapedType::isDynamic(s)) + if (s < 0 && !RankedShapedType::isDynamic(s)) return emitError() << "invalid tensor dimension size"; if (auto v = encoding.dyn_cast_or_null()) if (failed(v.verifyEncoding(shape, elementType, emitError))) @@ -349,10 +345,6 @@ bool BaseMemRefType::hasRank() const { return !isa(); } -ArrayRef BaseMemRefType::getShape() const { - return cast().getShape(); -} - BaseMemRefType BaseMemRefType::cloneWith(std::optional> shape, Type elementType) const { if (auto unrankedTy = dyn_cast()) { @@ -426,9 +418,9 @@ if (originalType == candidateReducedType) return SliceVerificationResult::Success; - ShapedType originalShapedType = originalType.cast(); - ShapedType candidateReducedShapedType = - candidateReducedType.cast(); + auto originalShapedType = originalType.cast(); + auto candidateReducedShapedType = + candidateReducedType.cast(); // Rank and size logic is valid for all ShapedTypes. ArrayRef originalShape = originalShapedType.getShape(); @@ -617,7 +609,7 @@ // Negative sizes are not allowed except for `kDynamic`. for (int64_t s : shape) - if (s < 0 && !ShapedType::isDynamic(s)) + if (s < 0 && !RankedShapedType::isDynamic(s)) return emitError() << "invalid memref size"; assert(layout && "missing layout specification"); @@ -801,12 +793,12 @@ if (auto cst = offsetExpr.dyn_cast()) offset = cst.getValue(); else - offset = ShapedType::kDynamic; + offset = RankedShapedType::kDynamic; for (auto e : strideExprs) { if (auto c = e.dyn_cast()) strides.push_back(c.getValue()); else - strides.push_back(ShapedType::kDynamic); + strides.push_back(RankedShapedType::kDynamic); } return success(); } diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -63,8 +63,8 @@ for (auto dims : llvm::zip(shape1, shape2)) { int64_t dim1 = std::get<0>(dims); int64_t dim2 = std::get<1>(dims); - if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) && - dim1 != dim2) + if (!RankedShapedType::isDynamic(dim1) && + !RankedShapedType::isDynamic(dim2) && dim1 != dim2) return failure(); } return success(); @@ -85,10 +85,11 @@ if (!sType2) return failure(); - if (!sType1.hasRank() || !sType2.hasRank()) + if (!isa(sType1) || !isa(sType2)) return success(); - return verifyCompatibleShape(sType1.getShape(), sType2.getShape()); + return verifyCompatibleShape(cast(sType1).getShape(), + cast(sType2).getShape()); } /// Returns success if the given two arrays have the same number of elements and @@ -107,10 +108,10 @@ return success(); auto staticDim = std::accumulate( dims.begin(), dims.end(), dims.front(), [](auto fold, auto dim) { - return ShapedType::isDynamic(dim) ? fold : dim; + return RankedShapedType::isDynamic(dim) ? fold : dim; }); return success(llvm::all_of(dims, [&](auto dim) { - return ShapedType::isDynamic(dim) || dim == staticDim; + return RankedShapedType::isDynamic(dim) || dim == staticDim; })); } @@ -142,23 +143,31 @@ } // Remove all unranked shapes - auto shapes = llvm::to_vector<8>(llvm::make_filter_range( - shapedTypes, [](auto shapedType) { return shapedType.hasRank(); })); + auto shapes = llvm::to_vector<8>( + llvm::make_filter_range(shapedTypes, [](auto shapedType) { + return isa(shapedType); + })); if (shapes.empty()) return success(); // All ranks should be equal - auto firstRank = shapes.front().getRank(); - if (llvm::any_of(shapes, - [&](auto shape) { return firstRank != shape.getRank(); })) + auto firstRank = cast(shapes.front()).getRank(); + if (llvm::any_of(shapes, [&](auto shape) { + return firstRank != cast(shape).getRank(); + })) return failure(); for (unsigned i = 0; i < firstRank; ++i) { // Retrieve all ranked dimensions auto dims = llvm::to_vector<8>(llvm::map_range( llvm::make_filter_range( - shapes, [&](auto shape) { return shape.getRank() >= i; }), - [&](auto shape) { return shape.getDimSize(i); })); + shapes, + [&](auto shape) { + return cast(shape).getRank() >= i; + }), + [&](auto shape) { + return cast(shape).getDimSize(i); + })); if (verifyCompatibleDims(dims).failed()) return failure(); } diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -39,20 +39,21 @@ auto shapedType = result.getType().dyn_cast(); if (!shapedType) continue; - if (!shapedType.hasRank()) { + if (!isa(shapedType)) { // Nothing to check for unranked shaped values. ++resultIdx; continue; } + auto rankedShapedType = cast(shapedType); // Assert one OpFoldResult per dimension. - assert(shapedType.getRank() == + assert(rankedShapedType.getRank() == static_cast(reifiedReturnShapes[resultIdx].size()) && "incorrect implementation of ReifyRankedShapedTypeOpInterface"); - for (int64_t dim = 0; dim < shapedType.getRank(); ++dim) { + for (int64_t dim = 0; dim < rankedShapedType.getRank(); ++dim) { // reifyResultShapes must return: // * Attribute for static dimensions // * Value for dynamic dimensions - assert(shapedType.isDynamicDim(dim) == + assert(rankedShapedType.isDynamicDim(dim) == reifiedReturnShapes[resultIdx][dim].is() && "incorrect implementation of ReifyRankedShapedTypeOpInterface"); } @@ -69,7 +70,7 @@ if (val.isNull()) return false; if (auto t = val.dyn_cast()) - return t.cast().hasRank(); + return isa(t); if (val.is()) return true; return val.get()->hasRank(); @@ -88,7 +89,7 @@ void ShapeAdaptor::getDims(SmallVectorImpl &res) const { assert(hasRank()); if (auto t = val.dyn_cast()) { - ArrayRef vals = t.cast().getShape(); + ArrayRef vals = t.cast().getShape(); res.assign(vals.begin(), vals.end()); } else if (auto attr = val.dyn_cast()) { auto dattr = attr.cast(); @@ -111,7 +112,7 @@ int64_t ShapeAdaptor::getDimSize(int index) const { assert(hasRank()); if (auto t = val.dyn_cast()) - return t.cast().getDimSize(index); + return t.cast().getDimSize(index); if (auto attr = val.dyn_cast()) return attr.cast() .getValues()[index] @@ -123,7 +124,7 @@ int64_t ShapeAdaptor::getRank() const { assert(hasRank()); if (auto t = val.dyn_cast()) - return t.cast().getRank(); + return t.cast().getRank(); if (auto attr = val.dyn_cast()) return attr.cast().size(); return val.get()->getDims().size(); @@ -134,23 +135,23 @@ return false; if (auto t = val.dyn_cast()) - return t.cast().hasStaticShape(); + return t.cast().hasStaticShape(); if (auto attr = val.dyn_cast()) { auto dattr = attr.cast(); for (auto index : dattr.getValues()) - if (ShapedType::isDynamic(index.getSExtValue())) + if (RankedShapedType::isDynamic(index.getSExtValue())) return false; return true; } auto *stc = val.get(); - return llvm::none_of(stc->getDims(), ShapedType::isDynamic); + return llvm::none_of(stc->getDims(), RankedShapedType::isDynamic); } int64_t ShapeAdaptor::getNumElements() const { assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); if (auto t = val.dyn_cast()) - return t.cast().getNumElements(); + return t.cast().getNumElements(); if (auto attr = val.dyn_cast()) { auto dattr = attr.cast(); @@ -180,7 +181,7 @@ SmallVector dims; getDims(dims); auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string { - if (ShapedType::isDynamic(dim)) + if (RankedShapedType::isDynamic(dim)) return "?"; return llvm::formatv("{0}", dim).str(); }); diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -51,8 +51,8 @@ assert(!dim.has_value() && "invalid dim value"); } else if (auto shapedType = dyn_cast(value.getType())) { assert(*dim >= 0 && "invalid dim value"); - if (shapedType.hasRank()) - assert(*dim < shapedType.getRank() && "invalid dim value"); + if (auto rankedShapedType = dyn_cast(shapedType)) + assert(*dim < rankedShapedType.getRank() && "invalid dim value"); } else { llvm_unreachable("unsupported type"); } @@ -83,8 +83,9 @@ auto shapedType = dyn_cast(value.getType()); if (shapedType) { // Static dimension: return constant directly. - if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim)) - return builder.getAffineConstantExpr(shapedType.getDimSize(*dim)); + auto rankedShapedType = dyn_cast(shapedType); + if (rankedShapedType && !rankedShapedType.isDynamicDim(*dim)) + return builder.getAffineConstantExpr(rankedShapedType.getDimSize(*dim)); } else { // Constant index value: return directly. if (auto constInt = getConstantIntValue(value)) @@ -164,9 +165,9 @@ // Check for static dim size. if (dim != kIndexValue) { - auto shapedType = cast(value.getType()); - if (shapedType.hasRank() && !shapedType.isDynamicDim(dim)) { - bound(value)[dim] == getExpr(shapedType.getDimSize(dim)); + auto rankedShapedType = cast(value.getType()); + if (rankedShapedType && !rankedShapedType.isDynamicDim(dim)) { + bound(value)[dim] == getExpr(rankedShapedType.getDimSize(dim)); continue; } } @@ -346,7 +347,9 @@ continue; } - assert(cast(value.getType()).isDynamicDim(dim) && + auto rankedShapedType = cast(value.getType()); + (void)rankedShapedType; + assert((!rankedShapedType || rankedShapedType.isDynamicDim(dim)) && "expected dynamic dim"); mapOperands.push_back(std::make_pair(value, dim)); } diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -28,7 +28,7 @@ << " values, got " << staticVals.size(); unsigned expectedNumDynamicEntries = llvm::count_if(staticVals, [&](int64_t staticVal) { - return ShapedType::isDynamic(staticVal); + return RankedShapedType::isDynamic(staticVal); }); if (values.size() != expectedNumDynamicEntries) return op->emitError("expected ") @@ -112,7 +112,7 @@ } unsigned idx = 0; llvm::interleaveComma(integers, printer, [&](int64_t integer) { - if (ShapedType::isDynamic(integer)) + if (RankedShapedType::isDynamic(integer)) printer << values[idx++]; else printer << integer; @@ -131,7 +131,7 @@ auto res = parser.parseOptionalOperand(operand); if (res.has_value() && succeeded(res.value())) { values.push_back(operand); - integerVals.push_back(ShapedType::kDynamic); + integerVals.push_back(RankedShapedType::kDynamic); } else { int64_t integer; if (failed(parser.parseInteger(integer))) diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -986,14 +986,15 @@ if (auto iType = type.dyn_cast()) return (os << "size_t"), success(); if (auto tType = type.dyn_cast()) { - if (!tType.hasRank()) + auto rtType = dyn_cast(tType); + if (!rtType) return emitError(loc, "cannot emit unranked tensor type"); - if (!tType.hasStaticShape()) + if (!rtType.hasStaticShape()) return emitError(loc, "cannot emit tensor type with non static shape"); os << "Tensor<"; - if (failed(emitType(loc, tType.getElementType()))) + if (failed(emitType(loc, rtType.getElementType()))) return failure(); - auto shape = tType.getShape(); + auto shape = rtType.getShape(); for (auto dimSize : shape) { os << ", "; os << dimSize; diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -233,7 +233,7 @@ if (!llvm::ConstantDataSequential::isElementTypeCompatible(innermostLLVMType)) return nullptr; - ShapedType type = denseElementsAttr.getType(); + RankedShapedType type = denseElementsAttr.getType(); if (type.getNumElements() == 0) return nullptr; diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -685,7 +685,7 @@ uint32_t resultID = 0; if (auto attr = valueAttr.dyn_cast()) { - int rank = attr.getType().dyn_cast().getRank(); + int rank = attr.getType().dyn_cast().getRank(); SmallVector index(rank); resultID = prepareDenseElementsConstant(loc, constType, attr, /*dim=*/0, index); @@ -732,7 +732,7 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType, DenseElementsAttr valueAttr, int dim, MutableArrayRef index) { - auto shapedType = valueAttr.getType().dyn_cast(); + auto shapedType = valueAttr.getType().dyn_cast(); assert(dim <= shapedType.getRank()); if (shapedType.getRank() == dim) { if (auto attr = valueAttr.dyn_cast()) { diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -218,7 +218,7 @@ static_split_point = split_point dynamic_split_point = None else: - static_split_point = _get_int64_attr(ShapedType.get_dynamic_size()) + static_split_point = _get_int64_attr(RankedShapedType.get_dynamic_size()) dynamic_split_point = _get_op_result_or_value(split_point) target = _get_op_result_or_value(target) @@ -281,7 +281,7 @@ if isinstance(size, int): static_sizes.append(size) else: - static_sizes.append(ShapedType.get_dynamic_size()) + static_sizes.append(RankedShapedType.get_dynamic_size()) dynamic_sizes.append(_get_op_result_or_value(size)) sizes_attr = DenseI64ArrayAttr.get(static_sizes) diff --git a/mlir/python/mlir/dialects/_tensor_ops_ext.py b/mlir/python/mlir/dialects/_tensor_ops_ext.py --- a/mlir/python/mlir/dialects/_tensor_ops_ext.py +++ b/mlir/python/mlir/dialects/_tensor_ops_ext.py @@ -30,7 +30,7 @@ if isinstance(s, int): static_sizes.append(s) else: - static_sizes.append(ShapedType.get_dynamic_size()) + static_sizes.append(RankedShapedType.get_dynamic_size()) dynamic_sizes.append(s) result_type = RankedTensorType.get(static_sizes, element_type) op = self.build_generic( diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -181,7 +181,7 @@ # indexing maps and iterators that match the rank of the first output tensor. # An operation is rank polymorphic if the iteration domain has rank zero. if not iterator_types_attr: - rank = ShapedType(outs[0].type).rank + rank = RankedShapedType(outs[0].type).rank iterator_types_attr = ArrayAttr.get( [Attribute.parse("#linalg.iterator_type")] * rank) scalar_map = AffineMap.get(rank, 0, []) @@ -192,7 +192,7 @@ indexing_maps.append(scalar_map) if arg_def.operand_def.is_tensor(): idx = arg_def.operand_def.registered_index - if idx < len(ins) and ShapedType(ins[idx].type).rank == 0: + if idx < len(ins) and RankedShapedType(ins[idx].type).rank == 0: indexing_maps.append(scalar_map) else: indexing_maps.append(tensor_map) diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -716,11 +716,12 @@ if (!mlirTypeIsAVector(vector) || !mlirTypeIsAShaped(vector)) return 14; if (!mlirTypeEqual(mlirShapedTypeGetElementType(vector), f32) || - !mlirShapedTypeHasRank(vector) || mlirShapedTypeGetRank(vector) != 2 || - mlirShapedTypeGetDimSize(vector, 0) != 2 || - mlirShapedTypeIsDynamicDim(vector, 0) || - mlirShapedTypeGetDimSize(vector, 1) != 3 || - !mlirShapedTypeHasStaticShape(vector)) + !mlirTypeIsARankedShaped(vector) || + mlirRankedShapedTypeGetRank(vector) != 2 || + mlirRankedShapedTypeGetDimSize(vector, 0) != 2 || + mlirRankedShapedTypeIsDynamicDim(vector, 0) || + mlirRankedShapedTypeGetDimSize(vector, 1) != 3 || + !mlirRankedShapedTypeHasStaticShape(vector)) return 15; mlirTypeDump(vector); fprintf(stderr, "\n"); @@ -741,7 +742,7 @@ MlirType unrankedTensor = mlirUnrankedTensorTypeGet(f32); if (!mlirTypeIsATensor(unrankedTensor) || !mlirTypeIsAUnrankedTensor(unrankedTensor) || - mlirShapedTypeHasRank(unrankedTensor)) + mlirTypeIsARankedShaped(unrankedTensor)) return 17; mlirTypeDump(unrankedTensor); fprintf(stderr, "\n"); diff --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir --- a/mlir/test/IR/invalid-builtin-attributes.mlir +++ b/mlir/test/IR/invalid-builtin-attributes.mlir @@ -1,13 +1,13 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics func.func @elementsattr_non_tensor_type() -> () { - "foo"(){bar = dense<[4]> : i32} : () -> () // expected-error {{elements literal must be a shaped type}} + "foo"(){bar = dense<[4]> : i32} : () -> () // expected-error {{elements literal type must be a ranked shaped type with static shape}} } // ----- func.func @elementsattr_non_ranked() -> () { - "foo"(){bar = dense<[4]> : tensor} : () -> () // expected-error {{elements literal type must have static shape}} + "foo"(){bar = dense<[4]> : tensor} : () -> () // expected-error {{elements literal type must be a ranked shaped type with static shape}} } // ----- diff --git a/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp b/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp --- a/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp +++ b/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp @@ -41,13 +41,13 @@ return; } llvm::outs() << "MemRefType offset: "; - if (ShapedType::isDynamic(offset)) + if (RankedShapedType::isDynamic(offset)) llvm::outs() << "?"; else llvm::outs() << offset; llvm::outs() << " strides: "; llvm::interleaveComma(strides, llvm::outs(), [&](int64_t v) { - if (ShapedType::isDynamic(v)) + if (RankedShapedType::isDynamic(v)) llvm::outs() << "?"; else llvm::outs() << v; diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -113,7 +113,7 @@ if (!op.getSource().hasOneUse()) return false; - auto resultType = op.getResult().getType().cast(); + auto resultType = op.getResult().getType().cast(); constexpr int64_t kConstantFoldingMaxNumElements = 1024; return resultType.getNumElements() <= kConstantFoldingMaxNumElements; }; diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -86,7 +86,7 @@ ]> { let mnemonic = "i64_elements"; let parameters = (ins - AttributeSelfTypeParameter<"", "::mlir::ShapedType">:$type, + AttributeSelfTypeParameter<"", "::mlir::RankedShapedType">:$type, ArrayRefParameter<"uint64_t">:$elements ); let extraClassDeclaration = [{ diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -86,7 +86,8 @@ LogicalResult TestI64ElementsAttr::verify(function_ref emitError, - ShapedType type, ArrayRef elements) { + RankedShapedType type, + ArrayRef elements) { if (type.getNumElements() != static_cast(elements.size())) { return emitError() << "number of elements does not match the provided shape type, got: " diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1348,7 +1348,8 @@ if (!sval) { return emitOptionalError(location, "only shaped type operands allowed"); } - int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic; + auto rsval = sval.dyn_cast(); + int64_t dim = rsval ? rsval.getShape().front() : RankedShapedType::kDynamic; auto type = IntegerType::get(context, 17); Attribute encoding; diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -260,7 +260,7 @@ DerivedTypeAttr element_dtype = DerivedTypeAttr<"return getElementTypeOrSelf(getOutput().getType());">; DerivedAttr num_elements = DerivedAttr<"int", - "return getOutput().getType().cast().getNumElements();", + "return getOutput().getType().cast().getNumElements();", "$_builder.getI32IntegerAttr($_self)">; } diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -24,7 +24,8 @@ # CHECK-NEXT: %[[RES:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : tensor<12x?xf32>) -> tensor<12x?xf32> # CHECK-NEXT: return %[[RES]] : tensor<12x?xf32> @func.FuncOp.from_py_func( - RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32)) + RankedTensorType.get((12, RankedShapedType.get_dynamic_size()), f32) + ) def fill_tensor(out): zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result return linalg.fill(zero, outs=[out]) @@ -35,7 +36,8 @@ # CHECK-NEXT: linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : memref<12x?xf32>) # CHECK-NEXT: return @func.FuncOp.from_py_func( - MemRefType.get((12, ShapedType.get_dynamic_size()), f32)) + MemRefType.get((12, RankedShapedType.get_dynamic_size()), f32) + ) def fill_buffer(out): zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result linalg.fill(zero, outs=[out]) diff --git a/mlir/test/python/dialects/shape.py b/mlir/test/python/dialects/shape.py --- a/mlir/test/python/dialects/shape.py +++ b/mlir/test/python/dialects/shape.py @@ -19,8 +19,10 @@ module = Module.create() f32 = F32Type.get() with InsertionPoint(module.body): + @func.FuncOp.from_py_func( - RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32)) + RankedTensorType.get((12, RankedShapedType.get_dynamic_size()), f32) + ) def const_shape_tensor(arg): shape.ConstWitnessOp(False) shape.ConstSizeOp(30) @@ -31,8 +33,6 @@ DenseElementsAttr.get( np.array([3, 4], dtype=np.int64), type=IndexType.get())) - - # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>) # CHECK-DAG: shape.const_witness false # CHECK-DAG: shape.const_size 30 diff --git a/mlir/test/python/dialects/tensor.py b/mlir/test/python/dialects/tensor.py --- a/mlir/test/python/dialects/tensor.py +++ b/mlir/test/python/dialects/tensor.py @@ -23,8 +23,13 @@ @func.FuncOp.from_py_func( RankedTensorType.get( - (ShapedType.get_dynamic_size(), ShapedType.get_dynamic_size()), - f32Type)) + ( + RankedShapedType.get_dynamic_size(), + RankedShapedType.get_dynamic_size(), + ), + f32Type, + ) + ) # CHECK: func @tensor_static_dim # CHECK-SAME: %[[ARG0:.+]]: tensor # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index diff --git a/mlir/test/python/dialects/vector.py b/mlir/test/python/dialects/vector.py --- a/mlir/test/python/dialects/vector.py +++ b/mlir/test/python/dialects/vector.py @@ -36,8 +36,12 @@ with InsertionPoint(module.body): vector_type = VectorType.get([2, 3], F32Type.get()) memref_type = MemRefType.get( - [ShapedType.get_dynamic_size(), - ShapedType.get_dynamic_size()], F32Type.get()) + [ + RankedShapedType.get_dynamic_size(), + RankedShapedType.get_dynamic_size(), + ], + F32Type.get(), + ) index_type = IndexType.get() mask_type = VectorType.get(vector_type.shape, IntegerType.get_signless(1)) identity_map = AffineMap.get_identity(vector_type.rank) diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py --- a/mlir/test/python/ir/array_attributes.py +++ b/mlir/test/python/ir/array_attributes.py @@ -74,13 +74,13 @@ try: attr = DenseElementsAttr.get_splat(non_shaped_type, element) except ValueError as e: - # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32) + # CHECK: Expected a static RankedShapedType for the shaped_type parameter: Type(f32) print(e) try: attr = DenseElementsAttr.get_splat(dynamic_shaped_type, element) except ValueError as e: - # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>) + # CHECK: Expected a static RankedShapedType for the shaped_type parameter: Type(tensor<*xf32>) print(e) try: @@ -364,4 +364,3 @@ # CHECK: {{\[}}[1 2 3] # CHECK: {{\[}}4 5 6]] print(np.array(attr)) - diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -548,7 +548,7 @@ print(attr.strides[2]) attr = StridedLayoutAttr.get_fully_dynamic(3) - dynamic = ShapedType.get_dynamic_stride_or_offset() + dynamic = RankedShapedType.get_dynamic_stride_or_offset() # CHECK: strided<[?, ?, ?], offset: ?> print(attr) # CHECK: offset is dynamic: True diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -254,8 +254,6 @@ vector = VectorType(Type.parse("vector<2x3xf32>")) # CHECK: element type: f32 print("element type:", vector.element_type) - # CHECK: whether the given shaped type is ranked: True - print("whether the given shaped type is ranked:", vector.has_rank) # CHECK: rank: 2 print("rank:", vector.rank) # CHECK: whether the shaped type has a static shape: True @@ -270,7 +268,10 @@ print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1)) # CHECK: isinstance(ShapedType): True print("isinstance(ShapedType):", isinstance(vector, ShapedType)) - + # CHECK: isinstance(RankedShapedType): True + print("isinstance(RankedShapedType):", isinstance(vector, RankedShapedType)) + # CHECK: has_rank: True + print("has_rank:", vector.has_rank) # CHECK-LABEL: TEST: testAbstractShapedType # Tests that ShapedType operates as an abstract base class of a concrete @@ -340,27 +341,6 @@ unranked_tensor = UnrankedTensorType.get(f32) # CHECK: unranked tensor type: tensor<*xf32> print("unranked tensor type:", unranked_tensor) - try: - invalid_rank = unranked_tensor.rank - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") - try: - invalid_is_dynamic_dim = unranked_tensor.is_dynamic_dim(0) - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") - try: - invalid_get_dim_size = unranked_tensor.get_dim_size(1) - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") none = NoneType.get() try: @@ -423,27 +403,6 @@ unranked_memref = UnrankedMemRefType.get(f32, Attribute.parse("2")) # CHECK: unranked memref type: memref<*xf32, 2> print("unranked memref type:", unranked_memref) - try: - invalid_rank = unranked_memref.rank - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") - try: - invalid_is_dynamic_dim = unranked_memref.is_dynamic_dim(0) - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") - try: - invalid_get_dim_size = unranked_memref.get_dim_size(1) - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") none = NoneType.get() try: @@ -501,11 +460,11 @@ print("data:", opaque.data) -# CHECK-LABEL: TEST: testShapedTypeConstants -# Tests that ShapedType exposes magic value constants. +# CHECK-LABEL: TEST: testRankedShapedTypeConstants +# Tests that RankedShapedType exposes magic value constants. @run -def testShapedTypeConstants(): +def testRankedShapedTypeConstants(): # CHECK: - print(type(ShapedType.get_dynamic_size())) + print(type(RankedShapedType.get_dynamic_size())) # CHECK: - print(type(ShapedType.get_dynamic_stride_or_offset())) + print(type(RankedShapedType.get_dynamic_stride_or_offset())) diff --git a/mlir/unittests/Dialect/BroadcastShapeTest.cpp b/mlir/unittests/Dialect/BroadcastShapeTest.cpp --- a/mlir/unittests/Dialect/BroadcastShapeTest.cpp +++ b/mlir/unittests/Dialect/BroadcastShapeTest.cpp @@ -47,7 +47,7 @@ TEST(BroadcastShapeTest, InterleavingUnknowns) { SmallVector result; - int64_t dyn = mlir::ShapedType::kDynamic; + int64_t dyn = mlir::RankedShapedType::kDynamic; ASSERT_TRUE(getBroadcastedShape({1, 2, dyn, dyn, dyn}, {dyn, dyn, dyn, 4, 1}, result)); EXPECT_THAT(result, ElementsAre(dyn, 2, dyn, 4, dyn));