diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td --- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td +++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td @@ -128,7 +128,9 @@ static constexpr LenType singleton() { return 1; } /// Character has a LEN value which is not a compile-time known constant. - static constexpr LenType unknownLen() { return mlir::ShapedType::kDynamic; } + static constexpr LenType unknownLen() { + return mlir::RankedShapedType::kDynamic; + } /// Character LEN is a runtime value. bool hasDynamicLen() { return getLen() == unknownLen(); } @@ -493,7 +495,7 @@ // The value `kDynamic` represents an unknown extent for a dimension static constexpr Extent getUnknownExtent() { - return mlir::ShapedType::kDynamic; + return mlir::RankedShapedType::kDynamic; } }]; } diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td b/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td --- a/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td +++ b/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td @@ -76,7 +76,7 @@ isPolymorphic()); } static constexpr int64_t getUnknownExtent() { - return mlir::ShapedType::kDynamic; + return mlir::RankedShapedType::kDynamic; } }]; diff --git a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp --- a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp +++ b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp @@ -410,8 +410,8 @@ auto affineApply = rewriter.create( acoOp.getLoc(), affineMap, indexArgs); auto arrayElementType = coordinateArrayElement(acoOp); - auto newType = - mlir::MemRefType::get({mlir::ShapedType::kDynamic}, arrayElementType); + auto newType = mlir::MemRefType::get({mlir::RankedShapedType::kDynamic}, + arrayElementType); auto arrayConvert = rewriter.create(acoOp.getLoc(), newType, acoOp.getMemref()); return std::make_pair(affineApply, arrayConvert); diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -159,44 +159,47 @@ // Shaped type. //===----------------------------------------------------------------------===// -/// Checks whether the given type is a Shaped type. +/// Checks whether the given type is a ShapedType. MLIR_CAPI_EXPORTED bool mlirTypeIsAShaped(MlirType type); +/// Checks whether the given type is a RankedShapeType. +MLIR_CAPI_EXPORTED bool mlirTypeIsARankedShaped(MlirType type); + /// Returns the element type of the shaped type. MLIR_CAPI_EXPORTED MlirType mlirShapedTypeGetElementType(MlirType type); -/// Checks whether the given shaped type is ranked. -MLIR_CAPI_EXPORTED bool mlirShapedTypeHasRank(MlirType type); - /// Returns the rank of the given ranked shaped type. -MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetRank(MlirType type); +MLIR_CAPI_EXPORTED int64_t mlirRankedShapedTypeGetRank(MlirType type); -/// Checks whether the given shaped type has a static shape. -MLIR_CAPI_EXPORTED bool mlirShapedTypeHasStaticShape(MlirType type); +/// Checks whether the given ranked shaped type has a static shape. +MLIR_CAPI_EXPORTED bool mlirRankedShapedTypeHasStaticShape(MlirType type); -/// Checks wither the dim-th dimension of the given shaped type is dynamic. -MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim); +/// Checks whether the dim-th dimension of the given ranked shaped type is +/// dynamic. +MLIR_CAPI_EXPORTED bool mlirRankedShapedTypeIsDynamicDim(MlirType type, + intptr_t dim); /// Returns the dim-th dimension of the given ranked shaped type. -MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type, - intptr_t dim); +MLIR_CAPI_EXPORTED int64_t mlirRankedShapedTypeGetDimSize(MlirType type, + intptr_t dim); /// Checks whether the given value is used as a placeholder for dynamic sizes -/// in shaped types. -MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicSize(int64_t size); +/// in ranked shaped types. +MLIR_CAPI_EXPORTED bool mlirRankedShapedTypeIsDynamicSize(int64_t size); -/// Returns the value indicating a dynamic size in a shaped type. Prefer -/// mlirShapedTypeIsDynamicSize to direct comparisons with this value. -MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicSize(void); +/// Returns the value indicating a dynamic size in a ranked shaped type. Prefer +/// mlirRankedShapedTypeIsDynamicSize to direct comparisons with this value. +MLIR_CAPI_EXPORTED int64_t mlirRankedShapedTypeGetDynamicSize(void); /// Checks whether the given value is used as a placeholder for dynamic strides -/// and offsets in shaped types. -MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val); - -/// Returns the value indicating a dynamic stride or offset in a shaped type. -/// Prefer mlirShapedTypeGetDynamicStrideOrOffset to direct comparisons with -/// this value. -MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(void); +/// and offsets in ranked shaped types. +MLIR_CAPI_EXPORTED bool +mlirRankedShapedTypeIsDynamicStrideOrOffset(int64_t val); + +/// Returns the value indicating a dynamic stride or offset in a ranked shaped +/// type. Prefer mlirRankedShapedTypeGetDynamicStrideOrOffset to direct +/// comparisons with this value. +MLIR_CAPI_EXPORTED int64_t mlirRankedShapedTypeGetDynamicStrideOrOffset(void); //===----------------------------------------------------------------------===// // Vector type. diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -125,7 +125,7 @@ ArrayRef shape = getType().getShape(); return std::count_if( shape.begin(), shape.begin() + idx, - [&](int64_t size) { return ShapedType::isDynamic(size); }); + [&](int64_t size) { return RankedShapedType::isDynamic(size); }); } // Return the Value of the dynamic size of the tensor at dimension diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -307,7 +307,7 @@ /*defaultImplementation=*/[{ assert(opOperand->getOwner() == this->getOperation()); if (auto shapedType = - opOperand->get().getType().template dyn_cast()) + opOperand->get().getType().template dyn_cast()) return shapedType.getRank(); return 0; }] @@ -359,7 +359,7 @@ /*defaultImplementation=*/[{ assert(opOperand->getOwner() == this->getOperation()); if (auto shapedType = - opOperand->get().getType().template dyn_cast()) + opOperand->get().getType().template dyn_cast()) return shapedType.getShape(); return {}; }] @@ -544,7 +544,7 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return llvm::any_of(getStaticShape(), ShapedType::isDynamic); + return llvm::any_of(getStaticShape(), RankedShapedType::isDynamic); }] >, InterfaceMethod< diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -234,7 +234,7 @@ //===----------------------------------------------------------------------===// def TensorOrMemref : - AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">; + AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::RankedShapedType">; def MapOp : LinalgStructuredBase_Op<"map", [ DeclareOpInterfaceMethods, diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1055,9 +1055,9 @@ static split point attribute when it is known at transform IR construction time or as the handle to an operation producing a single index-typed value when it is computed by payload IR. In the latter case, the static split - point must be set to `ShapedType::kDynamic` and the dynamic size handle - must point to as many value-producing operations as there are structured - operations pointed to by the target handle. + point must be set to `RankedShapedType::kDynamic` and the dynamic size + handle must point to as many value-producing operations as there are + structured operations pointed to by the target handle. The operation consumes the target handle, but preserves the split point handle if provided. It produces two new handles pointing to the two parts @@ -1477,9 +1477,9 @@ case the tile value must be computed by the payload IR and the handle to the operation computing it must be provided through `dynamic_sizes`. When the sizes are not known statically, the corresponding entry in the - `static_sizes` attribute must be set to `ShapedType::kDynamic`. Only + `static_sizes` attribute must be set to `RankedShapedType::kDynamic`. Only the dynamic sizes must be provided in `dynamic_sizes`, i.e., there should - be as many handles as `ShapedType::kDynamic` values in the + be as many handles as `RankedShapedType::kDynamic` values in the `static_sizes` attribute. A static size of `0` indicates that the dimension should not be tiled. No loop will be generated for such dimensions. If all tile sizes are `0`, this transform is effectively a no-op. @@ -1676,9 +1676,9 @@ case the tile value must be computed by the payload IR and the handle to the operation computing it must be provided through `dynamic_sizes`. When the sizes are not known statically, the corresponding entry in the - `static_sizes` attribute must be set to `ShapedType::kDynamic`. Only + `static_sizes` attribute must be set to `RankedShapedType::kDynamic`. Only the dynamic sizes must be provided in `dynamic_sizes`, i.e., there should - be as many handles as `ShapedType::kDynamic` values in the + be as many handles as `RankedShapedType::kDynamic` values in the `static_sizes` attribute. A static size of `0` indicates that the dimension should not be tiled. No loop will be generated for such dimensions. If all tile sizes are `0`, this transform is effectively a no-op. diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1828,8 +1828,8 @@ The representation based on offsets, sizes and strides support a partially-static specification via attributes specified through the `static_offsets`, `static_sizes` and `static_strides` arguments. A special - sentinel value ShapedType::kDynamic encodes that the corresponding entry has - a dynamic value. + sentinel value RankedShapedType::kDynamic encodes that the corresponding + entry has a dynamic value. A subview operation may additionally reduce the rank of the resulting view by removing dimensions that are statically known to be of size 1. diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -35,7 +35,7 @@ /// Alias type for extent tensors. RankedTensorType getExtentTensorType(MLIRContext *ctx, - int64_t rank = ShapedType::kDynamic); + int64_t rank = RankedShapedType::kDynamic); // Check if a type is an extent tensor, e.g., tensor. bool isExtentTensorType(Type); diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td @@ -91,7 +91,7 @@ def Shape_ExtentTensorType : 1DTensorOf<[Index]>, - BuildableType<"::mlir::RankedTensorType::get({ShapedType::kDynamic}, " + BuildableType<"::mlir::RankedTensorType::get({RankedShapedType::kDynamic}, " "$_builder.getType<::mlir::IndexType>())"> { let description = [{ The extent tensor is a tensor of rank one with arbitrarily many index diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -47,11 +47,11 @@ /// The type for individual components of a compile-time shape. We avoid /// calling this "size" because we use the term "sizes" to indicate the /// actual run-time sizes, whereas this type also allows the value -/// `ShapedType::kDynamic`. +/// `RankedShapedType::kDynamic`. using DynSize = int64_t; /// The type for individual components of a compile-time shape which -/// are known not to be `ShapedType::kDynamic`. +/// are known not to be `RankedShapedType::kDynamic`. using StaticSize = int64_t; } // namespace sparse_tensor diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h @@ -54,7 +54,7 @@ assert((!isIdentity() || getDimRank() == lvlRank) && "Rank mismatch"); } - SparseTensorType(ShapedType stp, SparseTensorEncodingAttr enc) + SparseTensorType(RankedShapedType stp, SparseTensorEncodingAttr enc) : SparseTensorType( RankedTensorType::get(stp.getShape(), stp.getElementType(), enc)) {} @@ -181,7 +181,7 @@ ArrayRef getDimShape() const { return rtp.getShape(); } /// Safely looks up the requested dimension-DynSize. If you intend - /// to check the result with `ShapedType::isDynamic`, then see the + /// to check the result with `RankedShapedType::isDynamic`, then see the /// `getStaticDimSize` method instead. DynSize getDynamicDimSize(Dimension d) const { assert(d < getDimRank() && "Dimension is out of bounds"); @@ -192,8 +192,8 @@ /// sizes to `std::nullopt`. std::optional getStaticDimSize(Dimension d) const { const DynSize sh = getDynamicDimSize(d); - return ShapedType::isDynamic(sh) ? std::nullopt - : std::optional(sh); + return RankedShapedType::isDynamic(sh) ? std::nullopt + : std::optional(sh); } /// Returns true if no dimension has dynamic size. @@ -208,7 +208,7 @@ bool isDynamicDim(Dimension d) const { // We don't use `rtp.isDynamicDim(d)` because we want the // OOB error message to be consistent with `getDynamicDimSize`. - return ShapedType::isDynamic(getDynamicDimSize(d)); + return RankedShapedType::isDynamic(getDynamicDimSize(d)); } /// Returns the number of dimensions which have dynamic sizes. diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -281,8 +281,8 @@ The representation based on offsets, sizes and strides support a partially-static specification via attributes specified through the `static_offsets`, `static_sizes` and `static_strides` arguments. A special - sentinel value ShapedType::kDynamic encodes that the corresponding entry has - a dynamic value. + sentinel value RankedShapedType::kDynamic encodes that the corresponding + entry has a dynamic value. After buffer allocation, the "extract_slice" op is expected to lower into a memref.subview op. @@ -769,8 +769,8 @@ The representation based on offsets, sizes and strides support a partially-static specification via attributes specified through the `static_offsets`, `static_sizes` and `static_strides` arguments. A special - sentinel value ShapedType::kDynamic encodes that the corresponding entry has - a dynamic value. + sentinel value RankedShapedType::kDynamic encodes that the corresponding + entry has a dynamic value. After buffer allocation, the "insert_slice" op is expected to lower into a memref.subview op. @@ -1267,7 +1267,7 @@ unsigned numDynamic = 0; unsigned count = staticAttrs.size(); for (unsigned idx = 0; idx < count; ++idx) { - if (ShapedType::isDynamic(staticAttrs[idx])) + if (RankedShapedType::isDynamic(staticAttrs[idx])) res.push_back(values[numDynamic++]); else res.push_back(builder.getI64IntegerAttr(staticAttrs[idx])); @@ -1379,8 +1379,8 @@ The representation based on offsets, sizes and strides support a partially-static specification via attributes specified through the `static_offsets`, `static_sizes` and `static_strides` arguments. A special - sentinel value ShapedType::kDynamic encodes that the corresponding entry has - a dynamic value. + sentinel value RankedShapedType::kDynamic encodes that the corresponding + entry has a dynamic value. After buffer allocation, the "parallel_insert_slice" op is expected to lower into a memref.subview op. diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h --- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h @@ -48,10 +48,10 @@ std::optional> checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef params) { - SmallVector dynTypes; + SmallVector dynTypes; SmallVector dynamicDims; for (const Value ¶m : params) { - auto paramTy = param.getType().cast(); + auto paramTy = param.getType().cast(); if (!paramTy.hasStaticShape()) dynTypes.push_back(paramTy); } @@ -59,8 +59,9 @@ if (dynTypes.empty()) return dynamicDims; - for (const ShapedType &dynTy : dynTypes) { - if (llvm::any_of(dynTy.getShape().drop_front(), ShapedType::isDynamic)) { + for (const RankedShapedType &dynTy : dynTypes) { + if (llvm::any_of(dynTy.getShape().drop_front(), + RankedShapedType::isDynamic)) { (void)rewriter.notifyMatchFailure( op, "input can only be dynamic for batch size"); return std::nullopt; diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h --- a/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h @@ -45,10 +45,10 @@ static ValueKnowledge getKnowledgeFromType(Type type) { ValueKnowledge result = getPessimisticValueState(); if (auto shapedType = type.dyn_cast()) { - if (shapedType.hasRank()) { + if (auto rankedShapedType = dyn_cast(shapedType)) { result.hasRank = true; - result.sizes.reserve(shapedType.getRank()); - for (auto dim : shapedType.getShape()) + result.sizes.reserve(rankedShapedType.getRank()); + for (auto dim : rankedShapedType.getShape()) result.sizes.push_back(dim); } result.dtype = shapedType.getElementType(); @@ -111,14 +111,14 @@ return result; result.hasRank = true; - result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamic); + result.sizes.resize(lhs.sizes.size(), RankedShapedType::kDynamic); for (auto i : llvm::seq(0, result.sizes.size())) { int64_t lhsSize = lhs.sizes[i]; int64_t rhsSize = rhs.sizes[i]; int64_t &resultSize = result.sizes[i]; - if (lhsSize == ShapedType::kDynamic) { + if (lhsSize == RankedShapedType::kDynamic) { resultSize = rhsSize; - } else if (rhsSize == ShapedType::kDynamic) { + } else if (rhsSize == RankedShapedType::kDynamic) { resultSize = lhsSize; } else if (lhsSize == rhsSize) { resultSize = lhsSize; @@ -155,7 +155,7 @@ } result.hasRank = true; - result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamic); + result.sizes.resize(lhs.sizes.size(), RankedShapedType::kDynamic); for (int i = 0, e = lhs.sizes.size(); i < e; i++) { if (lhs.sizes[i] == rhs.sizes[i]) { result.sizes[i] = lhs.sizes[i]; @@ -170,7 +170,7 @@ // Whether the value has known rank. bool hasRank; // If `hasRank`, the sizes along each rank. Unknown sizes are represented as - // `ShapedType::kDynamic`. + // `RankedShapedType::kDynamic`. llvm::SmallVector sizes; // The dtype of a tensor. // This is equal to nullptr if we don't know that it is a specific concrete diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -68,7 +68,8 @@ /// the target type when possible. Return std::nullopt when this computation /// failed. std::optional> -getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType); +getReassociationIndicesForReshape(RankedShapedType sourceType, + RankedShapedType targetType); /// Returns the reassociation maps to collapse `sourceShape` to `targetShape` if /// possible. @@ -155,8 +156,9 @@ ArrayRef reassociationMaps, bool isExpandingReshape); template -static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType, - ShapedType expandedType, +static LogicalResult verifyReshapeLikeShapes(OpTy op, + RankedShapedType collapsedType, + RankedShapedType expandedType, bool isExpandingReshape) { return reshapeLikeShapesAreCompatible( [&](const Twine &msg) { return op->emitOpError(msg); }, @@ -235,8 +237,8 @@ if (!expandOp) return failure(); - ShapedType srcType = expandOp.getSrcType(); - ShapedType resultType = collapseOp.getResultType(); + RankedShapedType srcType = expandOp.getSrcType(); + RankedShapedType resultType = collapseOp.getResultType(); if (hasNonIdentityLayout(collapseOp.getSrc().getType()) || hasNonIdentityLayout(expandOp.getSrc().getType()) || @@ -301,8 +303,8 @@ if (!collapseOp) return failure(); - ShapedType srcType = collapseOp.getSrcType(); - ShapedType resultType = expandOp.getResultType(); + RankedShapedType srcType = collapseOp.getSrcType(); + RankedShapedType resultType = expandOp.getResultType(); if (hasNonIdentityLayout(expandOp.getSrc().getType()) || hasNonIdentityLayout(collapseOp.getSrc().getType()) || diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -43,7 +43,7 @@ /// Helper function to dispatch an OpFoldResult into `staticVec` if: /// a) it is an IntegerAttr /// In other cases, the OpFoldResult is dispached to the `dynamicVec`. -/// In such dynamic cases, ShapedType::kDynamic is also pushed to +/// In such dynamic cases, RankedShapedType::kDynamic is also pushed to /// `staticVec`. This is useful to extract mixed static and dynamic entries /// that come from an AttrSizedOperandSegments trait. void dispatchIndexOpFoldResult(OpFoldResult ofr, @@ -105,8 +105,8 @@ // ArrayRef valueOrAttrVec); /// Return a vector of OpFoldResults with the same size a staticValues, but -/// all elements for which ShapedType::isDynamic is true, will be replaced by -/// dynamicValues. +/// all elements for which RankedShapedType::isDynamic is true, will be replaced +/// by dynamicValues. SmallVector getMixedValues(ArrayRef staticValues, ValueRange dynamicValues, Builder &b); diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -134,7 +134,7 @@ /// Build the default minor identity map suitable for a vector transfer. This /// also handles the case memref<... x vector<...>> -> vector<...> in which the /// rank of the identity map must take the vector element type into account. -AffineMap getTransferMinorIdentityMap(ShapedType shapedType, +AffineMap getTransferMinorIdentityMap(RankedShapedType shapedType, VectorType vectorType); /// Return true if the transfer_write fully writes the data accessed by the diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td --- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td @@ -160,7 +160,7 @@ }]>, InterfaceMethod<[{ Returns the shaped type of the elements attribute. - }], "::mlir::ShapedType", "getType"> + }], "::mlir::RankedShapedType", "getType"> ]; string ElementsAttrInterfaceAccessors = [{ @@ -312,7 +312,7 @@ bool isValidIndex(ArrayRef index) const { return isValidIndex(*this, index); } - static bool isValidIndex(ShapedType type, ArrayRef index); + static bool isValidIndex(RankedShapedType type, ArrayRef index); static bool isValidIndex(ElementsAttr elementsAttr, ArrayRef index); diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -103,7 +103,8 @@ /// Each element attribute value is expected to be an element of 'type'. /// 'type' must be a vector or tensor with static shape. If the element of /// `type` is non-integer/index/float it is assumed to be a string type. - static DenseElementsAttr get(ShapedType type, ArrayRef values); + static DenseElementsAttr get(RankedShapedType type, + ArrayRef values); /// Constructs a dense integer elements attribute from an array of integer /// or floating-point values. Each value is expected to be the same bitwidth @@ -112,7 +113,8 @@ template ::is_integer || is_valid_cpp_fp_type::value>> - static DenseElementsAttr get(const ShapedType &type, ArrayRef values) { + static DenseElementsAttr get(const RankedShapedType &type, + ArrayRef values) { const char *data = reinterpret_cast(values.data()); return getRawIntOrFloat( type, ArrayRef(data, values.size() * sizeof(T)), sizeof(T), @@ -124,7 +126,7 @@ typename = std::enable_if_t::is_integer || is_valid_cpp_fp_type::value || detail::is_complex_t::value>> - static DenseElementsAttr get(const ShapedType &type, T value) { + static DenseElementsAttr get(const RankedShapedType &type, T value) { return get(type, llvm::ArrayRef(value)); } @@ -136,7 +138,8 @@ typename = std::enable_if_t::value && (std::numeric_limits::is_integer || is_valid_cpp_fp_type::value)>> - static DenseElementsAttr get(const ShapedType &type, ArrayRef values) { + static DenseElementsAttr get(const RankedShapedType &type, + ArrayRef values) { const char *data = reinterpret_cast(values.data()); return getRawComplex(type, ArrayRef(data, values.size() * sizeof(T)), sizeof(T), std::numeric_limits::is_integer, @@ -144,43 +147,44 @@ } /// Overload of the above 'get' method that is specialized for boolean values. - static DenseElementsAttr get(ShapedType type, ArrayRef values); + static DenseElementsAttr get(RankedShapedType type, ArrayRef values); /// Overload of the above 'get' method that is specialized for StringRef /// values. - static DenseElementsAttr get(ShapedType type, ArrayRef values); + static DenseElementsAttr get(RankedShapedType type, + ArrayRef values); /// Constructs a dense integer elements attribute from an array of APInt /// values. Each APInt value is expected to have the same bitwidth as the /// element type of 'type'. 'type' must be a vector or tensor with static /// shape. - static DenseElementsAttr get(ShapedType type, ArrayRef values); + static DenseElementsAttr get(RankedShapedType type, ArrayRef values); /// Constructs a dense complex elements attribute from an array of APInt /// values. Each APInt value is expected to have the same bitwidth as the /// element type of 'type'. 'type' must be a vector or tensor with static /// shape. - static DenseElementsAttr get(ShapedType type, + static DenseElementsAttr get(RankedShapedType type, ArrayRef> values); /// Constructs a dense float elements attribute from an array of APFloat /// values. Each APFloat value is expected to have the same bitwidth as the /// element type of 'type'. 'type' must be a vector or tensor with static /// shape. - static DenseElementsAttr get(ShapedType type, ArrayRef values); + static DenseElementsAttr get(RankedShapedType type, ArrayRef values); /// Constructs a dense complex elements attribute from an array of APFloat /// values. Each APFloat value is expected to have the same bitwidth as the /// element type of 'type'. 'type' must be a vector or tensor with static /// shape. - static DenseElementsAttr get(ShapedType type, + static DenseElementsAttr get(RankedShapedType type, ArrayRef> values); /// Construct a dense elements attribute for an initializer_list of values. /// Each value is expected to be the same bitwidth of the element type of /// 'type'. 'type' must be a vector or tensor with static shape. template - static DenseElementsAttr get(const ShapedType &type, + static DenseElementsAttr get(const RankedShapedType &type, const std::initializer_list &list) { return get(type, ArrayRef(list)); } @@ -198,7 +202,7 @@ /// - For bitwidth = 1: Packed into 8bit bytes with bits corresponding to /// the linear order of the shape type from MSB to LSB, padded to on the /// right. - static DenseElementsAttr getFromRawBuffer(ShapedType type, + static DenseElementsAttr getFromRawBuffer(RankedShapedType type, ArrayRef rawBuffer); /// Returns true if the given buffer is a valid raw buffer for the given type. @@ -210,7 +214,7 @@ /// /// User code should be prepared for additional, conformant patterns to be /// identified as splats in the future. - static bool isValidRawBuffer(ShapedType type, ArrayRef rawBuffer, + static bool isValidRawBuffer(RankedShapedType type, ArrayRef rawBuffer, bool &detectedSplat); //===--------------------------------------------------------------------===// @@ -590,7 +594,7 @@ /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor /// with static shape. - ShapedType getType() const; + RankedShapedType getType() const; /// Return the element type of this DenseElementsAttr. Type getElementType() const; @@ -611,12 +615,12 @@ /// Return a new DenseElementsAttr that has the same data as the current /// attribute, but has been reshaped to 'newType'. The new type must have the /// same total number of elements as well as element type. - DenseElementsAttr reshape(ShapedType newType); + DenseElementsAttr reshape(RankedShapedType newType); /// Return a new DenseElementsAttr that has the same data as the current /// attribute, but with a different shape for a splat type. The new type must /// have the same element type. - DenseElementsAttr resizeSplat(ShapedType newType); + DenseElementsAttr resizeSplat(RankedShapedType newType); /// Return a new DenseElementsAttr that has the same data as the current /// attribute, but has bitcast elements to 'newElType'. The new type must have @@ -656,14 +660,15 @@ /// Overload of the raw 'get' method that asserts that the given type is of /// complex type. This method is used to verify type invariants that the /// templatized 'get' method cannot. - static DenseElementsAttr getRawComplex(ShapedType type, ArrayRef data, + static DenseElementsAttr getRawComplex(RankedShapedType type, + ArrayRef data, int64_t dataEltSize, bool isInt, bool isSigned); /// Overload of the raw 'get' method that asserts that the given type is of /// integer or floating-point type. This method is used to verify type /// invariants that the templatized 'get' method cannot. - static DenseElementsAttr getRawIntOrFloat(ShapedType type, + static DenseElementsAttr getRawIntOrFloat(RankedShapedType type, ArrayRef data, int64_t dataEltSize, bool isInt, bool isSigned); @@ -781,7 +786,7 @@ /// resource, but may be changed if necessary to ensure uniqueness during /// insertion. static DenseResourceElementsAttrBase - get(ShapedType type, StringRef blobName, AsmResourceBlob blob); + get(RankedShapedType type, StringRef blobName, AsmResourceBlob blob); /// Return the data of this attribute as an ArrayRef if it is present, /// returns std::nullopt otherwise. @@ -910,12 +915,12 @@ /// Get an instance of a DenseFPElementsAttr with the given arguments. This /// simply wraps the DenseElementsAttr::get calls. template - static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg) { + static DenseFPElementsAttr get(const RankedShapedType &type, Arg &&arg) { return DenseElementsAttr::get(type, llvm::ArrayRef(arg)) .template cast(); } template - static DenseFPElementsAttr get(const ShapedType &type, + static DenseFPElementsAttr get(const RankedShapedType &type, const std::initializer_list &list) { return DenseElementsAttr::get(type, list) .template cast(); @@ -952,12 +957,12 @@ /// Get an instance of a DenseIntElementsAttr with the given arguments. This /// simply wraps the DenseElementsAttr::get calls. template - static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg) { + static DenseIntElementsAttr get(const RankedShapedType &type, Arg &&arg) { return DenseElementsAttr::get(type, llvm::ArrayRef(arg)) .template cast(); } template - static DenseIntElementsAttr get(const ShapedType &type, + static DenseIntElementsAttr get(const RankedShapedType &type, const std::initializer_list &list) { return DenseElementsAttr::get(type, list) .template cast(); @@ -1034,7 +1039,7 @@ namespace mlir { -/// Given a list of strides (in which ShapedType::kDynamic +/// Given a list of strides (in which RankedShapedType::kDynamic /// represents a dynamic value), return the single result AffineMap which /// represents the linearized strided layout map. Dimensions correspond to the /// offset followed by the strides in order. Symbols are inserted for each diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -246,8 +246,10 @@ dense<[10.0, 11.0]> : tensor<2xf32> ``` }]; - let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type, - "ArrayRef":$rawData); + let parameters = (ins + AttributeSelfTypeParameter<"", "RankedShapedType">:$type, + "ArrayRef":$rawData + ); let extraClassDeclaration = [{ using DenseElementsAttr::empty; using DenseElementsAttr::getNumElements; @@ -292,7 +294,7 @@ static void convertEndianOfArrayRefForBEmachine(ArrayRef inRawData, MutableArrayRef outRawData, - ShapedType type); + RankedShapedType type); /// Convert endianess of input for big-endian(BE) machines. The number of /// elements of `inRawData` is `numElements`, and each element has @@ -314,7 +316,7 @@ /// /// If the `values` array only has a single element, then this constructs /// splat of that value. - static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth, + static DenseElementsAttr getRaw(RankedShapedType type, size_t storageWidth, ArrayRef values); /// Constructs a dense elements attribute from an array of raw APInt values. @@ -323,7 +325,7 @@ /// /// If the `values` array only has a single element, then this constructs /// splat of that value. - static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth, + static DenseElementsAttr getRaw(RankedShapedType type, size_t storageWidth, ArrayRef values); /// Get or create a new dense elements attribute instance with the given raw @@ -331,19 +333,19 @@ /// /// If the `values` array only has a single element, then this constructs /// splat of that value. - static DenseElementsAttr getRaw(ShapedType type, ArrayRef data); + static DenseElementsAttr getRaw(RankedShapedType type, ArrayRef data); /// Overload of the raw 'get' method that asserts that the given type is of /// complex type. This method is used to verify type invariants that the /// templatized 'get' method cannot. - static DenseElementsAttr getRawComplex(ShapedType type, ArrayRef data, + static DenseElementsAttr getRawComplex(RankedShapedType type, ArrayRef data, int64_t dataEltSize, bool isInt, bool isSigned); /// Overload of the raw 'get' method that asserts that the given type is of /// integer or floating-point type. This method is used to verify type /// invariants that the templatized 'get' method cannot. - static DenseElementsAttr getRawIntOrFloat(ShapedType type, + static DenseElementsAttr getRawIntOrFloat(RankedShapedType type, ArrayRef data, int64_t dataEltSize, bool isInt, bool isSigned); @@ -386,10 +388,12 @@ dense<["example1", "example2"]> : tensor<2x!foo.string> ``` }]; - let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type, - "ArrayRef":$value); + let parameters = (ins + AttributeSelfTypeParameter<"", "RankedShapedType">:$type, + "ArrayRef":$value + ); let builders = [ - AttrBuilderWithInferredContext<(ins "ShapedType":$type, + AttrBuilderWithInferredContext<(ins "RankedShapedType":$type, "ArrayRef":$values), [{ return $_get(type.getContext(), type, values, /* isSplat */(values.size() == 1)); @@ -460,12 +464,12 @@ ``` }]; let parameters = (ins - AttributeSelfTypeParameter<"", "ShapedType">:$type, + AttributeSelfTypeParameter<"", "RankedShapedType">:$type, ResourceHandleParameter<"DenseResourceElementsHandle">:$rawHandle ); let builders = [ AttrBuilderWithInferredContext<(ins - "ShapedType":$type, "DenseResourceElementsHandle":$handle + "RankedShapedType":$type, "DenseResourceElementsHandle":$handle )> ]; let extraClassDeclaration = [{ @@ -476,7 +480,7 @@ /// for the key of the new handle for the `blob` resource, but may be /// changed if necessary to ensure uniqueness during insertion. static DenseResourceElementsAttr get( - ShapedType type, StringRef blobName, AsmResourceBlob blob + RankedShapedType type, StringRef blobName, AsmResourceBlob blob ); public: @@ -839,11 +843,13 @@ ``` }]; - let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type, - "DenseIntElementsAttr":$indices, - "DenseElementsAttr":$values); + let parameters = (ins + AttributeSelfTypeParameter<"", "RankedShapedType">:$type, + "DenseIntElementsAttr":$indices, + "DenseElementsAttr":$values + ); let builders = [ - AttrBuilderWithInferredContext<(ins "ShapedType":$type, + AttrBuilderWithInferredContext<(ins "RankedShapedType":$type, "DenseElementsAttr":$indices, "DenseElementsAttr":$values), [{ assert(indices.getType().getElementType().isInteger(64) && @@ -986,7 +992,7 @@ Strides must be positive and the offset must be non-negative. Both the strides and the offset may be _dynamic_, i.e. their value may not be known at compile time. This is expressed as a `?` in the assembly syntax and as - `ShapedType::kDynamic` in the code. Stride and offset values + `RankedShapedType::kDynamic` in the code. Stride and offset values must satisfy the constraints above at runtime, the behavior is undefined otherwise. diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td --- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td @@ -48,14 +48,12 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> { let cppNamespace = "::mlir"; let description = [{ - This interface provides a common API for interacting with multi-dimensional - container types. These types contain a shape and an element type. + This interface provides a common API for interacting with ranked or unranked + container types. These types have an element type and may have a shape. - A shape is a list of sizes corresponding to the dimensions of the container. - If the number of dimensions in the shape is unknown, the shape is "unranked". - If the number of dimensions is known, the shape "ranked". The sizes of the - dimensions of the shape must be positive, or kDynamic (in which case the - size of the dimension is dynamic, or not statically known). + If the number of dimensions in the shape is unknown, the shape is + "unranked". If the number of dimensions is known, the shape "ranked". Ranked + types should implement `RankedShapedType`, a sub-interface of `ShapedType`. }]; let methods = [ InterfaceMethod<[{ @@ -71,108 +69,121 @@ Returns the element type of this shaped type. }], "::mlir::Type", "getElementType">, + ]; + let extraSharedClassDeclaration = [{ + /// Return a clone of this type with the given new shape and element type. + auto clone(::llvm::ArrayRef shape, Type elementType) { + return $_type.cloneWith(shape, elementType); + } + /// Return a clone of this type with the given new shape. + auto clone(::llvm::ArrayRef shape) { + return $_type.cloneWith(shape, $_type.getElementType()); + } + /// Return a clone of this type with the given new element type. + auto clone(::mlir::Type elementType) { + return $_type.cloneWith(/*shape=*/std::nullopt, elementType); + } - InterfaceMethod<[{ - Returns if this type is ranked, i.e. it has a known number of dimensions. - }], - "bool", "hasRank">, + /// If an element type is an integer or a float, return its width. + /// Otherwise, abort. + unsigned getElementTypeBitWidth() const { + return $_type.getElementType().getIntOrFloatBitWidth(); + } + }]; +} + +//===----------------------------------------------------------------------===// +// RankedShapedType +//===----------------------------------------------------------------------===// +def RankedShapedTypeInterface + : TypeInterface<"RankedShapedType", [ShapedTypeInterface]> { + let cppNamespace = "::mlir"; + let description = [{ + This interface provides a common API for interacting with multi-dimensional + container types that have a known rank. These types contain a shape and an + element type. + + A shape is a list of sizes corresponding to the dimensions of the container. + The sizes of the dimensions of the shape must be positive, or kDynamic (in + which case the size of the dimension is dynamic). + }]; + let methods = [ InterfaceMethod<[{ - Returns the shape of this type if it is ranked, otherwise asserts. - }], - "::llvm::ArrayRef", "getShape">, + Returns the shape of this type. + }], "::llvm::ArrayRef", "getShape">, ]; let extraClassDeclaration = [{ static constexpr int64_t kDynamic = std::numeric_limits::min(); - /// Whether the given dimension size indicates a dynamic dimension. + /// Return "true" if the given dimension size indicates a dynamic dimension. static constexpr bool isDynamic(int64_t dValue) { - return dValue == kDynamic; + return dValue == kDynamic; } - /// Whether the given shape has any size that indicates a dynamic dimension. + /// Return "true" if the given shape has any size that indicates a dynamic + /// dimension. static bool isDynamicShape(ArrayRef dSizes) { return any_of(dSizes, [](int64_t dSize) { return isDynamic(dSize); }); } - /// Return the number of elements present in the given shape. + /// Return the number of elements present in the given shape. Asserts that + /// all dimensions are static. static int64_t getNumElements(ArrayRef shape); }]; let extraSharedClassDeclaration = [{ - /// Return a clone of this type with the given new shape and element type. - auto clone(::llvm::ArrayRef shape, Type elementType) { - return $_type.cloneWith(shape, elementType); - } - /// Return a clone of this type with the given new shape. - auto clone(::llvm::ArrayRef shape) { - return $_type.cloneWith(shape, $_type.getElementType()); - } - /// Return a clone of this type with the given new element type. - auto clone(::mlir::Type elementType) { - return $_type.cloneWith(/*shape=*/std::nullopt, elementType); - } - - /// If an element type is an integer or a float, return its width. Otherwise, - /// abort. - unsigned getElementTypeBitWidth() const { - return $_type.getElementType().getIntOrFloatBitWidth(); - } - - /// If this is a ranked type, return the rank. Otherwise, abort. + /// Return the rank of this type. int64_t getRank() const { - assert($_type.hasRank() && "cannot query rank of unranked shaped type"); return $_type.getShape().size(); } /// If it has static shape, return the number of elements. Otherwise, abort. int64_t getNumElements() const { - assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); - return ::mlir::ShapedType::getNumElements($_type.getShape()); + assert(hasStaticShape() + && "cannot get element count of dynamically shaped type"); + return ::mlir::RankedShapedType::getNumElements($_type.getShape()); } - /// Returns true if this dimension has a dynamic size (for ranked types); - /// aborts for unranked types. + /// Return "true" if this dimension has a dynamic size. bool isDynamicDim(unsigned idx) const { assert(idx < getRank() && "invalid index for shaped type"); - return ::mlir::ShapedType::isDynamic($_type.getShape()[idx]); + return ::mlir::RankedShapedType::isDynamic($_type.getShape()[idx]); } - /// Returns if this type has a static shape, i.e. if the type is ranked and - /// all dimensions have known size (>= 0). + /// Return "true" if this type has a static shape, i.e., if all dimensions + /// have known size (>= 0). bool hasStaticShape() const { - return $_type.hasRank() && - !::mlir::ShapedType::isDynamicShape($_type.getShape()); + return !::mlir::RankedShapedType::isDynamicShape($_type.getShape()); } - /// Returns if this type has a static shape and the shape is equal to - /// `shape` return true. + /// Return "true" if this type has a static shape and the shape is equal to + /// `shape`. bool hasStaticShape(::llvm::ArrayRef shape) const { return hasStaticShape() && $_type.getShape() == shape; } - /// If this is a ranked type, return the number of dimensions with dynamic - /// size. Otherwise, abort. + /// Return the number of dimensions with dynamic size. int64_t getNumDynamicDims() const { - return llvm::count_if($_type.getShape(), ::mlir::ShapedType::isDynamic); + return llvm::count_if($_type.getShape(), ::mlir::RankedShapedType::isDynamic); } - /// If this is ranked type, return the size of the specified dimension. - /// Otherwise, abort. + /// Return the size of the specified dimension. int64_t getDimSize(unsigned idx) const { assert(idx < getRank() && "invalid index for shaped type"); return $_type.getShape()[idx]; } - /// Returns the position of the dynamic dimension relative to just the dynamic - /// dimensions, given its `index` within the shape. + /// Return the position of the dynamic dimension relative to just the + /// dynamic dimensions, given its `index` within the shape. unsigned getDynamicDimIndex(unsigned index) const { assert(index < getRank() && "invalid index"); - assert(::mlir::ShapedType::isDynamic(getDimSize(index)) && "invalid index"); + assert(::mlir::RankedShapedType::isDynamic(getDimSize(index)) + && "invalid index"); return llvm::count_if($_type.getShape().take_front(index), - ::mlir::ShapedType::isDynamic); + ::mlir::RankedShapedType::isDynamic); } }]; } diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -87,9 +87,6 @@ /// Returns if this type is ranked, i.e. it has a known number of dimensions. bool hasRank() const; - /// Returns the shape of this tensor type. - ArrayRef getShape() const; - /// Clone this type with the given shape and element type. If the /// provided shape is `None`, the current shape of the type is used. TensorType cloneWith(std::optional> shape, @@ -123,9 +120,6 @@ /// Returns if this type is ranked, i.e. it has a known number of dimensions. bool hasRank() const; - /// Returns the shape of this memref type. - ArrayRef getShape() const; - /// Clone this type with the given shape and element type. If the /// provided shape is `None`, the current shape of the type is used. BaseMemRefType cloneWith(std::optional> shape, @@ -445,7 +439,7 @@ /// symbols. /// /// A stride specification is a list of integer values that are either static -/// or dynamic (encoded with ShapedType::kDynamic). Strides encode +/// or dynamic (encoded with RankedShapedType::kDynamic). Strides encode /// the distance in the number of elements between successive entries along a /// particular dimension. LogicalResult getStridesAndOffset(MemRefType t, diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -376,9 +376,8 @@ // MemRefType //===----------------------------------------------------------------------===// -def Builtin_MemRef : Builtin_Type<"MemRef", [ - ShapedTypeInterface - ], "BaseMemRefType"> { +def Builtin_MemRef + : Builtin_Type<"MemRef", [RankedShapedTypeInterface], "BaseMemRefType"> { let summary = "Shaped reference to a region of memory"; let description = [{ Syntax: @@ -629,15 +628,15 @@ "unsigned":$memorySpaceInd)> ]; let extraClassDeclaration = [{ - using ShapedType::Trait::clone; - using ShapedType::Trait::getElementTypeBitWidth; - using ShapedType::Trait::getRank; - using ShapedType::Trait::getNumElements; - using ShapedType::Trait::isDynamicDim; - using ShapedType::Trait::hasStaticShape; - using ShapedType::Trait::getNumDynamicDims; - using ShapedType::Trait::getDimSize; - using ShapedType::Trait::getDynamicDimIndex; + using ShapedType::Trait::clone; + using ShapedType::Trait::getElementTypeBitWidth; + using RankedShapedType::Trait::getRank; + using RankedShapedType::Trait::getNumElements; + using RankedShapedType::Trait::isDynamicDim; + using RankedShapedType::Trait::hasStaticShape; + using RankedShapedType::Trait::getNumDynamicDims; + using RankedShapedType::Trait::getDimSize; + using RankedShapedType::Trait::getDynamicDimIndex; /// This is a builder type that keeps local references to arguments. /// Arguments that are passed into the builder must outlive the builder. @@ -711,9 +710,8 @@ // RankedTensorType //===----------------------------------------------------------------------===// -def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [ - ShapedTypeInterface - ], "TensorType"> { +def Builtin_RankedTensor + : Builtin_Type<"RankedTensor", [RankedShapedTypeInterface], "TensorType"> { let summary = "Multi-dimensional array with a fixed number of dimensions"; let description = [{ Syntax: @@ -794,15 +792,15 @@ }]> ]; let extraClassDeclaration = [{ - using ShapedType::Trait::clone; - using ShapedType::Trait::getElementTypeBitWidth; - using ShapedType::Trait::getRank; - using ShapedType::Trait::getNumElements; - using ShapedType::Trait::isDynamicDim; - using ShapedType::Trait::hasStaticShape; - using ShapedType::Trait::getNumDynamicDims; - using ShapedType::Trait::getDimSize; - using ShapedType::Trait::getDynamicDimIndex; + using ShapedType::Trait::clone; + using ShapedType::Trait::getElementTypeBitWidth; + using RankedShapedType::Trait::getRank; + using RankedShapedType::Trait::getNumElements; + using RankedShapedType::Trait::isDynamicDim; + using RankedShapedType::Trait::hasStaticShape; + using RankedShapedType::Trait::getNumDynamicDims; + using RankedShapedType::Trait::getDimSize; + using RankedShapedType::Trait::getDynamicDimIndex; /// This is a builder type that keeps local references to arguments. /// Arguments that are passed into the builder must outlive the builder. @@ -883,9 +881,8 @@ // UnrankedMemRefType //===----------------------------------------------------------------------===// -def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [ - ShapedTypeInterface - ], "BaseMemRefType"> { +def Builtin_UnrankedMemRef + : Builtin_Type<"UnrankedMemRef", [ShapedTypeInterface], "BaseMemRefType"> { let summary = "Shaped reference, with unknown rank, to a region of memory"; let description = [{ Syntax: @@ -933,15 +930,6 @@ let extraClassDeclaration = [{ using ShapedType::Trait::clone; using ShapedType::Trait::getElementTypeBitWidth; - using ShapedType::Trait::getRank; - using ShapedType::Trait::getNumElements; - using ShapedType::Trait::isDynamicDim; - using ShapedType::Trait::hasStaticShape; - using ShapedType::Trait::getNumDynamicDims; - using ShapedType::Trait::getDimSize; - using ShapedType::Trait::getDynamicDimIndex; - - ArrayRef getShape() const { return std::nullopt; } /// [deprecated] Returns the memory space in old raw integer representation. /// New `Attribute getMemorySpace()` method should be used instead. @@ -955,9 +943,8 @@ // UnrankedTensorType //===----------------------------------------------------------------------===// -def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [ - ShapedTypeInterface - ], "TensorType"> { +def Builtin_UnrankedTensor + : Builtin_Type<"UnrankedTensor", [ShapedTypeInterface], "TensorType"> { let summary = "Multi-dimensional array with unknown dimensions"; let description = [{ Syntax: @@ -986,15 +973,6 @@ let extraClassDeclaration = [{ using ShapedType::Trait::clone; using ShapedType::Trait::getElementTypeBitWidth; - using ShapedType::Trait::getRank; - using ShapedType::Trait::getNumElements; - using ShapedType::Trait::isDynamicDim; - using ShapedType::Trait::hasStaticShape; - using ShapedType::Trait::getNumDynamicDims; - using ShapedType::Trait::getDimSize; - using ShapedType::Trait::getDynamicDimIndex; - - ArrayRef getShape() const { return std::nullopt; } }]; let skipDefaultBuilders = 1; let genVerifyDecl = 1; @@ -1004,7 +982,8 @@ // VectorType //===----------------------------------------------------------------------===// -def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> { +def Builtin_Vector + : Builtin_Type<"Vector", [RankedShapedTypeInterface], "Type"> { let summary = "Multi-dimensional SIMD vector type"; let description = [{ Syntax: diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -253,7 +253,8 @@ // For a ShapedType, verify that it has a static shape. def HasStaticShapePred : - CPred<"$_self.cast<::mlir::ShapedType>().hasStaticShape()">; + And<[CPred<"$_self.isa<::mlir::RankedShapedType>()">, + CPred<"$_self.cast<::mlir::RankedShapedType>().hasStaticShape()">]>; // Whether a type is a TupleType. def IsTupleTypePred : CPred<"$_self.isa<::mlir::TupleType>()">; @@ -548,20 +549,20 @@ descr # " of " # AnyTypeOf.summary # " values", cppClassName>; // Whether a shaped type is ranked. -def HasRankPred : CPred<"$_self.cast<::mlir::ShapedType>().hasRank()">; +def HasRankPred : CPred<"$_self.isa<::mlir::RankedShapedType>()">; // Whether a shaped type has one of the specified ranks. class HasAnyRankOfPred ranks> : And<[ HasRankPred, Or().getRank() + CPred<[{$_self.cast<::mlir::RankedShapedType>().getRank() == }] # rank>)>]>; // Whether a shaped type has a rank greater than or equal of the specified rank. class HasRankGreaterOrEqualPred : And<[ HasRankPred, - CPred<[{$_self.cast<::mlir::ShapedType>().getRank() >= }] # rank> + CPred<[{$_self.cast<::mlir::RankedShapedType>().getRank() >= }] # rank> ]>; // Vector types. @@ -1464,8 +1465,7 @@ CPred<"$_self.isa<::mlir::DenseFPElementsAttr>() &&" "$_self.cast<::mlir::DenseFPElementsAttr>().getType()." "getElementType().isF" # width # "() && " - // Check that this is ranked and has the specified shape. - "$_self.cast<::mlir::DenseFPElementsAttr>().getType().hasRank() && " + // Check that this has the specified shape. "$_self.cast<::mlir::DenseFPElementsAttr>().getType().getShape() == " "::mlir::ArrayRef({" # !interleave(dims, ", ") # "})">, width # "-bit float elements attribute of shape [" # @@ -2432,13 +2432,15 @@ // TODO: Improve the autogenerated error messages. class Rank : - StrFunc<"$" # name # ".getType().cast<::mlir::ShapedType>().getRank()">; + StrFunc<"$" # name # ".getType().cast<::mlir::RankedShapedType>()" + ".getRank()">; class Shape : - StrFunc<"$" # name # ".getType().cast<::mlir::ShapedType>().getShape()">; + StrFunc<"$" # name # ".getType().cast<::mlir::RankedShapedType>()" + ".getShape()">; class ElementCount : - StrFunc<"$" # name # ".getType().cast<::mlir::ShapedType>()" + StrFunc<"$" # name # ".getType().cast<::mlir::RankedShapedType>()" ".getNumElements()">; class ElementType : StrFunc<"getElementTypeOrSelf($" # name # ")">; diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -1236,7 +1236,7 @@ } /// Parse a dimension list of a tensor or memref type. This populates the - /// dimension list, using ShapedType::kDynamic for the `?` dimensions if + /// dimension list, using RankedShapedType::kDynamic for the `?` dimensions if /// `allowDynamic` is set and errors out on `?` otherwise. Parsing the /// trailing `x` is configurable. /// diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -70,7 +70,7 @@ /// Returns whether the index'th dimension is dynamic. /// Requires: shape is ranked. bool isDynamicDim(int index) const { - return ShapedType::isDynamic(getDimSize(index)); + return RankedShapedType::isDynamic(getDimSize(index)); } /// Returns whether the shape is fully static. @@ -100,7 +100,7 @@ /// The components consist of /// - A ranked or unranked shape with the dimension specification match those /// of ShapeType's getShape() (e.g., dynamic dimension represented using -/// ShapedType::kDynamic) +/// RankedShapedType::kDynamic) /// - A element type, may be unset (nullptr) /// - A attribute, may be unset (nullptr) /// Used by ShapedType type inferences. @@ -114,10 +114,10 @@ ShapedTypeComponents(Type elementType) : elementType(elementType), attr(nullptr), ranked(false) {} ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) { - ranked = shapedType.hasRank(); + ranked = isa(shapedType); elementType = shapedType.getElementType(); if (ranked) - dims = llvm::to_vector<4>(shapedType.getShape()); + dims = llvm::to_vector<4>(cast(shapedType).getShape()); } ShapedTypeComponents(ShapeAdaptor adaptor) : attr(nullptr) { ranked = adaptor.hasRank(); diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td --- a/mlir/include/mlir/Interfaces/VectorInterfaces.td +++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td @@ -150,13 +150,13 @@ /*defaultImplementation=*/ >, InterfaceMethod< - /*desc=*/"Return the ShapedType.", - /*retTy=*/"::mlir::ShapedType", + /*desc=*/"Return the RankedShapedType.", + /*retTy=*/"::mlir::RankedShapedType", /*methodName=*/"getShapedType", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/ - "return $_op.getSource().getType().template cast<::mlir::ShapedType>();" + "return $_op.getSource().getType().template cast<::mlir::RankedShapedType>();" >, InterfaceMethod< /*desc=*/"Return the VectorType.", diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -50,7 +50,7 @@ `getArrayAttrMaxRanks()`[0] (resp. [1], [2]). 3. if an entry of `static_offsets` (resp. `static_sizes`, `static_strides`) is equal to a special sentinel value, namely - `ShapedType::kDynamic`, then the corresponding entry is a dynamic + `RankedShapedType::kDynamic`, then the corresponding entry is a dynamic offset (resp. size, stride). 4. a variadic `offset` (resp. `sizes`, `strides`) operand must be present for each dynamic offset (resp. size, stride). @@ -204,7 +204,7 @@ /*args=*/(ins "unsigned":$idx), /*methodBody=*/"", /*defaultImplementation=*/[{ - return ::mlir::ShapedType::isDynamic(static_offsets()[idx]); + return ::mlir::RankedShapedType::isDynamic(static_offsets()[idx]); }] >, InterfaceMethod< @@ -214,7 +214,7 @@ /*args=*/(ins "unsigned":$idx), /*methodBody=*/"", /*defaultImplementation=*/[{ - return ::mlir::ShapedType::isDynamic(static_sizes()[idx]); + return ::mlir::RankedShapedType::isDynamic(static_sizes()[idx]); }] >, InterfaceMethod< @@ -224,7 +224,7 @@ /*args=*/(ins "unsigned":$idx), /*methodBody=*/"", /*defaultImplementation=*/[{ - return ::mlir::ShapedType::isDynamic(static_strides()[idx]); + return ::mlir::RankedShapedType::isDynamic(static_strides()[idx]); }] >, InterfaceMethod< @@ -280,7 +280,7 @@ assert($_op.isDynamicOffset(idx) && "expected dynamic offset"); auto numDynamic = getNumDynamicEntriesUpToIdx( static_offsets(), - ::mlir::ShapedType::isDynamic, + ::mlir::RankedShapedType::isDynamic, idx); return $_op.getOffsetSizeAndStrideStartOperandIndex() + numDynamic; }] @@ -297,7 +297,7 @@ /*defaultImplementation=*/[{ assert($_op.isDynamicSize(idx) && "expected dynamic size"); auto numDynamic = getNumDynamicEntriesUpToIdx( - static_sizes(), ::mlir::ShapedType::isDynamic, idx); + static_sizes(), ::mlir::RankedShapedType::isDynamic, idx); return $_op.getOffsetSizeAndStrideStartOperandIndex() + offsets().size() + numDynamic; }] @@ -315,7 +315,7 @@ assert($_op.isDynamicStride(idx) && "expected dynamic stride"); auto numDynamic = getNumDynamicEntriesUpToIdx( static_strides(), - ::mlir::ShapedType::isDynamic, + ::mlir::RankedShapedType::isDynamic, idx); return $_op.getOffsetSizeAndStrideStartOperandIndex() + offsets().size() + sizes().size() + numDynamic; diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -472,7 +472,7 @@ /// Build a dense attribute instance with the parsed elements and the given /// shaped type. - DenseElementsAttr getAttr(SMLoc loc, ShapedType type); + DenseElementsAttr getAttr(SMLoc loc, RankedShapedType type); ArrayRef getShape() const { return shape; } @@ -489,7 +489,7 @@ DenseElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy); /// Build a Dense attribute with hex data for the given type. - DenseElementsAttr getHexAttr(SMLoc loc, ShapedType type); + DenseElementsAttr getHexAttr(SMLoc loc, RankedShapedType type); /// Parse a single element, returning failure if it isn't a valid element /// literal. For example: @@ -538,7 +538,8 @@ /// Build a dense attribute instance with the parsed elements and the given /// shaped type. -DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) { +DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, + RankedShapedType type) { Type eltType = type.getElementType(); // Check to see if we parse the literal from a hex string. @@ -709,7 +710,8 @@ } /// Build a Dense attribute with hex data for the given type. -DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, ShapedType type) { +DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, + RankedShapedType type) { Type elementType = type.getElementType(); if (!elementType.isIntOrIndexOrFloat() && !elementType.isa()) { p.emitError(loc) @@ -1039,7 +1041,7 @@ /// elements-literal-type ::= vector-type | ranked-tensor-type /// /// This method also checks the type has static shape. -ShapedType Parser::parseElementsLiteralType(Type type) { +RankedShapedType Parser::parseElementsLiteralType(Type type) { // If the user didn't provide a type, parse the colon type for the literal. if (!type) { if (parseToken(Token::colon, "expected ':'")) @@ -1048,15 +1050,13 @@ return nullptr; } - auto sType = type.dyn_cast(); - if (!sType) { - emitError("elements literal must be a shaped type"); + auto sType = type.dyn_cast(); + if (!sType || !sType.hasStaticShape()) { + emitError( + "elements literal type must be a ranked shaped type with static shape"); return nullptr; } - if (!sType.hasStaticShape()) - return (emitError("elements literal type must have static shape"), nullptr); - return sType; } @@ -1072,7 +1072,7 @@ // of the type. Type indiceEltType = builder.getIntegerType(64); if (consumeIf(Token::greater)) { - ShapedType type = parseElementsLiteralType(attrType); + RankedShapedType type = parseElementsLiteralType(attrType); if (!type) return nullptr; @@ -1114,7 +1114,7 @@ // 2-dimensional shape where the second dimension is the rank of the type. // Given that the parsed indices is a splat, we know that we only have one // indice and thus one for the first dimension. - ShapedType indicesType; + RankedShapedType indicesType; if (indiceParser.getShape().empty()) { indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); } else { @@ -1152,7 +1152,7 @@ // must fit into int64_t limits. auto parseStrideOrOffset = [&]() -> std::optional { if (consumeIf(Token::question)) - return ShapedType::kDynamic; + return RankedShapedType::kDynamic; SMLoc loc = getToken().getLoc(); auto emitWrongTokenError = [&] { diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -262,7 +262,7 @@ /// Parse a dense elements attribute. Attribute parseDenseElementsAttr(Type attrType); - ShapedType parseElementsLiteralType(Type type); + RankedShapedType parseElementsLiteralType(Type type); /// Parse a dense resource elements attribute. Attribute parseDenseResourceElementsAttr(Type attrType); diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -528,7 +528,7 @@ if (consumeIf(Token::question)) { if (!allowDynamic) return emitError(loc, "expected static shape"); - dimensions.push_back(ShapedType::kDynamic); + dimensions.push_back(RankedShapedType::kDynamic); } else { int64_t value; if (failed(parseIntegerInDimensionList(value))) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -668,10 +668,10 @@ message.append(py::repr(py::cast(elementAttr))); throw SetPyError(PyExc_ValueError, message); } - if (!mlirTypeIsAShaped(shapedType) || - !mlirShapedTypeHasStaticShape(shapedType)) { + if (!mlirTypeIsARankedShaped(shapedType) || + !mlirRankedShapedTypeHasStaticShape(shapedType)) { std::string message = - "Expected a static ShapedType for the shaped_type parameter: "; + "Expected a static RankedShapedType for the shaped_type parameter: "; message.append(py::repr(py::cast(shapedType))); throw SetPyError(PyExc_ValueError, message); } @@ -809,7 +809,7 @@ template py::buffer_info bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) { - intptr_t rank = mlirShapedTypeGetRank(shapedType); + intptr_t rank = mlirRankedShapedTypeGetRank(shapedType); // Prepare the data for the buffer_info. // Buffer is configured for read-only access below. Type *data = static_cast( @@ -817,7 +817,7 @@ // Prepare the shape for the buffer_info. SmallVector shape; for (intptr_t i = 0; i < rank; ++i) - shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); + shape.push_back(mlirRankedShapedTypeGetDimSize(shapedType, i)); // Prepare the strides for the buffer_info. SmallVector strides; if (mlirDenseElementsAttrIsSplat(*this)) { @@ -827,7 +827,7 @@ for (intptr_t i = 1; i < rank; ++i) { intptr_t strideFactor = 1; for (intptr_t j = i; j < rank; ++j) - strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); + strideFactor *= mlirRankedShapedTypeGetDimSize(shapedType, j); strides.push_back(sizeof(Type) * strideFactor); } strides.push_back(sizeof(Type)); @@ -1070,7 +1070,7 @@ c.def_static( "get_fully_dynamic", [](int64_t rank, DefaultingPyMlirContext ctx) { - auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset(); + auto dynamic = mlirRankedShapedTypeGetDynamicStrideOrOffset(); std::vector strides(rank); std::fill(strides.begin(), strides.end(), dynamic); MlirAttribute attr = mlirStridedLayoutAttrGet( diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -323,6 +323,12 @@ using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { + c.def_property_readonly( + "has_rank", + [](PyShapedType &self) -> bool { + return mlirTypeIsARankedShaped(self); + }, + "Returns whether the given shaped type is ranked."); c.def_property_readonly( "element_type", [](PyShapedType &self) { @@ -330,91 +336,84 @@ return PyType(self.getContext(), t); }, "Returns the element type of the shaped type."); - c.def_property_readonly( - "has_rank", - [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, - "Returns whether the given shaped type is ranked."); + } +}; + +class PyRankedShapedType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedShaped; + static constexpr const char *pyClassName = "RankedShapedType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { c.def_property_readonly( "rank", - [](PyShapedType &self) { - self.requireHasRank(); - return mlirShapedTypeGetRank(self); - }, + [](PyShapedType &self) { return mlirRankedShapedTypeGetRank(self); }, "Returns the rank of the given ranked shaped type."); c.def_property_readonly( "has_static_shape", [](PyShapedType &self) -> bool { - return mlirShapedTypeHasStaticShape(self); + return mlirRankedShapedTypeHasStaticShape(self); }, - "Returns whether the given shaped type has a static shape."); + "Returns whether the given ranked shaped type has a static shape."); c.def( "is_dynamic_dim", [](PyShapedType &self, intptr_t dim) -> bool { - self.requireHasRank(); - return mlirShapedTypeIsDynamicDim(self, dim); + return mlirRankedShapedTypeIsDynamicDim(self, dim); }, py::arg("dim"), - "Returns whether the dim-th dimension of the given shaped type is " - "dynamic."); + "Returns whether the dim-th dimension of the given ranked shaped type " + "is dynamic."); c.def( "get_dim_size", [](PyShapedType &self, intptr_t dim) { - self.requireHasRank(); - return mlirShapedTypeGetDimSize(self, dim); + return mlirRankedShapedTypeGetDimSize(self, dim); }, py::arg("dim"), "Returns the dim-th dimension of the given ranked shaped type."); c.def_static( "is_dynamic_size", - [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, + [](int64_t size) -> bool { + return mlirRankedShapedTypeIsDynamicSize(size); + }, py::arg("dim_size"), "Returns whether the given dimension size indicates a dynamic " "dimension."); c.def( "is_dynamic_stride_or_offset", [](PyShapedType &self, int64_t val) -> bool { - self.requireHasRank(); - return mlirShapedTypeIsDynamicStrideOrOffset(val); + return mlirRankedShapedTypeIsDynamicStrideOrOffset(val); }, py::arg("dim_size"), "Returns whether the given value is used as a placeholder for dynamic " - "strides and offsets in shaped types."); + "strides and offsets in ranked shaped types."); c.def_property_readonly( "shape", [](PyShapedType &self) { - self.requireHasRank(); - std::vector shape; - int64_t rank = mlirShapedTypeGetRank(self); + int64_t rank = mlirRankedShapedTypeGetRank(self); shape.reserve(rank); for (int64_t i = 0; i < rank; ++i) - shape.push_back(mlirShapedTypeGetDimSize(self, i)); + shape.push_back(mlirRankedShapedTypeGetDimSize(self, i)); return shape; }, "Returns the shape of the ranked shaped type as a list of integers."); c.def_static( - "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); }, - "Returns the value used to indicate dynamic dimensions in shaped " - "types."); + "get_dynamic_size", + []() { return mlirRankedShapedTypeGetDynamicSize(); }, + "Returns the value used to indicate dynamic dimensions in ranked " + "shaped types."); c.def_static( "get_dynamic_stride_or_offset", - []() { return mlirShapedTypeGetDynamicStrideOrOffset(); }, + []() { return mlirRankedShapedTypeGetDynamicStrideOrOffset(); }, "Returns the value used to indicate dynamic strides or offsets in " - "shaped types."); - } - -private: - void requireHasRank() { - if (!mlirShapedTypeHasRank(*this)) { - throw SetPyError( - PyExc_ValueError, - "calling this method requires that the type has a rank."); - } + "ranked shaped types."); } }; /// Vector Type subclass - VectorType. -class PyVectorType : public PyConcreteType { +class PyVectorType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; static constexpr const char *pyClassName = "VectorType"; @@ -439,7 +438,7 @@ /// Ranked Tensor Type subclass - RankedTensorType. class PyRankedTensorType - : public PyConcreteType { + : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; static constexpr const char *pyClassName = "RankedTensorType"; @@ -495,7 +494,7 @@ }; /// Ranked MemRef Type subclass - MemRefType. -class PyMemRefType : public PyConcreteType { +class PyMemRefType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef; static constexpr const char *pyClassName = "MemRefType"; @@ -726,6 +725,7 @@ PyNoneType::bind(m); PyComplexType::bind(m); PyShapedType::bind(m); + PyRankedShapedType::bind(m); PyVectorType::bind(m); PyRankedTensorType::bind(m); PyUnrankedTensorType::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -164,43 +164,46 @@ bool mlirTypeIsAShaped(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsARankedShaped(MlirType type) { + return unwrap(type).isa(); +} + MlirType mlirShapedTypeGetElementType(MlirType type) { return wrap(unwrap(type).cast().getElementType()); } -bool mlirShapedTypeHasRank(MlirType type) { - return unwrap(type).cast().hasRank(); +int64_t mlirRankedShapedTypeGetRank(MlirType type) { + return unwrap(type).cast().getRank(); } -int64_t mlirShapedTypeGetRank(MlirType type) { - return unwrap(type).cast().getRank(); +bool mlirRankedShapedTypeHasStaticShape(MlirType type) { + return unwrap(type).cast().hasStaticShape(); } -bool mlirShapedTypeHasStaticShape(MlirType type) { - return unwrap(type).cast().hasStaticShape(); +bool mlirRankedShapedTypeIsDynamicDim(MlirType type, intptr_t dim) { + return unwrap(type).cast().isDynamicDim( + static_cast(dim)); } -bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) { - return unwrap(type).cast().isDynamicDim( +int64_t mlirRankedShapedTypeGetDimSize(MlirType type, intptr_t dim) { + return unwrap(type).cast().getDimSize( static_cast(dim)); } -int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) { - return unwrap(type).cast().getDimSize(static_cast(dim)); +int64_t mlirRankedShapedTypeGetDynamicSize() { + return RankedShapedType::kDynamic; } -int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamic; } - -bool mlirShapedTypeIsDynamicSize(int64_t size) { - return ShapedType::isDynamic(size); +bool mlirRankedShapedTypeIsDynamicSize(int64_t size) { + return RankedShapedType::isDynamic(size); } -bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) { - return ShapedType::isDynamic(val); +bool mlirRankedShapedTypeIsDynamicStrideOrOffset(int64_t val) { + return RankedShapedType::isDynamic(val); } -int64_t mlirShapedTypeGetDynamicStrideOrOffset() { - return ShapedType::kDynamic; +int64_t mlirRankedShapedTypeGetDynamicStrideOrOffset() { + return RankedShapedType::kDynamic; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -204,7 +204,7 @@ size_t i = pair.index(); Value index = pair.value(); Value strideOp; - if (ShapedType::isDynamic(strides[i])) { + if (RankedShapedType::isDynamic(strides[i])) { strideOp = rewriter.create( loc, memrefDescriptor.stride(rewriter, loc, i), byteWidthConst); } else { @@ -226,7 +226,7 @@ Value sgprOffset = adaptor.getSgprOffset(); if (!sgprOffset) sgprOffset = createI32Constant(rewriter, loc, 0); - if (ShapedType::isDynamic(offset)) + if (RankedShapedType::isDynamic(offset)) sgprOffset = rewriter.create( loc, memrefDescriptor.offset(rewriter, loc), sgprOffset); else if (offset > 0) diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -369,7 +369,7 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite( arith::ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto srcType = constOp.getType().dyn_cast(); + auto srcType = constOp.getType().dyn_cast(); if (!srcType || srcType.getNumElements() == 1) return failure(); @@ -435,10 +435,13 @@ // attributes; element attributes only works with builtin types. So we need // to prepare another converted builtin types for the destination elements // attribute. - if (dstAttrType.isa()) - dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); - else - dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); + if (auto rankedTensoType = dstAttrType.dyn_cast()) { + dstAttrType = + RankedTensorType::get(rankedTensoType.getShape(), dstElemType); + } else { + dstAttrType = VectorType::get( + dstAttrType.cast().getShape(), dstElemType); + } dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); } @@ -456,7 +459,7 @@ arith::ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Type srcType = constOp.getType(); - if (auto shapedType = srcType.dyn_cast()) { + if (auto shapedType = srcType.dyn_cast()) { if (shapedType.getNumElements() != 1) return failure(); srcType = shapedType.getElementType(); diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp --- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp @@ -54,9 +54,8 @@ // Extract all strides and offsets and verify they are static. auto [strides, offset] = getStridesAndOffset(type); - assert(!ShapedType::isDynamic(offset) && - "expected static offset"); - assert(!llvm::any_of(strides, ShapedType::isDynamic) && + assert(!RankedShapedType::isDynamic(offset) && "expected static offset"); + assert(!llvm::any_of(strides, RankedShapedType::isDynamic) && "expected static strides"); auto convertedType = typeConverter.convertType(type); diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -76,14 +76,14 @@ Value index; if (offset != 0) // Skip if offset is zero. - index = ShapedType::isDynamic(offset) + index = RankedShapedType::isDynamic(offset) ? memRefDescriptor.offset(rewriter, loc) : createIndexConstant(rewriter, loc, offset); for (int i = 0, e = indices.size(); i < e; ++i) { Value increment = indices[i]; if (strides[i] != 1) { // Skip if stride is 1. - Value stride = ShapedType::isDynamic(strides[i]) + Value stride = RankedShapedType::isDynamic(strides[i]) ? memRefDescriptor.stride(rewriter, loc, i) : createIndexConstant(rewriter, loc, strides[i]); increment = rewriter.create(loc, increment, stride); @@ -124,14 +124,14 @@ SmallVectorImpl &strides, Value &sizeBytes) const { assert(isConvertibleAndHasIdentityMaps(memRefType) && "layout maps must have been normalized away"); - assert(count(memRefType.getShape(), ShapedType::kDynamic) == + assert(count(memRefType.getShape(), RankedShapedType::kDynamic) == static_cast(dynamicSizes.size()) && "dynamicSizes size doesn't match dynamic sizes count in memref shape"); sizes.reserve(memRefType.getRank()); unsigned dynamicIndex = 0; for (int64_t size : memRefType.getShape()) { - sizes.push_back(size == ShapedType::kDynamic + sizes.push_back(size == RankedShapedType::kDynamic ? dynamicSizes[dynamicIndex++] : createIndexConstant(rewriter, loc, size)); } @@ -147,14 +147,14 @@ if (size == 0) continue; bool useSizeAsStride = stride == 1; - if (size == ShapedType::kDynamic) - stride = ShapedType::kDynamic; - if (stride != ShapedType::kDynamic) + if (size == RankedShapedType::kDynamic) + stride = RankedShapedType::kDynamic; + if (stride != RankedShapedType::kDynamic) stride *= size; if (useSizeAsStride) runningStride = sizes[i]; - else if (stride == ShapedType::kDynamic) + else if (stride == RankedShapedType::kDynamic) runningStride = rewriter.create(loc, runningStride, sizes[i]); else diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -428,10 +428,10 @@ return false; for (int64_t stride : strides) - if (ShapedType::isDynamic(stride)) + if (RankedShapedType::isDynamic(stride)) return false; - return !ShapedType::isDynamic(offset); + return !RankedShapedType::isDynamic(offset); } /// Convert a memref type to a bare pointer to the memref element type. diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -27,8 +27,8 @@ static MemRefType makeStridedLayoutDynamic(MemRefType type) { return MemRefType::Builder(type).setLayout(StridedLayoutAttr::get( - type.getContext(), ShapedType::kDynamic, - SmallVector(type.getRank(), ShapedType::kDynamic))); + type.getContext(), RankedShapedType::kDynamic, + SmallVector(type.getRank(), RankedShapedType::kDynamic))); } /// Helper function to extract the operand types that are passed to the diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -36,7 +36,7 @@ namespace { bool isStaticStrideOrOffset(int64_t strideOrOffset) { - return !ShapedType::isDynamic(strideOrOffset); + return !RankedShapedType::isDynamic(strideOrOffset); } LLVM::LLVMFuncOp getFreeFn(LLVMTypeConverter *typeConverter, ModuleOp module) { @@ -1380,7 +1380,7 @@ Value stride = nullptr; int64_t targetRank = targetMemRefType.getRank(); for (auto i : llvm::reverse(llvm::seq(0, targetRank))) { - if (!ShapedType::isDynamic(strides[i])) { + if (!RankedShapedType::isDynamic(strides[i])) { // If the stride for this dimension is dynamic, then use the product // of the sizes of the inner dimensions. stride = createIndexConstant(rewriter, loc, strides[i]); @@ -1634,11 +1634,11 @@ ArrayRef shape, ValueRange dynamicSizes, unsigned idx) const { assert(idx < shape.size()); - if (!ShapedType::isDynamic(shape[idx])) + if (!RankedShapedType::isDynamic(shape[idx])) return createIndexConstant(rewriter, loc, shape[idx]); // Count the number of dynamic dims in range [0, idx] unsigned nDynamic = - llvm::count_if(shape.take_front(idx), ShapedType::isDynamic); + llvm::count_if(shape.take_front(idx), RankedShapedType::isDynamic); return dynamicSizes[nDynamic]; } @@ -1650,7 +1650,7 @@ ArrayRef strides, Value nextSize, Value runningStride, unsigned idx) const { assert(idx < strides.size()); - if (!ShapedType::isDynamic(strides[idx])) + if (!RankedShapedType::isDynamic(strides[idx])) return createIndexConstant(rewriter, loc, strides[idx]); if (nextSize) return runningStride diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp @@ -178,7 +178,7 @@ auto storageAttr = spirv::StorageClassAttr::get(memRefType.getContext(), *storage); if (auto rankedType = memRefType.dyn_cast()) { - return MemRefType::get(memRefType.getShape(), memRefType.getElementType(), + return MemRefType::get(rankedType.getShape(), memRefType.getElementType(), rankedType.getLayout(), storageAttr); } return UnrankedMemRefType::get(memRefType.getElementType(), storageAttr); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -520,7 +520,8 @@ "All TOSA elementwise ops should only return a single result."); auto results = operation->getResults(); - auto resultTy = operation->getResult(0).getType().dyn_cast(); + auto resultTy = + operation->getResult(0).getType().dyn_cast(); if (!resultTy) return rewriter.notifyMatchFailure(operation, @@ -538,10 +539,10 @@ SmallVector emptyTensors; SmallVector dynDims; - dynDims.resize(results.front().getType().cast().getRank()); + dynDims.resize(results.front().getType().cast().getRank()); for (auto arg : operation->getOperands()) { - auto operandTy = arg.getType().cast(); + auto operandTy = arg.getType().cast(); for (int i = 0; i < operandTy.getRank(); i++) { if (operandTy.isDynamicDim(i) && !dynDims[i]) dynDims[i] = rewriter.create(loc, arg, i); @@ -551,7 +552,7 @@ SmallVector filteredDims = condenseValues(dynDims); for (auto result : results) { - auto resultTy = result.getType().template cast(); + auto resultTy = result.getType().template cast(); emptyTensors.push_back(rewriter.create( loc, resultTy.getShape(), resultTy.getElementType(), filteredDims)); opResultTypes.push_back(result.getType()); @@ -566,7 +567,7 @@ // Input indexing maps may be broadcasted. for (Value operand : operation->getOperands()) { - ShapedType type = operand.getType().cast(); + auto type = operand.getType().cast(); if (type.getShape() == resultTy.getShape()) { operands.push_back(operand); @@ -733,7 +734,7 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, PatternRewriter &rewriter) { auto loc = op->getLoc(); - auto inputTy = op->getOperand(0).getType().template cast(); + auto inputTy = op->getOperand(0).getType().template cast(); auto resultTy = op->getResult(0).getType().template cast(); auto elementTy = resultTy.getElementType(); Value input = op->getOperand(0); @@ -799,7 +800,7 @@ SmallVector reassociationMap; uint64_t expandInputRank = - linalgOp.getResults()[0].getType().cast().getRank(); + linalgOp.getResults()[0].getType().cast().getRank(); reassociationMap.resize(expandInputRank); for (uint64_t i = 0; i < expandInputRank; i++) { @@ -848,18 +849,19 @@ auto loc = op.getLoc(); auto input = op->getOperand(0); - auto resultTy = op.getType().cast(); + auto resultTy = op.getType().cast(); SmallVector dynDims; - dynDims.resize(op->getResult(0).getType().cast().getRank()); + dynDims.resize( + op->getResult(0).getType().cast().getRank()); SmallVector inputExprs; inputExprs.resize(resultTy.getRank()); - auto operandTy = input.getType().cast(); + auto operandTy = input.getType().dyn_cast(); for (const auto &permutation : llvm::enumerate(perms.getValues())) { auto index = permutation.index(); auto value = permutation.value().getZExtValue(); - if (!operandTy.hasRank() || operandTy.isDynamicDim(index)) { + if (!operandTy || operandTy.isDynamicDim(index)) { dynDims[value] = rewriter.create(loc, input, index); } inputExprs[value] = rewriter.getAffineDimExpr(index); @@ -893,8 +895,8 @@ PatternRewriter &rewriter) const final { auto loc = op.getLoc(); auto input = op.getInput(); - auto inputTy = op.getInput().getType().cast(); - auto outputTy = op.getOutput().getType().cast(); + auto inputTy = op.getInput().getType().cast(); + auto outputTy = op.getOutput().getType().cast(); unsigned rank = inputTy.getRank(); // This is an illegal configuration. terminate and log an error @@ -1135,7 +1137,8 @@ outputDynSize.push_back(builder.create(input, 3)); // Generate the elementwise operation for casting scaling the input value. - auto genericTy = collapseTy.clone(resultTy.getElementType()); + auto genericTy = + collapseTy.clone(resultTy.getElementType()).cast(); Value empty = builder.create( genericTy.getShape(), resultTy.getElementType(), outputDynSize); auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank()); @@ -1282,8 +1285,8 @@ Location loc = op.getLoc(); ImplicitLocOpBuilder b(loc, rewriter); auto input = op.getInput(); - auto inputTy = input.getType().cast(); - auto resultTy = op.getType().cast(); + auto inputTy = input.getType().cast(); + auto resultTy = op.getType().cast(); auto resultETy = resultTy.getElementType(); auto imageH = inputTy.getShape()[1]; @@ -1573,8 +1576,8 @@ PatternRewriter &rewriter) const final { auto loc = op.getLoc(); Value input = op.getInput(); - auto inputTy = input.getType().template cast(); - auto resultTy = op.getType().template cast(); + auto inputTy = input.getType().template cast(); + auto resultTy = op.getType().template cast(); auto axis = op.getAxis(); SmallVector dynDims; @@ -1635,9 +1638,9 @@ ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto input = op.getInput1(); - auto inputTy = input.getType().cast(); + auto inputTy = input.getType().cast(); auto inputShape = inputTy.getShape(); - auto resultTy = op.getType().cast(); + auto resultTy = op.getType().cast(); auto elementTy = inputTy.getElementType(); int64_t rank = inputTy.getRank(); @@ -1647,7 +1650,7 @@ SmallVector genericShape; for (int i = 0; i < rank; i++) { int64_t dim = multiples[i]; - genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim); + genericShape.push_back(dim == -1 ? RankedShapedType::kDynamic : dim); genericShape.push_back(inputShape[i]); } @@ -1710,8 +1713,8 @@ PatternRewriter &rewriter) const final { auto loc = argmaxOp.getLoc(); Value input = argmaxOp.getInput(); - auto inputTy = input.getType().cast(); - auto resultTy = argmaxOp.getOutput().getType().cast(); + auto inputTy = input.getType().cast(); + auto resultTy = argmaxOp.getOutput().getType().cast(); auto inElementTy = inputTy.getElementType(); auto outElementTy = resultTy.getElementType(); int axis = argmaxOp.getAxis(); @@ -1831,7 +1834,7 @@ auto valuesTy = op.getValues().getType().dyn_cast_or_null(); - auto resultTy = op.getType().cast(); + auto resultTy = op.getType().cast(); if (!valuesTy) return rewriter.notifyMatchFailure(op, "unranked tensors not supported"); @@ -1904,9 +1907,9 @@ auto loc = op.getLoc(); Value input = op.getInput(); Value table = op.getTable(); - auto inputTy = input.getType().cast(); + auto inputTy = input.getType().cast(); auto tableTy = table.getType().cast(); - auto resultTy = op.getType().cast(); + auto resultTy = op.getType().cast(); auto inputElementTy = inputTy.getElementType(); auto tableElementTy = tableTy.getElementType(); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -36,7 +36,7 @@ if (llvm::all_of(pad, [](int64_t p) { return p == 0; })) return input; - ShapedType inputTy = input.getType().cast(); + auto inputTy = input.getType().cast(); Type inputETy = inputTy.getElementType(); auto inputShape = inputTy.getShape(); @@ -48,7 +48,7 @@ for (int i = 0, s = inputShape.size(); i < s; i++) { auto lowPad = pad[i * 2]; auto highPad = pad[i * 2 + 1]; - if (ShapedType::isDynamic(inputShape[i])) + if (RankedShapedType::isDynamic(inputShape[i])) paddedShape.push_back(inputShape[i]); else paddedShape.push_back(inputShape[i] + highPad + lowPad); @@ -67,7 +67,7 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, Value conv, Value result, ArrayRef indexingMaps) { - ShapedType resultTy = conv.getType().cast(); + auto resultTy = conv.getType().cast(); return rewriter .create( loc, resultTy, ValueRange({bias, conv}), result, indexingMaps, @@ -121,11 +121,11 @@ // Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D static SmallVector inferDynamicDimsForConv( - Location loc, Value input, Value weight, ShapedType resultTy, + Location loc, Value input, Value weight, RankedShapedType resultTy, ArrayRef padAttr, ArrayRef strideAttr, ArrayRef dilationAttr, ArrayRef inputSizeDims, ArrayRef kernelSizeDims, OpBuilder &rewriter) { - ShapedType inputTy = input.getType().cast(); + auto inputTy = input.getType().cast(); Type inputETy = inputTy.getElementType(); int64_t inputRank = inputTy.getRank(); @@ -187,11 +187,11 @@ Value weight = op->getOperand(1); Value bias = op->getOperand(2); - ShapedType inputTy = input.getType().template cast(); - ShapedType weightTy = weight.getType().template cast(); - ShapedType biasTy = bias.getType().template cast(); - ShapedType resultTy = - op->getResult(0).getType().template cast(); + auto inputTy = input.getType().template cast(); + auto weightTy = weight.getType().template cast(); + auto biasTy = bias.getType().template cast(); + auto resultTy = + op->getResult(0).getType().template cast(); Type inputETy = inputTy.getElementType(); Type resultETy = resultTy.getElementType(); @@ -353,10 +353,10 @@ Value weight = op->getOperand(1); Value bias = op->getOperand(2); - ShapedType inputTy = input.getType().cast(); - ShapedType weightTy = weight.getType().cast(); - ShapedType biasTy = bias.getType().cast(); - ShapedType resultTy = op->getResult(0).getType().cast(); + auto inputTy = input.getType().cast(); + auto weightTy = weight.getType().cast(); + auto biasTy = bias.getType().cast(); + auto resultTy = op->getResult(0).getType().cast(); int64_t resultRank = resultTy.getRank(); Type inputETy = inputTy.getElementType(); @@ -426,7 +426,7 @@ // Create the convolution op. auto strideAttr = rewriter.getI64TensorAttr(stride); auto dilationAttr = rewriter.getI64TensorAttr(dilation); - ShapedType linalgConvTy = + RankedShapedType linalgConvTy = RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2], weightShape[2], weightShape[3]}, resultETy); @@ -505,24 +505,27 @@ ConversionPatternRewriter &rewriter) const final { Location loc = op.getLoc(); - auto outputTy = op.getType().cast(); + auto outputTy = op.getType().cast(); auto outputElementTy = outputTy.getElementType(); - auto firstOperandTy = op->getOperand(0).getType().cast(); - auto secondOperandTy = op->getOperand(1).getType().cast(); + auto firstOperandTy = + op->getOperand(0).getType().dyn_cast(); + auto secondOperandTy = + op->getOperand(1).getType().dyn_cast(); SmallVector dynDims; - dynDims.resize(op->getResult(0).getType().cast().getRank()); + dynDims.resize( + op->getResult(0).getType().cast().getRank()); - if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(0)) { + if (!firstOperandTy || firstOperandTy.isDynamicDim(0)) { dynDims[0] = rewriter.create(loc, op->getOperand(0), 0); } - if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(1)) { + if (!firstOperandTy || firstOperandTy.isDynamicDim(1)) { dynDims[1] = rewriter.create(loc, op->getOperand(0), 1); } - if (!secondOperandTy.hasRank() || secondOperandTy.isDynamicDim(2)) { + if (!secondOperandTy || secondOperandTy.isDynamicDim(2)) { dynDims[2] = rewriter.create(loc, op->getOperand(1), 2); } @@ -564,26 +567,27 @@ matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { Location loc = op.getLoc(); - auto outputTy = op.getType().cast(); + auto outputTy = op.getType().cast(); auto input = op.getInput(); - auto inputTy = input.getType().cast(); + auto inputTy = input.getType().dyn_cast(); auto bias = op.getBias(); auto weight = op.getWeight(); - auto weightTy = weight.getType().cast(); + auto weightTy = weight.getType().cast(); auto weightShape = weightTy.getShape(); auto outputETy = outputTy.getElementType(); SmallVector dynDims; - dynDims.resize(op->getResult(0).getType().cast().getRank()); + dynDims.resize( + op->getResult(0).getType().cast().getRank()); - if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) { + if (!inputTy || inputTy.isDynamicDim(0)) { dynDims[0] = rewriter.create(loc, input, 0); } - if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) { + if (weightTy.isDynamicDim(0)) { dynDims[1] = rewriter.create(loc, weight, 0); } @@ -678,7 +682,7 @@ Value input = op.getInput(); ShapedType inputTy = input.getType().cast(); - ShapedType resultTy = op.getType().template cast(); + auto resultTy = op.getType().template cast(); Type resultETy = inputTy.getElementType(); auto dynamicDimsOr = @@ -750,12 +754,12 @@ ShapedType inputTy = input.getType().cast(); Type inElementTy = inputTy.getElementType(); - ShapedType resultTy = op.getType().template cast(); + auto resultTy = op.getType().template cast(); Type resultETy = op.getType().cast().getElementType(); Type accETy = inElementTy.isa() ? rewriter.getI32Type() : inElementTy; - ShapedType accTy = resultTy.clone(accETy); + auto accTy = resultTy.clone(accETy).cast(); auto dynamicDimsOr = checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()}); diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -26,7 +26,7 @@ bool isDynamic) { if (isDynamic) { // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1 - intermediateShape = {ShapedType::kDynamic}; + intermediateShape = {RankedShapedType::kDynamic}; return true; } @@ -134,8 +134,8 @@ LogicalResult matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - ShapedType operandTy = adaptor.getInput1().getType().cast(); - ShapedType resultTy = reshape.getType().template cast(); + auto operandTy = adaptor.getInput1().getType().cast(); + auto resultTy = reshape.getType().template cast(); bool isDynamic = !operandTy.hasStaticShape(); if (isDynamic && resultTy.getRank() != 1) { @@ -172,8 +172,8 @@ LogicalResult matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - ShapedType operandTy = adaptor.getInput1().getType().cast(); - ShapedType resultTy = reshape.getType().template cast(); + auto operandTy = adaptor.getInput1().getType().cast(); + auto resultTy = reshape.getType().template cast(); bool isDynamic = !operandTy.hasStaticShape(); if (isDynamic && operandTy.getRank() != 1) { @@ -211,8 +211,8 @@ LogicalResult matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - ShapedType operandTy = adaptor.getInput1().getType().cast(); - ShapedType resultTy = reshape.getType().template cast(); + auto operandTy = adaptor.getInput1().getType().cast(); + auto resultTy = reshape.getType().template cast(); bool isDynamic = !operandTy.hasStaticShape(); SmallVector intermediateShape; @@ -247,14 +247,15 @@ Value input = adaptor.getInput(); SmallVector strides, sizes; ArrayRef starts = sliceOp.getStart(); - strides.resize(sliceOp.getType().template cast().getRank(), 1); + strides.resize( + sliceOp.getType().template cast().getRank(), 1); SmallVector dynSizes; for (const auto &i : llvm::enumerate(sliceOp.getSize())) { int64_t size = i.value(); size_t index = i.index(); - sizes.push_back(size == -1 ? ShapedType::kDynamic : size); - if (!ShapedType::isDynamic(sizes.back())) + sizes.push_back(size == -1 ? RankedShapedType::kDynamic : size); + if (!RankedShapedType::isDynamic(sizes.back())) continue; auto dim = rewriter.create(loc, input, index); @@ -284,7 +285,7 @@ auto input = padOp.getInput1(); auto padding = padOp.getPadding(); - ShapedType inputTy = input.getType().cast(); + auto inputTy = input.getType().cast(); Type elementTy = inputTy.getElementType(); int64_t rank = inputTy.getRank(); @@ -355,7 +356,8 @@ LogicalResult matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto inputType = op.getOperand(0).getType().template cast(); + auto inputType = + op.getOperand(0).getType().template cast(); auto resultType = op.getType().dyn_cast(); Location loc = op.getLoc(); diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -135,7 +135,7 @@ strides.back() != 1) return std::nullopt; int64_t stride = strides[strides.size() - 2]; - if (stride == ShapedType::kDynamic) + if (stride == RankedShapedType::kDynamic) return std::nullopt; return stride; } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1313,9 +1313,9 @@ // layout. auto sizes = memRefType.getShape(); for (int index = 0, e = strides.size() - 1; index < e; ++index) { - if (ShapedType::isDynamic(sizes[index + 1]) || - ShapedType::isDynamic(strides[index]) || - ShapedType::isDynamic(strides[index + 1])) + if (RankedShapedType::isDynamic(sizes[index + 1]) || + RankedShapedType::isDynamic(strides[index]) || + RankedShapedType::isDynamic(strides[index + 1])) return std::nullopt; if (strides[index] != strides[index + 1] * sizes[index + 1]) return std::nullopt; @@ -1360,7 +1360,7 @@ if (!targetStrides) return failure(); // Only support static strides for now, regardless of contiguity. - if (llvm::any_of(*targetStrides, ShapedType::isDynamic)) + if (llvm::any_of(*targetStrides, RankedShapedType::isDynamic)) return failure(); auto int64Ty = IntegerType::get(rewriter.getContext(), 64); diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -830,7 +830,7 @@ for (unsigned r = 0; r < rank; r++) { cstWithShapeBounds.addBound(BoundType::LB, r, 0); int64_t dimSize = memRefType.getDimSize(r); - if (ShapedType::isDynamic(dimSize)) + if (RankedShapedType::isDynamic(dimSize)) continue; cstWithShapeBounds.addBound(BoundType::UB, r, dimSize - 1); } @@ -852,7 +852,7 @@ // If no constant bound is found, then it can always be bound by the // memref's dim size if the latter has a constant size along this dim. auto dimSize = memRefType.getDimSize(d); - if (dimSize == ShapedType::kDynamic) + if (dimSize == RankedShapedType::kDynamic) return std::nullopt; diffConstant = dimSize; // Lower bound becomes 0. diff --git a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp --- a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp @@ -96,7 +96,7 @@ // Put together alloc operands for any dynamic dimensions of the memref. SmallVector allocOperands; for (const auto &dim : llvm::enumerate(oldMemRefType.getShape())) { - if (dim.value() == ShapedType::kDynamic) + if (dim.value() == RankedShapedType::kDynamic) allocOperands.push_back(bOuter.createOrFold( forOp.getLoc(), oldMemRef, dim.index())); } diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp --- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp @@ -41,7 +41,7 @@ continue; } - assert(cast(value.getType()).isDynamicDim(*dim) && + assert(cast(value.getType()).isDynamicDim(*dim) && "expected dynamic dim"); if (isa(value.getType())) { // A tensor dimension is used: generate a tensor.dim. diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -1824,7 +1824,7 @@ bool isDynDim = isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims, context); if (isDynDim) { - newShape[d] = ShapedType::kDynamic; + newShape[d] = RankedShapedType::kDynamic; } else { // The lower bound for the shape is always zero. std::optional ubConst = fac.getConstantBound64(BoundType::UB, d); diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp --- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp @@ -86,7 +86,7 @@ continue; } - assert(value.getType().cast().isDynamicDim(*dim) && + assert(value.getType().cast().isDynamicDim(*dim) && "expected dynamic dim"); if (value.getType().isa()) { // A tensor dimension is used: generate a tensor.dim. diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -145,7 +145,7 @@ auto &shape = resultDims[shapedValue.cast().getResultNumber()]; for (const auto &dim : enumerate(tensorType.getShape())) - if (ShapedType::isDynamic(dim.value())) + if (RankedShapedType::isDynamic(dim.value())) dynamicSizes.push_back(shape[dim.index()].get()); } } @@ -838,9 +838,9 @@ // Case 2: Ranked memref type. auto rankedTensorType = tensorType.cast(); - int64_t dynamicOffset = ShapedType::kDynamic; + int64_t dynamicOffset = RankedShapedType::kDynamic; SmallVector dynamicStrides(rankedTensorType.getRank(), - ShapedType::kDynamic); + RankedShapedType::kDynamic); auto stridedLayout = StridedLayoutAttr::get(tensorType.getContext(), dynamicOffset, dynamicStrides); return MemRefType::get(rankedTensorType.getShape(), diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -47,7 +47,7 @@ failed(getStridesAndOffset(target, targetStrides, targetOffset))) return false; auto dynamicToStatic = [](int64_t a, int64_t b) { - return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b); + return RankedShapedType::isDynamic(a) && !RankedShapedType::isDynamic(b); }; if (dynamicToStatic(sourceOffset, targetOffset)) return false; @@ -69,7 +69,7 @@ auto loc = value.getLoc(); SmallVector dynamicOperands; for (int i = 0; i < destType.getRank(); ++i) { - if (destType.getShape()[i] != ShapedType::kDynamic) + if (destType.getShape()[i] != RankedShapedType::kDynamic) continue; auto index = b.createOrFold(loc, i); Value size = b.create(loc, value, index); @@ -132,7 +132,7 @@ void mlir::bufferization::populateDynamicDimSizes( OpBuilder &b, Location loc, Value shapedValue, SmallVector &dynamicDims) { - auto shapedType = shapedValue.getType().cast(); + auto shapedType = shapedValue.getType().cast(); for (int64_t i = 0; i < shapedType.getRank(); ++i) { if (shapedType.isDynamicDim(i)) { if (shapedType.isa()) { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp @@ -43,7 +43,7 @@ /// exceed the stack space. static bool defaultIsSmallAlloc(Value alloc, unsigned maximumSizeInBytes, unsigned maxRankOfAllocatedMemRef) { - auto type = alloc.getType().dyn_cast(); + auto type = alloc.getType().dyn_cast(); if (!type || !alloc.getDefiningOp()) return false; if (!type.hasStaticShape()) { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -28,9 +28,9 @@ SmallVector strides; if (failed(getStridesAndOffset(type, strides, offset))) return false; - if (!llvm::all_of(strides, ShapedType::isDynamic)) + if (!llvm::all_of(strides, RankedShapedType::isDynamic)) return false; - if (!ShapedType::isDynamic(offset)) + if (!RankedShapedType::isDynamic(offset)) return false; return true; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -147,12 +147,12 @@ // We parsed a generic dimension list, but vectors only support two forms: // - single non-dynamic entry in the list (fixed vector); - // - two elements, the first dynamic (indicated by ShapedType::kDynamic) - // and the second - // non-dynamic (scalable vector). + // - two elements, the first dynamic (indicated by + // RankedShapedType::kDynamic) and the second non-dynamic (scalable + // vector). if (dims.empty() || dims.size() > 2 || - ((dims.size() == 2) ^ (ShapedType::isDynamic(dims[0]))) || - (dims.size() == 2 && ShapedType::isDynamic(dims[1]))) { + ((dims.size() == 2) ^ (RankedShapedType::isDynamic(dims[0]))) || + (dims.size() == 2 && RankedShapedType::isDynamic(dims[1]))) { parser.emitError(dimPos) << "expected '? x x ' or ' x '"; return Type(); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -514,8 +514,8 @@ } static OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) { - auto shapedType = source.getType().cast(); - if (!shapedType.hasRank() || shapedType.isDynamicDim(dim)) + auto shapedType = source.getType().dyn_cast(); + if (!shapedType || shapedType.isDynamicDim(dim)) return createOrFoldDimOp(b, loc, source, dim); return b.getIndexAttr(shapedType.getDimSize(dim)); } @@ -644,7 +644,7 @@ for (OpOperand *opOperand : getDpsInitOperands()) { SmallVector shapes; for (int64_t dim : llvm::seq(0, getRank(opOperand))) { - auto shapedType = opOperand->get().getType().cast(); + auto shapedType = opOperand->get().getType().cast(); if (!shapedType.isDynamicDim(dim)) { // Static dim: Return IntegerAttr. shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim))); @@ -731,7 +731,7 @@ // Verify only static cases since we can't get exact dimension sizes and loop // ranges for dynamic cases in this stage. - if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) { + if (llvm::none_of(endLoopRangeValues, RankedShapedType::isDynamic)) { for (int64_t &range : endLoopRangeValues) range -= 1; for (OpOperand &opOperand : linalgOp->getOpOperands()) { @@ -743,7 +743,7 @@ ArrayRef shape = linalgOp.getShape(&opOperand); for (auto dim : llvm::seq(0, shape.size())) { // Ignore dynamic dimension or the case that the dimension size is 0 - if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0) + if (RankedShapedType::isDynamic(shape[dim]) || shape[dim] == 0) continue; // The first index or last index should be the maximum or the minimum in diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1180,7 +1180,7 @@ // The shape of each input must match the shape of the output. auto outputShape = getInit().getType().getShape(); for (Type inputArgType : TypeRange{getInputs()}) { - auto inputElemShape = inputArgType.cast().getShape(); + auto inputElemShape = inputArgType.cast().getShape(); if (inputElemShape != outputShape) { return emitOpError() << "expected shape of input (" << inputElemShape << ") to match shape of output (" << outputShape @@ -1250,7 +1250,8 @@ } SmallVector ReduceOp::getIteratorTypesArray() { - int64_t inputRank = getInputs()[0].getType().cast().getRank(); + int64_t inputRank = + getInputs()[0].getType().cast().getRank(); SmallVector iteratorTypes(inputRank, utils::IteratorType::parallel); for (int64_t reductionDim : getDimensions()) @@ -1259,7 +1260,8 @@ } ArrayAttr ReduceOp::getIndexingMaps() { - int64_t inputRank = getInputs()[0].getType().cast().getRank(); + int64_t inputRank = + getInputs()[0].getType().cast().getRank(); SmallVector affineMaps( getNumDpsInputs(), AffineMap::getMultiDimIdentityMap(inputRank, getContext())); @@ -1360,8 +1362,8 @@ ArrayRef dimensionsRef = getDimensions(); for (int64_t i = 1; i < getNumDpsInputs(); ++i) { - if (getInputs()[i].getType().cast().getShape() != - getInputs()[0].getType().cast().getShape()) { + if (getInputs()[i].getType().cast().getShape() != + getInputs()[0].getType().cast().getShape()) { return emitOpError() << "expects all inputs to have the same shapes. " "Shape at input-index " << i @@ -1369,16 +1371,16 @@ } } for (int64_t i = 1; i < getNumDpsInits(); ++i) { - if (getInits()[i].getType().cast().getShape() != - getInits()[0].getType().cast().getShape()) { + if (getInits()[i].getType().cast().getShape() != + getInits()[0].getType().cast().getShape()) { return emitOpError() << "expects all outputs to have the same shapes. " "Shape at output-index " << i << " is not equal to the shape at output-index 0."; } } - auto inputType = getInputs()[0].getType().cast(); - auto initType = getInits()[0].getType().cast(); + auto inputType = getInputs()[0].getType().cast(); + auto initType = getInits()[0].getType().cast(); DenseSet dimensionsToReduce; for (int64_t dimension : dimensionsRef) { diff --git a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp @@ -37,7 +37,7 @@ int64_t flatDimCtr = 0; for (Value operand : linalgOp->getOperands()) { assert(flatDimPos >= flatDimCtr && "invalid pos"); - auto shapedType = operand.getType().cast(); + auto shapedType = operand.getType().cast(); if (flatDimPos < flatDimCtr + shapedType.getRank()) { cstr.bound(value) < cstr.getExpr(operand, flatDimPos - flatDimCtr); break; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2180,7 +2180,7 @@ } staticSplitPoint = - parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic); + parser.getBuilder().getI64IntegerAttr(RankedShapedType::kDynamic); } result.addAttribute( @@ -2193,7 +2193,7 @@ void SplitOp::print(OpAsmPrinter &printer) { printer << " " << getTarget() << " after "; int64_t staticSplitSize = static_cast(getStaticSplitPoint()); - if (staticSplitSize != ShapedType::kDynamic) + if (staticSplitSize != RankedShapedType::kDynamic) printer << staticSplitSize; else printer << getDynamicSplitPoint(); @@ -2201,12 +2201,13 @@ printer.printOptionalAttrDict(getOperation()->getAttrs(), {getStaticSplitPointAttrName()}); printer << " : " << getTarget().getType(); - if (staticSplitSize == ShapedType::kDynamic) + if (staticSplitSize == RankedShapedType::kDynamic) printer << ", " << getDynamicSplitPoint().getType(); } LogicalResult SplitOp::verify() { - if ((static_cast(getStaticSplitPoint()) != ShapedType::kDynamic) ^ + if ((static_cast(getStaticSplitPoint()) != + RankedShapedType::kDynamic) ^ (getDynamicSplitPoint() == nullptr)) { return emitOpError() << "expects either a dynamic or a static split " "point to be provided"; @@ -2559,7 +2560,7 @@ unsigned dynamicPos = 0; Builder builder(getContext()); for (int64_t size : tileSizes) { - if (size == ShapedType::kDynamic) { + if (size == RankedShapedType::kDynamic) { results.push_back(dynamic[dynamicPos++]); } else { results.push_back(builder.getIndexAttr(size)); @@ -2975,7 +2976,7 @@ unsigned dynamicPos = 0; Builder builder(getContext()); for (int64_t size : tileSizes) { - if (size == ShapedType::kDynamic) { + if (size == RankedShapedType::kDynamic) { results.push_back(dynamic[dynamicPos++]); } else { results.push_back(builder.getIndexAttr(size)); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp @@ -64,7 +64,8 @@ if (genericOp.getNumDpsInits() != 1) return failure(); - auto outputType = genericOp.getResultTypes().front().dyn_cast(); + auto outputType = + genericOp.getResultTypes().front().dyn_cast(); // Require the output types to be static given that we are generating // constants. if (!outputType || !outputType.hasStaticShape()) @@ -174,7 +175,7 @@ auto inputShapes = llvm::to_vector<4>( llvm::map_range(genericOp.getInputs(), [](Value value) { - return value.getType().cast().getShape(); + return value.getType().cast().getShape(); })); // Given a `linearIndex`, remap it to a linear index to access linalg op diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp @@ -74,15 +74,16 @@ FailureOr> rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { - auto inputType = convOp.getInputs()[0].getType().cast(); - auto filterType = convOp.getInputs()[1].getType().cast(); - auto outputType = convOp.getOutputs()[0].getType().cast(); + auto inputType = convOp.getInputs()[0].getType().dyn_cast(); + auto filterType = + convOp.getInputs()[1].getType().dyn_cast(); + auto outputType = convOp.getOutputs()[0].getType().cast(); - if (!filterType.hasStaticShape()) + if (!filterType || !filterType.hasStaticShape()) return rewriter.notifyMatchFailure( convOp, "expected a static shape for the filter"); - if (!inputType.hasStaticShape()) + if (!inputType || !inputType.hasStaticShape()) return rewriter.notifyMatchFailure(convOp, "expected a static shape for the input"); @@ -360,15 +361,16 @@ FailureOr> rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { - auto inputType = convOp.getInputs()[0].getType().cast(); - auto filterType = convOp.getInputs()[1].getType().cast(); - auto outputType = convOp.getOutputs()[0].getType().cast(); + auto inputType = convOp.getInputs()[0].getType().dyn_cast(); + auto filterType = + convOp.getInputs()[1].getType().dyn_cast(); + auto outputType = convOp.getOutputs()[0].getType().cast(); - if (!filterType.hasStaticShape()) + if (!filterType || !filterType.hasStaticShape()) return rewriter.notifyMatchFailure( convOp, "expected a static shape for the filter"); - if (!inputType.hasStaticShape()) + if (!inputType || !inputType.hasStaticShape()) return rewriter.notifyMatchFailure(convOp, "expected a static shape for the input"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -49,7 +49,8 @@ /// /// Returns true if tensorType can be detensored. bool canBeDetensored(TensorType tensorType) { - return tensorType.hasRank() && tensorType.getRank() == 0; + auto rankedTensorType = tensorType.dyn_cast(); + return rankedTensorType && rankedTensorType.getRank() == 0; } bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -620,7 +620,7 @@ if (expandedShape.size() == 1) continue; for (int64_t shape : expandedShape.drop_front()) { - if (ShapedType::isDynamic(shape)) { + if (RankedShapedType::isDynamic(shape)) { return rewriter.notifyMatchFailure( genericOp, "cannot expand due to index semantics and dynamic dims"); } @@ -716,7 +716,7 @@ [&](int64_t dim) { return rewriter.create(loc, dim); }); Value newIndex = rewriter.create(loc, expandedDims.front()); for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) { - assert(!ShapedType::isDynamic(std::get<0>(it))); + assert(!RankedShapedType::isDynamic(std::get<0>(it))); AffineExpr idx, acc; bindDims(rewriter.getContext(), idx, acc); newIndex = rewriter.create( @@ -1512,8 +1512,9 @@ Value collapsedOpResult = collapsedGenericOp->getResult(originalResult.index()); auto originalResultType = - originalResult.value().getType().cast(); - auto collapsedOpResultType = collapsedOpResult.getType().cast(); + originalResult.value().getType().cast(); + auto collapsedOpResultType = + collapsedOpResult.getType().cast(); if (collapsedOpResultType.getRank() != originalResultType.getRank()) { AffineMap indexingMap = genericOp.getIndexingMapMatchingResult(originalResult.value()); @@ -1755,7 +1756,7 @@ modifiedOutput = true; SmallVector dynamicDims; for (const auto &dim : llvm::enumerate(operandType.getShape())) { - if (dim.value() != ShapedType::kDynamic) + if (dim.value() != RankedShapedType::kDynamic) continue; dynamicDims.push_back(rewriter.createOrFold( loc, operandVal, dim.index())); diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp @@ -92,7 +92,7 @@ sizes.reserve(offsets.size()); for (const auto &shape : llvm::enumerate( source.getType().cast().getShape())) { - if (ShapedType::isDynamic(shape.value())) { + if (RankedShapedType::isDynamic(shape.value())) { sizes.push_back( rewriter.create(loc, source, shape.index()) .getResult()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -155,11 +155,11 @@ if (!tensorType) continue; unsigned rank = tensorType.getRank(); - SmallVector staticOffsetsVector( - rank, ShapedType::kDynamic); - SmallVector staticSizesVector(rank, ShapedType::kDynamic); - SmallVector staticStridesVector( - rank, ShapedType::kDynamic); + SmallVector staticOffsetsVector(rank, + RankedShapedType::kDynamic); + SmallVector staticSizesVector(rank, RankedShapedType::kDynamic); + SmallVector staticStridesVector(rank, + RankedShapedType::kDynamic); resultTypes.push_back(tensor::ExtractSliceOp::inferResultType( tensorType, staticOffsetsVector, staticSizesVector, staticStridesVector)); diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -621,7 +621,8 @@ sizes = SmallVector(nPackedLoops, rewriter.getIndexAttr(1)); for (int64_t sz : transposedTensorType.getShape()) { // TODO: go grab dims when needed, atm tensor::PadOp yields a static tensor. - assert(!ShapedType::isDynamic(sz) && "padded tensor needs static sizes"); + assert(!RankedShapedType::isDynamic(sz) && + "padded tensor needs static sizes"); sizes.push_back(rewriter.getIndexAttr(sz)); } // strides = [1 .. 1]. @@ -692,7 +693,7 @@ } // Create the packed tensor. - SmallVector packedShape(nPackedLoops, ShapedType::kDynamic); + SmallVector packedShape(nPackedLoops, RankedShapedType::kDynamic); // TODO: go grab dims when needed, atm tensor::PadOp yields a static tensor. llvm::append_range(packedShape, transposedTensorType->getShape()); auto hoistedPackedTensorType = RankedTensorType::get( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -68,7 +68,7 @@ // Fallback dynamic buffer. auto dynamicBufferType = - MemRefType::get(ShapedType::kDynamic, b.getIntegerType(8)); + MemRefType::get(RankedShapedType::kDynamic, b.getIntegerType(8)); Value mul = b.createOrFold( b.create(width), allocSize); if (options.useAlloca) @@ -95,7 +95,7 @@ Value buffer = allocBuffer(b, options, viewType.getElementType(), allocSize, layout, alignment); SmallVector dynSizes(boundingSubViewSize.size(), - ShapedType::kDynamic); + RankedShapedType::kDynamic); Value view = b.createOrFold( MemRefType::get(dynSizes, viewType.getElementType()), buffer, zero, boundingSubViewSize); @@ -247,7 +247,8 @@ partialSizes.push_back( b.createOrFold(loc, subView, resultDimIdx++)); } - SmallVector dynSizes(fullSizes.size(), ShapedType::kDynamic); + SmallVector dynSizes(fullSizes.size(), + RankedShapedType::kDynamic); // If a callback is not specified, then use the default implementation for // allocating the promoted buffer. std::optional fullLocalView = diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -51,7 +51,8 @@ } SmallVector loopRanges = op.getStaticLoopRanges(); int64_t reductionDimSize = loopRanges[reductionDim]; - if (reductionDimSize == ShapedType::kDynamic || reductionDimSize % ratio != 0) + if (reductionDimSize == RankedShapedType::kDynamic || + reductionDimSize % ratio != 0) return b.notifyMatchFailure( op, "Reduction dimension not divisible by split ratio"); if (op.getNumDpsInits() != 1) @@ -262,7 +263,7 @@ unsigned reductionDimPos = dims[0]; SmallVector loopRanges = op.getStaticLoopRanges(); int64_t reductionDimSize = loopRanges[reductionDimPos]; - if (reductionDimSize == ShapedType::kDynamic || + if (reductionDimSize == RankedShapedType::kDynamic || reductionDimSize % splitFactor != 0 || insertSplitDimension >= loopRanges.size()) return b.notifyMatchFailure( diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -290,7 +290,7 @@ int64_t oldIdx = idx < insertSplitDimension ? idx : idx - 1; int64_t dim = oldShape[oldIdx]; newOutputShape.push_back(dim); - if (ShapedType::isDynamic(dim)) + if (RankedShapedType::isDynamic(dim)) dynamicDims.push_back(b.createOrFold( loc, linalgOp.getDpsInitOperand(0)->get(), oldIdx)); } @@ -366,7 +366,7 @@ // Then create a new reduction that only reduce the newly added dimension // from the previous op. int64_t intermRank = - partialReduce[0].getType().cast().getRank(); + partialReduce[0].getType().cast().getRank(); AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); SmallVector reductionIteratorTypes; SmallVector exprs; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -73,7 +73,7 @@ ? packPaddings[opOperand->getOperandNumber()] : false; bool hasStaticShape = llvm::none_of(shapeDimsToPad, [&](int64_t dim) { - return ShapedType::isDynamic(shape[dim]); + return RankedShapedType::isDynamic(shape[dim]); }); if (!nofold && hasStaticShape) return opOperand->get(); @@ -312,7 +312,8 @@ } // Fail hoisting if the operand shape is not fully static. - if (llvm::any_of(paddedOp.getShape(&opOperand), ShapedType::isDynamic)) { + if (llvm::any_of(paddedOp.getShape(&opOperand), + RankedShapedType::isDynamic)) { (void)rewriter.notifyMatchFailure(linalgOp, "non static padding shape -- skip"); continue; @@ -810,8 +811,8 @@ PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const { - auto inputShapedType = padOp.getSource().getType().cast(); - auto resultShapedType = padOp.getResult().getType().cast(); + auto inputShapedType = padOp.getSource().getType(); + auto resultShapedType = padOp.getResult().getType(); // Bail on non-static shapes. if (!inputShapedType.hasStaticShape()) @@ -986,7 +987,7 @@ } Location loc = packOp.getLoc(); - ShapedType inputType = packOp.getSourceType(); + auto inputType = packOp.getSourceType(); int64_t inputRank = inputType.getRank(); assert(llvm::all_of(packOp.getDestType().getShape().take_front(inputRank), [](int64_t val) { return val == 1; })); @@ -1058,7 +1059,7 @@ } readSizes.push_back(dimAndTileMapping[i]); readShape.push_back(getConstantIntValue(dimAndTileMapping[i]) - .value_or(ShapedType::kDynamic)); + .value_or(RankedShapedType::kDynamic)); } Type elemType = packOp.getSourceType().getElementType(); auto readType = RankedTensorType::get(readShape, elemType); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -212,7 +212,7 @@ std::optional maybeMaskingMap); // Holds the compile-time static sizes of the iteration space to vectorize. - // Dynamic dimensions are represented using ShapedType::kDynamic. + // Dynamic dimensions are represented using RankedShapedType::kDynamic. SmallVector iterSpaceStaticSizes; /// Holds the value sizes of the iteration space to vectorize. Static @@ -237,7 +237,7 @@ LinalgOp linalgOp) { // TODO: Support 0-d vectors. for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) { - if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) { + if (!RankedShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) { // Create constant index op for static dimensions. iterSpaceValueSizes.push_back(rewriter.create( linalgOp.getLoc(), iterSpaceStaticSizes[vecDim])); @@ -287,7 +287,7 @@ LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << "\n"); - if (ShapedType::isDynamicShape(canonicalVecShape)) + if (RankedShapedType::isDynamicShape(canonicalVecShape)) return failure(); // Initialize iteration space static sizes. @@ -864,7 +864,7 @@ targetShape.back() == 1) return VectorMemoryAccessKind::Gather; - auto inputShape = extractOp.getTensor().getType().cast(); + auto inputShape = extractOp.getTensor().getType(); // 2. Assume that it's a gather load when reading _from_ a tensor for which // the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`. @@ -1302,7 +1302,7 @@ if (!inputVectorSizes.empty()) { assert(inputVectorSizes.size() == linalgOp.getNumLoops() && "Input vector sizes don't match the number of loops"); - assert(!ShapedType::isDynamicShape(inputVectorSizes) && + assert(!RankedShapedType::isDynamicShape(inputVectorSizes) && "Input vector sizes can't have dynamic dimensions"); assert( llvm::all_of( @@ -1310,7 +1310,7 @@ [](std::tuple sizePair) { int64_t staticSize = std::get<0>(sizePair); int64_t inputSize = std::get<1>(sizePair); - return ShapedType::isDynamic(staticSize) || + return RankedShapedType::isDynamic(staticSize) || staticSize <= inputSize; }) && "Input vector sizes must be greater than or equal to iteration space " @@ -1923,7 +1923,7 @@ if (!padValue) return failure(); // Dynamic shapes not supported. - if (!padOp.getResult().getType().cast().hasStaticShape()) + if (!padOp.getResult().getType().hasStaticShape()) return failure(); // Pad result not used as destination. if (insertOp.getDest() == padOp.getResult()) @@ -2163,18 +2163,18 @@ // Convolution vectorization patterns //===----------------------------------------------------------------------===// -template -static void bindShapeDims(ShapedType shapedType) {} +template static void bindShapeDims(RankedShapedType shapedType) {} template -static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) { +static void bindShapeDims(RankedShapedType shapedType, IntTy &val, + IntTy2 &...vals) { val = shapedType.getShape()[N]; bindShapeDims(shapedType, vals...); } /// Bind a pack of int& to the leading dimensions of shapedType.getShape(). template -static void bindShapeDims(ShapedType shapedType, IntTy &...vals) { +static void bindShapeDims(RankedShapedType shapedType, IntTy &...vals) { bindShapeDims<0>(shapedType, vals...); } @@ -2245,9 +2245,9 @@ lhsShaped = linalgOp.getDpsInputOperand(0)->get(); rhsShaped = linalgOp.getDpsInputOperand(1)->get(); resShaped = linalgOp.getDpsInitOperand(0)->get(); - lhsShapedType = lhsShaped.getType().dyn_cast(); - rhsShapedType = rhsShaped.getType().dyn_cast(); - resShapedType = resShaped.getType().dyn_cast(); + lhsShapedType = lhsShaped.getType().dyn_cast(); + rhsShapedType = rhsShaped.getType().dyn_cast(); + resShapedType = resShaped.getType().dyn_cast(); if (!lhsShapedType || !rhsShapedType || !resShapedType) return; // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR @@ -2829,7 +2829,7 @@ bool isPoolExt = false; int strideW, dilationW; Value lhsShaped, rhsShaped, resShaped; - ShapedType lhsShapedType, rhsShapedType, resShapedType; + RankedShapedType lhsShapedType, rhsShapedType, resShapedType; // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops. // Returns true iff it is a valid conv/pooling op. diff --git a/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp @@ -52,20 +52,20 @@ OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim) { - auto shapedType = val.getType().cast(); - if (!shapedType.hasRank() || shapedType.isDynamicDim(dim)) + auto shapedType = val.getType().dyn_cast(); + if (!shapedType || shapedType.isDynamicDim(dim)) return createOrFoldDimOp(b, loc, val, dim); return b.getIndexAttr(shapedType.getDimSize(dim)); } SmallVector createDynamicDimensions(OpBuilder &b, Location loc, Value val) { - auto shapedType = val.getType().cast(); - assert(shapedType.hasRank() && "`val` must have a static rank"); + auto shapedType = val.getType().dyn_cast(); + assert(shapedType && "`val` must have a static rank"); SmallVector res; res.reserve(shapedType.getRank()); for (const auto &dim : llvm::enumerate(shapedType.getShape())) { - if (dim.value() == ShapedType::kDynamic) + if (dim.value() == RankedShapedType::kDynamic) res.push_back(createOrFoldDimOp(b, loc, val, dim.index())); } return res; @@ -73,8 +73,8 @@ SmallVector getMixedDimensions(OpBuilder &b, Location loc, Value val) { - auto shapedType = val.getType().cast(); - assert(shapedType.hasRank() && "`val` must have a static rank"); + auto shapedType = val.getType().dyn_cast(); + assert(shapedType && "`val` must have a static rank"); SmallVector dynamicDims = createDynamicDimensions(b, loc, val); return getMixedValues(shapedType.getShape(), dynamicDims, b); } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -685,7 +685,7 @@ ArrayRef lbs, ArrayRef ubs, ArrayRef subShapeSizes, bool omitPartialTileCheck) { - auto shapedType = valueToTile.getType().dyn_cast(); + auto shapedType = valueToTile.getType().dyn_cast(); assert(shapedType && "only shaped types can be tiled"); ArrayRef shape = shapedType.getShape(); int64_t rank = shapedType.getRank(); @@ -742,7 +742,7 @@ int64_t shapeSize = shape[r]; std::optional sizeCst = getConstantIntValue(size); auto hasTileSizeOne = sizeCst && *sizeCst == 1; - auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) && + auto dividesEvenly = sizeCst && !RankedShapedType::isDynamic(shapeSize) && ((shapeSize % *sizeCst) == 0); if (!hasTileSizeOne && !dividesEvenly) { LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -31,23 +31,20 @@ namespace saturated_arith { struct Wrapper { static Wrapper stride(int64_t v) { - return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} - : Wrapper{false, v}; + return (RankedShapedType::isDynamic(v)) ? Wrapper{true, 0} + : Wrapper{false, v}; } static Wrapper offset(int64_t v) { - return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} - : Wrapper{false, v}; + return (RankedShapedType::isDynamic(v)) ? Wrapper{true, 0} + : Wrapper{false, v}; } static Wrapper size(int64_t v) { - return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v}; - } - int64_t asOffset() { - return saturated ? ShapedType::kDynamic : v; - } - int64_t asSize() { return saturated ? ShapedType::kDynamic : v; } - int64_t asStride() { - return saturated ? ShapedType::kDynamic : v; + return (RankedShapedType::isDynamic(v)) ? Wrapper{true, 0} + : Wrapper{false, v}; } + int64_t asOffset() { return saturated ? RankedShapedType::kDynamic : v; } + int64_t asSize() { return saturated ? RankedShapedType::kDynamic : v; } + int64_t asStride() { return saturated ? RankedShapedType::kDynamic : v; } bool operator==(Wrapper other) { return (saturated && other.saturated) || (!saturated && !other.saturated && v == other.v); @@ -151,7 +148,7 @@ /// - `memRefTy == memref>` /// - `getAttributes == getConstantStrides` (i.e., a wrapper around /// `getStridesAndOffset`), and -/// - `isDynamic == ShapedType::isDynamic` +/// - `isDynamic == RankedShapedType::isDynamic` /// Will yield: `values == [2, 1]` static void constifyIndexValues( SmallVectorImpl &values, MemRefType memRefTy, @@ -303,7 +300,7 @@ for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { int64_t dimSize = memrefType.getDimSize(dim); // If this is already static dimension, keep it. - if (!ShapedType::isDynamic(dimSize)) { + if (!RankedShapedType::isDynamic(dimSize)) { newShapeConstants.push_back(dimSize); continue; } @@ -315,7 +312,7 @@ newShapeConstants.push_back(constSizeArg.getZExtValue()); } else { // Dynamic shape dimension not folded; copy dynamicSize from old memref. - newShapeConstants.push_back(ShapedType::kDynamic); + newShapeConstants.push_back(RankedShapedType::kDynamic); dynamicSizes.push_back(dynamicSize); } dynamicDimPos++; @@ -718,22 +715,21 @@ for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { auto ss = std::get<0>(it), st = std::get<1>(it); if (ss != st) - if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st)) + if (RankedShapedType::isDynamic(ss) && !RankedShapedType::isDynamic(st)) return false; } // If cast is towards more static offset along any dimension, don't fold. if (sourceOffset != resultOffset) - if (ShapedType::isDynamic(sourceOffset) && - !ShapedType::isDynamic(resultOffset)) + if (RankedShapedType::isDynamic(sourceOffset) && + !RankedShapedType::isDynamic(resultOffset)) return false; // If cast is towards more static strides along any dimension, don't fold. for (auto it : llvm::zip(sourceStrides, resultStrides)) { auto ss = std::get<0>(it), st = std::get<1>(it); if (ss != st) - if (ShapedType::isDynamic(ss) && - !ShapedType::isDynamic(st)) + if (RankedShapedType::isDynamic(ss) && !RankedShapedType::isDynamic(st)) return false; } @@ -766,8 +762,8 @@ // same. They are also compatible if either one is dynamic (see // description of MemRefCastOp for details). auto checkCompatible = [](int64_t a, int64_t b) { - return (ShapedType::isDynamic(a) || - ShapedType::isDynamic(b) || a == b); + return (RankedShapedType::isDynamic(a) || + RankedShapedType::isDynamic(b) || a == b); }; if (!checkCompatible(aOffset, bOffset)) return false; @@ -784,8 +780,8 @@ for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); - if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) && - aDim != bDim) + if (!RankedShapedType::isDynamic(aDim) && + !RankedShapedType::isDynamic(bDim) && aDim != bDim) return false; } return true; @@ -1441,7 +1437,7 @@ SmallVector ExtractStridedMetadataOp::getConstifiedMixedSizes() { SmallVector values = getAsOpFoldResult(getSizes()); constifyIndexValues(values, getSource().getType(), getContext(), - getConstantSizes, ShapedType::isDynamic); + getConstantSizes, RankedShapedType::isDynamic); return values; } @@ -1449,7 +1445,7 @@ ExtractStridedMetadataOp::getConstifiedMixedStrides() { SmallVector values = getAsOpFoldResult(getStrides()); constifyIndexValues(values, getSource().getType(), getContext(), - getConstantStrides, ShapedType::isDynamic); + getConstantStrides, RankedShapedType::isDynamic); return values; } @@ -1457,7 +1453,7 @@ OpFoldResult offsetOfr = getAsOpFoldResult(getOffset()); SmallVector values(1, offsetOfr); constifyIndexValues(values, getSource().getType(), getContext(), - getConstantOffset, ShapedType::isDynamic); + getConstantOffset, RankedShapedType::isDynamic); return values[0]; } @@ -1794,8 +1790,8 @@ OpFoldResult RankOp::fold(FoldAdaptor adaptor) { // Constant fold rank when the rank of the operand is known. auto type = getOperand().getType(); - auto shapedType = type.dyn_cast(); - if (shapedType && shapedType.hasRank()) + auto shapedType = type.dyn_cast(); + if (shapedType) return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); return IntegerAttr(); } @@ -1873,8 +1869,9 @@ // Match sizes in result memref type and in static_sizes attribute. for (auto [idx, resultSize, expectedSize] : llvm::enumerate(resultType.getShape(), getStaticSizes())) { - if (!ShapedType::isDynamic(resultSize) && - !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize) + if (!RankedShapedType::isDynamic(resultSize) && + !RankedShapedType::isDynamic(expectedSize) && + resultSize != expectedSize) return emitError("expected result type with size = ") << expectedSize << " instead of " << resultSize << " in dim = " << idx; @@ -1891,8 +1888,8 @@ // Match offset in result memref type and in static_offsets attribute. int64_t expectedOffset = getStaticOffsets().front(); - if (!ShapedType::isDynamic(resultOffset) && - !ShapedType::isDynamic(expectedOffset) && + if (!RankedShapedType::isDynamic(resultOffset) && + !RankedShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset) return emitError("expected result type with offset = ") << expectedOffset << " instead of " << resultOffset; @@ -1900,8 +1897,8 @@ // Match strides in result memref type and in static_strides attribute. for (auto [idx, resultStride, expectedStride] : llvm::enumerate(resultStrides, getStaticStrides())) { - if (!ShapedType::isDynamic(resultStride) && - !ShapedType::isDynamic(expectedStride) && + if (!RankedShapedType::isDynamic(resultStride) && + !RankedShapedType::isDynamic(expectedStride) && resultStride != expectedStride) return emitError("expected result type with stride = ") << expectedStride << " instead of " << resultStride @@ -1944,14 +1941,14 @@ SmallVector ReinterpretCastOp::getConstifiedMixedSizes() { SmallVector values = getMixedSizes(); constifyIndexValues(values, getType(), getContext(), getConstantSizes, - ShapedType::isDynamic); + RankedShapedType::isDynamic); return values; } SmallVector ReinterpretCastOp::getConstifiedMixedStrides() { SmallVector values = getMixedStrides(); constifyIndexValues(values, getType(), getContext(), getConstantStrides, - ShapedType::isDynamic); + RankedShapedType::isDynamic); return values; } @@ -1960,7 +1957,7 @@ assert(values.size() == 1 && "reinterpret_cast must have one and only one offset"); constifyIndexValues(values, getType(), getContext(), getConstantOffset, - ShapedType::isDynamic); + RankedShapedType::isDynamic); return values[0]; } @@ -2120,7 +2117,7 @@ << expandedDim << " is out of bounds"; // Check if there are multiple dynamic dims in a reassociation group. - if (ShapedType::isDynamic(expandedShape[expandedDim])) { + if (RankedShapedType::isDynamic(expandedShape[expandedDim])) { if (foundDynamic && !allowMultipleDynamicDimsPerGroup) return op->emitOpError( "at most one dimension in a reassociation group may be dynamic"); @@ -2129,7 +2126,8 @@ } // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity. - if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic) + if (RankedShapedType::isDynamic(collapsedShape[collapsedDim]) != + foundDynamic) return op->emitOpError("collapsed dim (") << collapsedDim << ") must be dynamic if and only if reassociation group is " @@ -2323,14 +2321,14 @@ ArrayRef ref = llvm::ArrayRef(reassoc); while (srcShape[ref.back()] == 1 && ref.size() > 1) ref = ref.drop_back(); - if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) { + if (!RankedShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) { resultStrides.push_back(srcStrides[ref.back()]); } else { // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so // the corresponding stride may have to be skipped. (See above comment.) // Therefore, the result stride cannot be statically determined and must // be dynamic. - resultStrides.push_back(ShapedType::kDynamic); + resultStrides.push_back(RankedShapedType::kDynamic); } } @@ -2533,7 +2531,7 @@ if (resultMemRefType) { if (!resultMemRefType.getLayout().isIdentity()) return emitOpError("result memref type should have identity affine map"); - if (shapeSize == ShapedType::kDynamic) + if (shapeSize == RankedShapedType::kDynamic) return emitOpError("cannot use shape operand with dynamic length to " "reshape to statically-ranked memref type"); if (shapeSize != resultMemRefType.getRank()) @@ -3134,8 +3132,8 @@ } OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) { - auto resultShapedType = getResult().getType().cast(); - auto sourceShapedType = getSource().getType().cast(); + auto resultShapedType = getResult().getType().cast(); + auto sourceShapedType = getSource().getType().cast(); if (resultShapedType.hasStaticShape() && resultShapedType == sourceShapedType) { @@ -3327,7 +3325,7 @@ for (unsigned dim = 0, e = rank; dim < e; ++dim) { int64_t dimSize = memrefType.getDimSize(dim); // If this is already static dimension, keep it. - if (!ShapedType::isDynamic(dimSize)) { + if (!RankedShapedType::isDynamic(dimSize)) { newShapeConstants.push_back(dimSize); continue; } diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -83,7 +83,7 @@ bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols}); AffineExpr expr = symbols.front(); - values[0] = ShapedType::isDynamic(sourceOffset) + values[0] = RankedShapedType::isDynamic(sourceOffset) ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) : rewriter.getIndexAttr(sourceOffset); SmallVector subOffsets = subview.getMixedOffsets(); @@ -93,7 +93,7 @@ for (unsigned i = 0; i < sourceRank; ++i) { // Compute the stride. OpFoldResult origStride = - ShapedType::isDynamic(sourceStrides[i]) + RankedShapedType::isDynamic(sourceStrides[i]) ? origStrides[i] : OpFoldResult(rewriter.getIndexAttr(sourceStrides[i])); strides.push_back(makeComposedFoldedAffineApply( @@ -259,7 +259,7 @@ // Fill up all the statically known sizes. for (unsigned i = 0; i < groupSize; ++i) { uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]); - if (ShapedType::isDynamic(dimSize)) { + if (RankedShapedType::isDynamic(dimSize)) { assert(!dynSizeIdx && "There must be at most one dynamic size per group"); dynSizeIdx = i; continue; @@ -327,7 +327,7 @@ for (int i = groupSize - 1; i >= 0; --i) { expandedStrides[i] = builder.getIndexAttr(currentStride); uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]); - if (ShapedType::isDynamic(dimSize)) { + if (RankedShapedType::isDynamic(dimSize)) { assert(!dynSizeIdx && "There must be at most one dynamic size per group"); dynSizeIdx = i; continue; @@ -341,7 +341,7 @@ auto sourceType = source.getType().cast(); auto [strides, offset] = getStridesAndOffset(sourceType); - OpFoldResult origStride = ShapedType::isDynamic(strides[groupId]) + OpFoldResult origStride = RankedShapedType::isDynamic(strides[groupId]) ? origStrides[groupId] : builder.getIndexAttr(strides[groupId]); @@ -351,7 +351,7 @@ // that dimension with the dynamic size. if (dynSizeIdx) { int64_t productOfAllStaticSizes = currentStride; - assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) && + assert(RankedShapedType::isDynamic(sourceType.getDimSize(groupId)) && "We shouldn't be able to change dynamicity"); OpFoldResult origSize = origSizes[groupId]; @@ -436,7 +436,7 @@ MemRefType collapseShapeType = collapseShape.getResultType(); uint64_t size = collapseShapeType.getDimSize(groupId); - if (!ShapedType::isDynamic(size)) { + if (!RankedShapedType::isDynamic(size)) { collapsedSize.push_back(builder.getIndexAttr(size)); return collapsedSize; } @@ -452,7 +452,7 @@ collapsedSize.push_back(getProductOfValues( reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(), - origSizes, ShapedType::isDynamic)); + origSizes, RankedShapedType::isDynamic)); return collapsedSize; } @@ -495,7 +495,7 @@ continue; int64_t currentStride = strides[currentDim]; - groupStrides.push_back(ShapedType::isDynamic(currentStride) + groupStrides.push_back(RankedShapedType::isDynamic(currentStride) ? origStrides[currentDim] : builder.getIndexAttr(currentStride)); } @@ -506,14 +506,14 @@ auto [collapsedStrides, collapsedOffset] = getStridesAndOffset(collapsedType); int64_t finalStride = collapsedStrides[groupId]; - if (ShapedType::isDynamic(finalStride)) { + if (RankedShapedType::isDynamic(finalStride)) { // Look for a dynamic stride. At this point we don't know which one is // desired, but they are all equally good/bad. for (int64_t currentDim : reassocGroup) { assert(srcShape[currentDim] == 1 && "We should be dealing with 1x1x...x1"); - if (ShapedType::isDynamic(strides[currentDim])) + if (RankedShapedType::isDynamic(strides[currentDim])) return {origStrides[currentDim]}; } llvm_unreachable("We should have found a dynamic stride"); @@ -574,7 +574,7 @@ unsigned reshapeRank = reshapeType.getRank(); OpFoldResult offsetOfr = - ShapedType::isDynamic(offset) + RankedShapedType::isDynamic(offset) ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) : rewriter.getIndexAttr(offset); @@ -665,7 +665,7 @@ sizes.reserve(rank); unsigned dynamicPos = 0; for (int64_t size : memRefType.getShape()) { - if (ShapedType::isDynamic(size)) + if (RankedShapedType::isDynamic(size)) sizes.push_back(dynamic[dynamicPos++]); else sizes.push_back(rewriter.getIndexAttr(size)); @@ -749,7 +749,7 @@ // Collect the sizes. ArrayRef sizes = memRefType.getShape(); - assert(!llvm::any_of(sizes, ShapedType::isDynamic) && + assert(!llvm::any_of(sizes, RankedShapedType::isDynamic) && "unexpected dynamic shape for result of `memref.get_global` op"); // Strides (just creates identity strides). diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -60,9 +60,9 @@ // strides from unranked memrefs, so cast the source to a type with fully // dynamic layout, from which we can then extract the offset and strides. // (Rank was already verified.) - int64_t dynamicOffset = ShapedType::kDynamic; + int64_t dynamicOffset = RankedShapedType::kDynamic; SmallVector dynamicShape(resultType.getRank(), - ShapedType::kDynamic); + RankedShapedType::kDynamic); auto stridedLayout = StridedLayoutAttr::get(builder.getContext(), dynamicOffset, dynamicShape); auto dynStridesType = @@ -102,7 +102,7 @@ return; // Check offset. - if (resultOffset != ShapedType::kDynamic) { + if (resultOffset != RankedShapedType::kDynamic) { // Static/dynamic offset -> dynamic offset does not need verification. Value srcOffset = metadataOp.getResult(1); Value resultOffsetVal = @@ -116,7 +116,7 @@ // Check strides. for (const auto &it : llvm::enumerate(resultStrides)) { // Static/dynamic stride -> dynamic stride does not need verification. - if (it.value() == ShapedType::kDynamic) + if (it.value() == RankedShapedType::kDynamic) continue; Value srcStride = diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -130,7 +130,7 @@ } if (quantizedType.isa()) { // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - ShapedType sType = quantizedType.cast(); + RankedShapedType sType = quantizedType.cast(); if (!sType.getElementType().isa()) { return nullptr; } @@ -156,7 +156,8 @@ return *this; } if (candidateType.isa()) { - ShapedType candidateShapedType = candidateType.cast(); + RankedShapedType candidateShapedType = + candidateType.cast(); if (candidateShapedType.getElementType() != getExpressedType()) { return nullptr; } @@ -185,7 +186,7 @@ } if (quantizedType.isa()) { // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - ShapedType sType = quantizedType.cast(); + RankedShapedType sType = quantizedType.cast(); if (!sType.getElementType().isa()) { return nullptr; } diff --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp --- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp +++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp @@ -61,7 +61,7 @@ UniformQuantizedPerAxisValueConverter::convert(DenseFPElementsAttr attr) { // Creates the converter for each chunk. Normally the size of the // quantization dim is 3, so we can cache all the converters. - ShapedType type = attr.getType(); + RankedShapedType type = attr.getType(); size_t dimSize = type.getDimSize(quantizationDim); if (dimSize != scales.size()) { return {}; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1976,7 +1976,7 @@ if (valueType == opType) return success(); auto arrayType = opType.dyn_cast(); - auto shapedType = valueType.dyn_cast(); + auto shapedType = valueType.dyn_cast(); if (!arrayType) return op.emitOpError("result or element type (") << opType << ") does not match value type (" << valueType diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -185,9 +185,9 @@ return elementSize; auto dims = memRefType.getShape(); - if (llvm::is_contained(dims, ShapedType::kDynamic) || - ShapedType::isDynamic(offset) || - llvm::is_contained(strides, ShapedType::kDynamic)) + if (llvm::is_contained(dims, RankedShapedType::kDynamic) || + RankedShapedType::isDynamic(offset) || + llvm::is_contained(strides, RankedShapedType::kDynamic)) return std::nullopt; int64_t memrefSize = -1; @@ -197,7 +197,7 @@ return (offset + memrefSize) * *elementSize; } - if (auto tensorType = type.dyn_cast()) { + if (auto tensorType = type.dyn_cast()) { if (!tensorType.hasStaticShape()) return std::nullopt; @@ -342,7 +342,8 @@ const SPIRVConversionOptions &options, TensorType type) { // TODO: Handle dynamic shapes. - if (!type.hasStaticShape()) { + auto rankedTensorType = dyn_cast(type); + if (!rankedTensorType || !rankedTensorType.hasStaticShape()) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: dynamic shape unimplemented\n"); return nullptr; @@ -817,8 +818,8 @@ int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(baseType, strides, offset)) || - llvm::is_contained(strides, ShapedType::kDynamic) || - ShapedType::isDynamic(offset)) { + llvm::is_contained(strides, RankedShapedType::kDynamic) || + RankedShapedType::isDynamic(offset)) { return nullptr; } @@ -848,8 +849,8 @@ int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(baseType, strides, offset)) || - llvm::is_contained(strides, ShapedType::kDynamic) || - ShapedType::isDynamic(offset)) { + llvm::is_contained(strides, RankedShapedType::kDynamic) || + RankedShapedType::isDynamic(offset)) { return nullptr; } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -48,8 +48,8 @@ LogicalResult shape::getShapeVec(Value input, SmallVectorImpl &shapeValues) { if (auto inputOp = input.getDefiningOp()) { - auto type = inputOp.getArg().getType().cast(); - if (!type.hasRank()) + auto type = inputOp.getArg().getType().dyn_cast(); + if (!type) return failure(); llvm::append_range(shapeValues, type.getShape()); return success(); @@ -1078,8 +1078,8 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) { Type valType = getValue().getType(); - auto valShapedType = valType.dyn_cast(); - if (!valShapedType || !valShapedType.hasRank()) + auto valShapedType = valType.dyn_cast(); + if (!valShapedType) return nullptr; std::optional index = getConstantIndex(); if (!index.has_value()) @@ -1087,7 +1087,7 @@ if (index.value() >= valShapedType.getRank()) return nullptr; auto extent = valShapedType.getDimSize(*index); - if (ShapedType::isDynamic(extent)) + if (RankedShapedType::isDynamic(extent)) return nullptr; return IntegerAttr::get(IndexType::get(getContext()), extent); } @@ -1106,8 +1106,8 @@ } LogicalResult mlir::shape::DimOp::verify() { - auto st = getValue().getType().cast(); - if (!st.hasRank()) + auto st = getValue().getType().dyn_cast(); + if (!st) return success(); if (auto index = getConstantIndex()) { if (*index < 0 || *index >= st.getRank()) @@ -1439,9 +1439,9 @@ } else if (isExtentTensorType(l)) { auto rank1 = l.cast().getShape()[0]; auto rank2 = r.cast().getShape()[0]; - if (ShapedType::isDynamic(rank1)) + if (RankedShapedType::isDynamic(rank1)) acc = l; - else if (ShapedType::isDynamic(rank2)) + else if (RankedShapedType::isDynamic(rank2)) acc = r; else if (rank1 != rank2) return emitOptionalError(location, "unequal shape cardinality"); @@ -1696,7 +1696,7 @@ //===----------------------------------------------------------------------===// OpFoldResult ShapeOfOp::fold(FoldAdaptor) { - auto type = getOperand().getType().dyn_cast(); + auto type = getOperand().getType().dyn_cast(); if (!type || !type.hasStaticShape()) return nullptr; Builder builder(getContext()); @@ -1766,9 +1766,8 @@ if (operands[0].getType().isa()) inferredReturnTypes.assign({ShapeType::get(context)}); else { - auto shapedTy = operands[0].getType().cast(); - int64_t rank = - shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic; + auto shapedTy = operands[0].getType().dyn_cast(); + int64_t rank = shapedTy ? shapedTy.getRank() : RankedShapedType::kDynamic; Type indexTy = IndexType::get(context); Type extentTensorTy = RankedTensorType::get({rank}, indexTy); inferredReturnTypes.assign({extentTensorTy}); diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td --- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td +++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td @@ -11,7 +11,8 @@ }]>>; def HasStaticShape : Constraint().hasStaticShape() + $0.getType().isa() + && $0.getType().cast().hasStaticShape() }]>>; // Helper that takes the first element of a range. diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -677,8 +677,8 @@ for (unsigned i = 0; i < nBatched; i++) { const auto valBatch = valuesTp.getShape()[i]; const auto crdBatch = coordinatesTp.getShape()[i]; - if (ShapedType::isDynamic(valBatch) || ShapedType::isDynamic(crdBatch) || - crdBatch != valBatch) { + if (RankedShapedType::isDynamic(valBatch) || + RankedShapedType::isDynamic(crdBatch) || crdBatch != valBatch) { return op->emitError( "values/coordinates batched level sizes don't match statically"); } @@ -686,8 +686,8 @@ const auto valuesNSE = valuesTp.getShape()[nBatched]; const auto coordsNSE = coordinatesTp.getShape()[nBatched]; - if (!ShapedType::isDynamic(valuesNSE) && !ShapedType::isDynamic(coordsNSE) && - valuesNSE != coordsNSE) + if (!RankedShapedType::isDynamic(valuesNSE) && + !RankedShapedType::isDynamic(coordsNSE) && valuesNSE != coordsNSE) return op->emitError("values/coordinates number-of-elements don't match"); // NOTE: We use `getLvlRank` because the `coordinatesTp` is for @@ -695,7 +695,7 @@ const DynSize coordsRank = coordinatesTp.getShape()[1 + nBatched]; const Level tensorRank = tensorTp.getLvlRank(); // FIXME: replace the `operator!=` with our backported `safelyNE`. - if (!ShapedType::isDynamic(coordsRank) && + if (!RankedShapedType::isDynamic(coordsRank) && coordsRank != static_cast(tensorRank) - nBatched) return op->emitError("input/output level-ranks don't match"); @@ -738,7 +738,7 @@ // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++) - if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic) + if (shape1[d] != shape2[d] && shape2[d] != RankedShapedType::kDynamic) return emitError("unexpected conversion mismatch in dimension ") << d; return success(); } @@ -943,7 +943,7 @@ for (Dimension d = 0; d < dimRank; d++) { const DynSize dstSh = dstTp.getDimShape()[d]; if (d == concatDim) { - if (!ShapedType::isDynamic(dstSh)) { + if (!RankedShapedType::isDynamic(dstSh)) { // If we reach here, then all inputs have static shapes. So we // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)` // to avoid redundant assertions in the loop. @@ -961,7 +961,7 @@ DynSize prev = dstSh; for (const auto src : getInputs()) { const auto sh = getSparseTensorType(src).getDimShape()[d]; - if (!ShapedType::isDynamic(prev) && sh != prev) + if (!RankedShapedType::isDynamic(prev) && sh != prev) return emitError("All dimensions (expect for the concatenating one) " "should be equal."); prev = sh; @@ -1104,7 +1104,7 @@ const DynSize sh = mtp.getShape()[0]; // We can't check the size of dynamic dimension at compile-time, but all // xs and ys should have a dimension not less than n at runtime. - if (n && !ShapedType::isDynamic(sh) && sh < n.value()) + if (n && !RankedShapedType::isDynamic(sh) && sh < n.value()) return emitError(llvm::formatv("xs and ys need to have a dimension >= n" ": {0} < {1}", sh, n.value())); @@ -1142,7 +1142,7 @@ // of this lambda. const auto checkDim = [&](Value v, StaticSize minSize, const char *message) { const DynSize sh = getMemRefType(v).getShape()[0]; - if (!ShapedType::isDynamic(sh) && sh < minSize) + if (!RankedShapedType::isDynamic(sh) && sh < minSize) emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize)); }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -185,10 +185,10 @@ /// a memref with a standard unit stride zero offset layout is returned. inline MemRefType get1DMemRefType(Type etp, bool withLayout) { auto layout = withLayout ? StridedLayoutAttr::StridedLayoutAttr::get( - etp.getContext(), ShapedType::kDynamic, - {ShapedType::kDynamic}) + etp.getContext(), RankedShapedType::kDynamic, + {RankedShapedType::kDynamic}) : StridedLayoutAttr(); - return MemRefType::get(ShapedType::kDynamic, etp, layout); + return MemRefType::get(RankedShapedType::kDynamic, etp, layout); } /// Scans to top of generated loop. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -275,12 +275,12 @@ // expanded from the i-th dimension in srcShape. // For example, if srcDim = 8, then the expanded shape could be <2x?x2>, // but not <2x?x?>. - if (staticDstShape[j] == ShapedType::kDynamic) { + if (staticDstShape[j] == RankedShapedType::kDynamic) { // The expanded dimension has dynamic size. We compute the dimension // by dividing srcDim by the product of the static dimensions. StaticSize product = 1; for (unsigned k = start; k < start + map.size(); k++) { - if (staticDstShape[k] != ShapedType::kDynamic) { + if (staticDstShape[k] != RankedShapedType::kDynamic) { product *= staticDstShape[k]; } } @@ -391,7 +391,7 @@ Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp) { - auto memTp = MemRefType::get({ShapedType::kDynamic}, tp); + auto memTp = MemRefType::get({RankedShapedType::kDynamic}, tp); return builder.create(loc, memTp, ValueRange{sz}); } @@ -420,7 +420,7 @@ auto memTp = MemRefType::get(shape, elemTp); SmallVector dynamicSizes; for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) { - if (shape[i] == ShapedType::kDynamic) + if (shape[i] == RankedShapedType::kDynamic) dynamicSizes.push_back(sizes[i]); } Value mem = builder.create(loc, memTp, dynamicSizes); @@ -584,7 +584,8 @@ const auto memTp = mem.getType().cast(); assert(memTp.getRank() == 1); const DynSize memSh = memTp.getDimSize(0); - assert(ShapedType::isDynamic(memSh) || memSh >= static_cast(size)); + assert(RankedShapedType::isDynamic(memSh) || + memSh >= static_cast(size)); assert(offsetIdx == 0 || offsetIdx < size); #endif // NDEBUG SmallVector vs; @@ -606,7 +607,8 @@ const auto memTp = mem.getType().cast(); assert(memTp.getRank() == 1); const DynSize memSh = memTp.getDimSize(0); - assert(ShapedType::isDynamic(memSh) || memSh >= static_cast(vsize)); + assert(RankedShapedType::isDynamic(memSh) || + memSh >= static_cast(vsize)); assert(offsetIdx == 0 || offsetIdx < vsize); #endif // NDEBUG for (const auto &v : llvm::enumerate(vs)) { @@ -639,7 +641,7 @@ const auto lvlSizesTp = MemRefType::get(lvlSizesShape, iTp); lvlCoords = builder.create(loc, lvlSizesTp, lvlCoords); // Finally, create the ReshapeOp. - const SmallVector resShape(lvlRank, ShapedType::kDynamic); + const SmallVector resShape(lvlRank, RankedShapedType::kDynamic); const Type elemTp = getMemRefType(valuesBuffer).getElementType(); const auto resTp = MemRefType::get(resShape, elemTp); return builder.create(loc, resTp, valuesBuffer, lvlCoords); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -1217,7 +1217,7 @@ auto mtp = getMemRefType(v); if (!mtp.isDynamicDim(0)) { auto newMtp = - MemRefType::get({ShapedType::kDynamic}, mtp.getElementType()); + MemRefType::get({RankedShapedType::kDynamic}, mtp.getElementType()); v = rewriter.create(loc, newMtp, v); } operands.push_back(v); @@ -1313,7 +1313,7 @@ Value c2 = constantIndex(rewriter, loc, 2); auto bufferType = - MemRefType::get({ShapedType::kDynamic}, value.getType()); + MemRefType::get({RankedShapedType::kDynamic}, value.getType()); scf::IfOp ifOp = rewriter.create(loc, bufferType, cond, /*else=*/true); // True branch. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -207,7 +207,7 @@ dimSizes.reserve(dimRank); unsigned i = 0; // cumulative index into `dynSizes`. for (const DynSize sh : stt.getDimShape()) - dimSizes.push_back(ShapedType::isDynamic(sh) + dimSizes.push_back(RankedShapedType::isDynamic(sh) ? dynSizes[i++] : constantIndex(builder, loc, sh)); @@ -873,7 +873,7 @@ const auto sz = sizeFromTensorAtDim(rewriter, loc, desc, innerDim); // Generate a memref for `sz` elements of type `t`. const auto genAlloc = [&](Type t) { - const auto memTp = MemRefType::get({ShapedType::kDynamic}, t); + const auto memTp = MemRefType::get({RankedShapedType::kDynamic}, t); return rewriter.create(loc, memTp, ValueRange{sz}); }; // Allocate temporary buffers for values/filled-switch and added. @@ -1244,7 +1244,7 @@ ubs.push_back(constantIndex(builder, loc, dimSz)); steps.push_back(c1); } - auto tensorType = op.getValues().getType(); + auto tensorType = op.getValues().getType().cast(); auto memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); Value batV = builder.create(loc, memrefType, @@ -1313,7 +1313,8 @@ matchAndRewrite(PackOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { const unsigned batchedLvls = op.getNumBatchedLvls(); - unsigned nse = op.getValues().getType().getDimSize(batchedLvls); + unsigned nse = op.getValues().getType().cast().getDimSize( + batchedLvls); const auto stt = getSparseTensorType(op.getResult()); assert(isCOOType(stt.getEncoding(), batchedLvls, true)); @@ -1322,7 +1323,7 @@ batchDimSzs.reserve(batchedLvls); for (unsigned i = 0; i < batchedLvls; i++) { // Should already be guaranteed by verifier. - assert(!ShapedType::isDynamic(stt.getDimShape()[i])); + assert(!RankedShapedType::isDynamic(stt.getDimShape()[i])); batchedCount *= stt.getDimShape()[i]; batchDimSzs.push_back(stt.getDimShape()[i]); } @@ -1372,7 +1373,8 @@ break; } case SparseTensorFieldKind::CrdMemRef: { - auto tensorType = op.getCoordinates().getType(); + auto tensorType = + op.getCoordinates().getType().cast(); auto memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); field = rewriter.create( @@ -1381,7 +1383,7 @@ break; } case SparseTensorFieldKind::ValMemRef: { - auto tensorType = op.getValues().getType(); + auto tensorType = op.getValues().getType().cast(); auto memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); field = rewriter.create( @@ -1414,7 +1416,7 @@ for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) { // FIXME: dim/lvl confusion! const auto sh = stt.getDimShape()[lvl]; - assert(!ShapedType::isDynamic(sh)); + assert(!RankedShapedType::isDynamic(sh)); desc.setLvlSize(rewriter, loc, lvl, constantIndex(rewriter, loc, sh)); if (lvl == 0) desc.setPosMemSize(rewriter, loc, lvl, constantIndex(rewriter, loc, 2)); @@ -1524,7 +1526,7 @@ createFuncCall(rewriter, loc, "copySparseTensorReaderDimSizes", {}, {reader, dimSizes}, EmitCInterface::On); for (const auto &d : llvm::enumerate(dstTp.getDimShape())) - if (ShapedType::isDynamic(d.value())) + if (RankedShapedType::isDynamic(d.value())) dynSizes.push_back(rewriter.create( loc, dimSizes, constantIndex(rewriter, loc, d.index()))); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -145,7 +145,7 @@ out.clear(); out.reserve(stt.getDimRank()); for (const DynSize sh : stt.getDimShape()) { - const auto s = ShapedType::isDynamic(sh) ? 0 : sh; + const auto s = RankedShapedType::isDynamic(sh) ? 0 : sh; out.push_back(constantIndex(builder, loc, s)); } } @@ -196,7 +196,7 @@ /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack, /// this buffer must be explicitly deallocated by client. static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) { - auto memTp = MemRefType::get({ShapedType::kDynamic}, tp); + auto memTp = MemRefType::get({RankedShapedType::kDynamic}, tp); return rewriter.create(loc, memTp, ValueRange{sz}); } @@ -737,7 +737,7 @@ Value dimSizesBuffer; if (stt.hasDynamicDimShape()) { Type indexTp = rewriter.getIndexType(); - auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp); + auto memTp = MemRefType::get({RankedShapedType::kDynamic}, indexTp); dimSizesBuffer = createFuncCall(rewriter, loc, "getSparseTensorReaderDimSizes", memTp, reader, EmitCInterface::On) @@ -1114,7 +1114,7 @@ Location loc = op.getLoc(); // Query values array size for the actually stored values size. Type eltType = op.getTensor().getType().cast().getElementType(); - auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType); + auto resTp = MemRefType::get({RankedShapedType::kDynamic}, eltType); Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands()); rewriter.replaceOpWithNewOp(op, values, constantIndex(rewriter, loc, 0)); @@ -1308,9 +1308,9 @@ dstTensor = dst; // Get the values buffer for the sparse tensor and reshape it to the // corresponding dense tensor shape. - dst = genValuesCall(rewriter, loc, - MemRefType::get({ShapedType::kDynamic}, elemTp), - {dst}); + dst = genValuesCall( + rewriter, loc, + MemRefType::get({RankedShapedType::kDynamic}, elemTp), {dst}); // Pass the `dstDimCoords` buffer for `reshapeValuesToLevels` // to reuse for storing level-sizes (yes, "level-sizes"). // This is safe to do because `dstTp` is a dense-tensor type, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -116,10 +116,10 @@ /// Populates given sizes array from type (for static sizes) and from /// the tensor (for dynamic sizes). static void sizesForTensor(OpBuilder &builder, SmallVectorImpl &sizes, - Location loc, ShapedType stp, Value tensor) { + Location loc, RankedShapedType stp, Value tensor) { for (const auto &d : enumerate(stp.getShape())) { Value dim; - if (d.value() == ShapedType::kDynamic) + if (d.value() == RankedShapedType::kDynamic) dim = builder.create(loc, tensor, d.index()); else dim = constantIndex(builder, loc, d.value()); @@ -145,7 +145,7 @@ const SmallVectorImpl &sizes, SmallVectorImpl &dynSizes) { for (const auto &d : enumerate(tp.getShape())) { - if (d.value() == ShapedType::kDynamic) + if (d.value() == RankedShapedType::kDynamic) dynSizes.push_back(sizes[d.index()]); } } @@ -183,13 +183,13 @@ /// sizes) and from the source tensors (for dynamic sizes). static void concatSizesFromInputs(OpBuilder &builder, SmallVectorImpl &sizes, Location loc, - ShapedType dstTp, ValueRange srcs, + RankedShapedType dstTp, ValueRange srcs, unsigned dim) { auto dstShape = dstTp.getShape(); sizesFromSrc(builder, sizes, loc, srcs[0]); // Sum up on the `dim` if the dimension is dynamic. - if (dstShape[dim] != ShapedType::kDynamic) { + if (dstShape[dim] != RankedShapedType::kDynamic) { // Faithfully take the static size. sizes[dim] = constantIndex(builder, loc, dstShape[dim]); } else { @@ -375,7 +375,7 @@ genReshapeDstShape(rewriter, loc, dstSizes, srcSizes, dstShape, op.getReassociationIndices()); for (auto [idx, shape] : llvm::enumerate(dstShape)) { - if (shape == ShapedType::kDynamic) + if (shape == RankedShapedType::kDynamic) dstDynSizes.push_back(dstSizes[idx]); } } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp @@ -178,11 +178,14 @@ const Type metaDataType = StorageSpecifierType::get(stt.getEncoding()); // memref positions - const Type posMemType = MemRefType::get({ShapedType::kDynamic}, posType); + const Type posMemType = + MemRefType::get({RankedShapedType::kDynamic}, posType); // memref coordinates - const Type crdMemType = MemRefType::get({ShapedType::kDynamic}, crdType); + const Type crdMemType = + MemRefType::get({RankedShapedType::kDynamic}, crdType); // memref values - const Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType); + const Type valMemType = + MemRefType::get({RankedShapedType::kDynamic}, eltType); foreachFieldInSparseTensor( stt.getEncoding(), diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1242,7 +1242,7 @@ Value tensor = lhs->get(); Location loc = op.getLoc(); if (atStart) { - auto dynShape = {ShapedType::kDynamic}; + auto dynShape = {RankedShapedType::kDynamic}; Type etp = tensor.getType().cast().getElementType(); Type t1 = MemRefType::get(dynShape, etp); Type t2 = MemRefType::get(dynShape, builder.getI1Type()); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp @@ -41,7 +41,7 @@ static OpFoldResult getCollapsedOutputDimFromInputShape( OpBuilder &builder, Location loc, int64_t dimIndex, Value src, ArrayRef dstStaticShape, ArrayRef reassociationMap) { - if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) { + if (!RankedShapedType::isDynamic(dstStaticShape[dimIndex])) { return builder.getIndexAttr(dstStaticShape[dimIndex]); } AffineMap map = reassociationMap[dimIndex]; @@ -78,7 +78,7 @@ OpBuilder &builder, Location loc, int64_t dimIndex, Value src, ArrayRef dstStaticShape, ArrayRef reassociation, llvm::DenseMap &expandedDimToCollapsedDim) { - if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) { + if (!RankedShapedType::isDynamic(dstStaticShape[dimIndex])) { return builder.getIndexAttr(dstStaticShape[dimIndex]); } unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex]; @@ -97,7 +97,7 @@ llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) { if (d.index() + startPos == static_cast(dimIndex)) continue; - assert(!ShapedType::isDynamic(d.value()) && + assert(!RankedShapedType::isDynamic(d.value()) && "single dimension cannot be expanded into multiple dynamic " "dimensions"); linearizedStaticDim *= d.value(); @@ -130,7 +130,8 @@ ArrayRef dstStaticShape, ArrayRef reassocation) { return dstStaticShape.size() > - static_cast(src.getType().cast().getRank()) + static_cast( + src.getType().cast().getRank()) ? getExpandedOutputShapeFromInputShape( builder, loc, src, dstStaticShape, reassocation) : getCollapsedOutputShapeFromInputShape( diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -64,6 +64,7 @@ FailureOr tensor::getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult) { auto tensorType = opResult.getType().dyn_cast(); + auto rankedTensorType = opResult.getType().dyn_cast(); assert(tensorType && "expected tensor type"); // If the op has a destination, it implements DestinationStyleOpInterface and @@ -78,7 +79,7 @@ // Compute sizes. SmallVector mixedSizes; - if (!tensorType.hasStaticShape()) { + if (!rankedTensorType || !rankedTensorType.hasStaticShape()) { // Dynamic shape: Query ReifyRankedShapedTypeOpInterface. ReifiedRankedShapedTypeDims reifiedShapes; if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes))) @@ -86,7 +87,7 @@ mixedSizes = reifiedShapes[opResult.getResultNumber()]; } else { // Static shape: Take static sizes directly. - for (int64_t sz : tensorType.getShape()) + for (int64_t sz : rankedTensorType.getShape()) mixedSizes.push_back(b.getIndexAttr(sz)); } @@ -180,8 +181,8 @@ // If cast is towards more static sizes along any dimension, don't fold. for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) { - if (!ShapedType::isDynamic(std::get<0>(t)) && - ShapedType::isDynamic(std::get<1>(t))) + if (!RankedShapedType::isDynamic(std::get<0>(t)) && + RankedShapedType::isDynamic(std::get<1>(t))) return false; } @@ -281,29 +282,31 @@ static TensorType joinShapes(TensorType one, TensorType two) { assert(one.getElementType() == two.getElementType()); - if (!one.hasRank()) + auto oneRanked = one.dyn_cast(); + if (!oneRanked) return two; - if (!two.hasRank()) + auto twoRanked = two.dyn_cast(); + if (!twoRanked) return one; - int64_t rank = one.getRank(); - if (rank != two.getRank()) + int64_t rank = oneRanked.getRank(); + if (rank != twoRanked.getRank()) return {}; SmallVector join; join.reserve(rank); for (int64_t i = 0; i < rank; ++i) { - if (one.isDynamicDim(i)) { - join.push_back(two.getDimSize(i)); + if (oneRanked.isDynamicDim(i)) { + join.push_back(twoRanked.getDimSize(i)); continue; } - if (two.isDynamicDim(i)) { - join.push_back(one.getDimSize(i)); + if (twoRanked.isDynamicDim(i)) { + join.push_back(oneRanked.getDimSize(i)); continue; } - if (one.getDimSize(i) != two.getDimSize(i)) + if (oneRanked.getDimSize(i) != twoRanked.getDimSize(i)) return {}; - join.push_back(one.getDimSize(i)); + join.push_back(oneRanked.getDimSize(i)); } return RankedTensorType::get(join, one.getElementType()); } @@ -389,7 +392,7 @@ if (dimMask && dimMask->count(i)) continue; int64_t dim = rankedResultType.getShape()[dimIndex++]; - if (ShapedType::isDynamic(dim)) + if (RankedShapedType::isDynamic(dim)) continue; sizes[i] = rewriter.getIndexAttr(dim); } @@ -473,12 +476,12 @@ fromElements.getResult().getType().cast(); // The case where the type encodes the size of the dimension is handled // above. - assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()])); + assert(RankedShapedType::isDynamic(resultType.getShape()[index.getInt()])); // Find the operand of the fromElements that corresponds to this index. auto dynExtents = fromElements.getDynamicExtents().begin(); for (auto dim : resultType.getShape().take_front(index.getInt())) - if (ShapedType::isDynamic(dim)) + if (RankedShapedType::isDynamic(dim)) dynExtents++; return Value{*dynExtents}; @@ -533,7 +536,7 @@ ArrayRef staticShape, Type elementType, Attribute encoding) { assert(all_of(staticShape, - [](int64_t sz) { return !ShapedType::isDynamic(sz); }) && + [](int64_t sz) { return !RankedShapedType::isDynamic(sz); }) && "expected only static sizes"); build(builder, result, staticShape, elementType, ValueRange{}, encoding); } @@ -707,7 +710,7 @@ // Case 1: The empty tensor dim is static. Check that the tensor cast // result dim matches. if (auto attr = currDim.dyn_cast()) { - if (ShapedType::isDynamic(newDim) || + if (RankedShapedType::isDynamic(newDim) || newDim != attr.cast().getInt()) { // Something is off, the cast result shape cannot be more dynamic // than the empty tensor result shape (enforced by @@ -722,7 +725,7 @@ // Case 2 : The tensor cast shape is static, but empty tensor result // shape is dynamic. - if (!ShapedType::isDynamic(newDim)) { + if (!RankedShapedType::isDynamic(newDim)) { newMixedSizes.push_back(rewriter.getIndexAttr(newDim)); continue; } @@ -1129,13 +1132,13 @@ auto operandsIt = tensorFromElements.getDynamicExtents().begin(); for (int64_t dim : resultType.getShape()) { - if (!ShapedType::isDynamic(dim)) { + if (!RankedShapedType::isDynamic(dim)) { newShape.push_back(dim); continue; } APInt index; if (!matchPattern(*operandsIt, m_ConstantInt(&index))) { - newShape.push_back(ShapedType::kDynamic); + newShape.push_back(RankedShapedType::kDynamic); newOperands.push_back(*operandsIt++); continue; } @@ -1210,8 +1213,8 @@ OpFoldResult RankOp::fold(FoldAdaptor adaptor) { // Constant fold rank when the rank of the operand is known. auto type = getOperand().getType(); - auto shapedType = type.dyn_cast(); - if (shapedType && shapedType.hasRank()) + auto shapedType = type.dyn_cast(); + if (shapedType) return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); return IntegerAttr(); } @@ -1225,7 +1228,7 @@ setNameFn(getResult(), "reshape"); } -static int64_t getNumElements(ShapedType type) { +static int64_t getNumElements(RankedShapedType type) { int64_t numElements = 1; for (auto dim : type.getShape()) numElements *= dim; @@ -1252,7 +1255,7 @@ return emitOpError("source and destination tensor should have the " "same number of elements"); } - if (ShapedType::isDynamic(shapeSize)) + if (RankedShapedType::isDynamic(shapeSize)) return emitOpError("cannot use shape operand with dynamic length to " "reshape to statically-ranked tensor type"); if (shapeSize != resultRankedType.getRank()) @@ -1325,8 +1328,8 @@ unsigned dim = m.getNumResults(); auto band = shape.slice(currentDim, dim); int64_t size = 1; - if (llvm::is_contained(band, ShapedType::kDynamic)) - size = ShapedType::kDynamic; + if (llvm::is_contained(band, RankedShapedType::kDynamic)) + size = RankedShapedType::kDynamic; else for (unsigned d = 0; d < dim; ++d) size *= shape[currentDim + d]; @@ -1449,9 +1452,9 @@ if (!fromElements) return failure(); - auto shapedTy = reshapeOp.getType().template cast(); + auto shapedTy = reshapeOp.getType().template dyn_cast(); - if (!shapedTy.hasStaticShape()) + if (!shapedTy || !shapedTy.hasStaticShape()) return failure(); rewriter.replaceOpWithNewOp(reshapeOp, reshapeOp.getType(), @@ -1929,8 +1932,8 @@ return failure(); // Dynamic result shape is not supported. - auto sourceType = op.getSource().getType().cast(); - auto resultType = op.getResult().getType().cast(); + auto sourceType = op.getSource().getType().cast(); + auto resultType = op.getResult().getType().cast(); if (!sourceType.hasStaticShape() || !resultType.hasStaticShape()) return failure(); @@ -1944,13 +1947,13 @@ // Check if there are any dynamic parts, which are not supported. auto offsets = op.getStaticOffsets(); - if (llvm::is_contained(offsets, ShapedType::kDynamic)) + if (llvm::is_contained(offsets, RankedShapedType::kDynamic)) return failure(); auto sizes = op.getStaticSizes(); - if (llvm::is_contained(sizes, ShapedType::kDynamic)) + if (llvm::is_contained(sizes, RankedShapedType::kDynamic)) return failure(); auto strides = op.getStaticStrides(); - if (llvm::is_contained(strides, ShapedType::kDynamic)) + if (llvm::is_contained(strides, RankedShapedType::kDynamic)) return failure(); // Compute the stride for each dimension. @@ -2036,7 +2039,7 @@ // static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, - ShapedType shapedType) { + RankedShapedType shapedType) { OpBuilder b(op.getContext()); for (OpFoldResult ofr : op.getMixedOffsets()) if (getConstantIntValue(ofr) != static_cast(0)) @@ -2070,8 +2073,8 @@ OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) { if (auto splat = adaptor.getSource().dyn_cast_or_null()) { - auto resultType = getResult().getType().cast(); - if (resultType.hasStaticShape()) + auto resultType = getResult().getType().dyn_cast(); + if (resultType && resultType.hasStaticShape()) return splat.resizeSplat(resultType); } if (getSourceType() == getType() && @@ -2538,14 +2541,15 @@ SmallVector inferredShape; for (auto i : llvm::seq(0, rank)) { - if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic || - staticHigh[i] == ShapedType::kDynamic) { - inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic + if (sourceType.isDynamicDim(i) || + staticLow[i] == RankedShapedType::kDynamic || + staticHigh[i] == RankedShapedType::kDynamic) { + inferredShape.push_back(resultShape.empty() ? RankedShapedType::kDynamic : resultShape[i]); } else { int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i]; assert((resultShape.empty() || size == resultShape[i] || - resultShape[i] == ShapedType::kDynamic) && + resultShape[i] == RankedShapedType::kDynamic) && "mismatch between inferred shape and result shape"); inferredShape.push_back(size); } @@ -2572,7 +2576,7 @@ ArrayRef attrs) { auto sourceType = source.getType().cast(); unsigned rank = sourceType.getRank(); - SmallVector staticVector(rank, ShapedType::kDynamic); + SmallVector staticVector(rank, RankedShapedType::kDynamic); build(b, result, resultType, source, staticVector, staticVector, low, high, nofold, attrs); } @@ -2839,7 +2843,7 @@ continue; OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()]; int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()]; - assert(!ShapedType::isDynamic(sourceSize) && + assert(!RankedShapedType::isDynamic(sourceSize) && "expected padded dimension to have a static size"); if (getConstantIntValue(sliceSize) != sourceSize) { return rewriter.notifyMatchFailure( @@ -2897,7 +2901,7 @@ for (auto operand : padTensorOp.getLow()) { APSInt intOp; if (!matchPattern(operand, m_ConstantInt(&intOp))) { - constOperandsLow.push_back(ShapedType::kDynamic); + constOperandsLow.push_back(RankedShapedType::kDynamic); continue; } constOperandsLow.push_back(intOp.getExtValue()); @@ -2906,7 +2910,7 @@ for (auto operand : padTensorOp.getHigh()) { APSInt intOp; if (!matchPattern(operand, m_ConstantInt(&intOp))) { - constOperandsHigh.push_back(ShapedType::kDynamic); + constOperandsHigh.push_back(RankedShapedType::kDynamic); continue; } constOperandsHigh.push_back(intOp.getExtValue()); @@ -2924,9 +2928,9 @@ auto lowCount = 0; auto highCount = 0; for (size_t i = 0; i < inputRank; i++) { - if (constLow[i] == ShapedType::kDynamic) + if (constLow[i] == RankedShapedType::kDynamic) constLow[i] = constOperandsLow[lowCount++]; - if (constHigh[i] == ShapedType::kDynamic) + if (constHigh[i] == RankedShapedType::kDynamic) constHigh[i] = constOperandsHigh[highCount++]; } @@ -2936,12 +2940,12 @@ // Calculate the output sizes with the static information. SmallVector newOutDims; for (size_t i = 0; i < inputRank; i++) { - if (outputDims[i] == ShapedType::kDynamic) { + if (outputDims[i] == RankedShapedType::kDynamic) { newOutDims.push_back( - (staticLow[i] == ShapedType::kDynamic || - staticHigh[i] == ShapedType::kDynamic || - inputDims[i] == ShapedType::kDynamic - ? ShapedType::kDynamic + (staticLow[i] == RankedShapedType::kDynamic || + staticHigh[i] == RankedShapedType::kDynamic || + inputDims[i] == RankedShapedType::kDynamic + ? RankedShapedType::kDynamic : inputDims[i] + staticLow[i] + staticHigh[i])); } else { newOutDims.push_back(outputDims[i]); @@ -2949,8 +2953,9 @@ } if (SmallVector(outputDims) == newOutDims || - llvm::all_of(newOutDims, - [&](int64_t x) { return x == ShapedType::kDynamic; })) + llvm::all_of(newOutDims, [&](int64_t x) { + return x == RankedShapedType::kDynamic; + })) return failure(); // Rewrite the op using the new static type. @@ -3201,7 +3206,7 @@ "applies to only pack or unpack operations"); int64_t destRank = op.getDestRank(); reifiedReturnShapes.resize(1, SmallVector(destRank)); - ShapedType resultType = op.getResult().getType().template cast(); + auto resultType = op.getResult().getType().template cast(); for (auto dim : llvm::seq(0, destRank)) { if (resultType.isDynamicDim(dim)) { reifiedReturnShapes[0][dim] = @@ -3237,7 +3242,7 @@ SmallVector mixedInnerTiles; unsigned dynamicValIndex = 0; for (int64_t staticTile : op.getStaticInnerTiles()) { - if (!ShapedType::isDynamic(staticTile)) + if (!RankedShapedType::isDynamic(staticTile)) mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile)); else mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]); @@ -3285,8 +3290,8 @@ llvm::zip(sourceShape, limitShape), [](std::tuple it) { int64_t sourceExtent = std::get<0>(it); int64_t limit = std::get<1>(it); - return ShapedType::isDynamic(sourceExtent) || - ShapedType::isDynamic(limit) || sourceExtent <= limit; + return RankedShapedType::isDynamic(sourceExtent) || + RankedShapedType::isDynamic(limit) || sourceExtent <= limit; }); } @@ -3332,9 +3337,9 @@ "tiling factors must equal the number of dimensions to tile"); } - ShapedType packedType = (std::is_same::value) - ? packOrUnPack.getDestType() - : packOrUnPack.getSourceType(); + RankedShapedType packedType = (std::is_same::value) + ? packOrUnPack.getDestType() + : packOrUnPack.getSourceType(); size_t packedRank = packedType.getRank(); // Require output rank to match input rank + number of blocking factors. if (unpackedRank + mixedTiles.size() != packedRank) { @@ -3362,9 +3367,9 @@ if (!constTileSize) { // If specified tile size is dynamic, output shape should // be dynamic too. - return ShapedType::isDynamic(shape); + return RankedShapedType::isDynamic(shape); } - if (ShapedType::isDynamic(shape)) { + if (RankedShapedType::isDynamic(shape)) { // For the shape being dynamic when tile size is // specified, return true. In canonical form a constant // tile size should lead to constant shape of the tiled @@ -3479,7 +3484,7 @@ ArrayRef innerDimsPos, ArrayRef innerTiles) { for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) { - if (ShapedType::isDynamic(inputShape[pos])) + if (RankedShapedType::isDynamic(inputShape[pos])) continue; std::optional constantTile = getConstantIntValue(tileSize); if (!constantTile) @@ -3522,9 +3527,10 @@ for (auto o : ofrs) { // Have to do this first, as getConstantIntValue special-cases constants. if (o.dyn_cast()) - result.push_back(ShapedType::kDynamic); + result.push_back(RankedShapedType::kDynamic); else - result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic)); + result.push_back( + getConstantIntValue(o).value_or(RankedShapedType::kDynamic)); } return result; } @@ -3537,10 +3543,10 @@ ArrayRef innerDimsPos, ArrayRef outerDimsPerm) { SmallVector resultShape = llvm::to_vector(sourceShape); for (auto tiledDim : llvm::enumerate(innerDimsPos)) { - if (ShapedType::isDynamic(resultShape[tiledDim.value()])) + if (RankedShapedType::isDynamic(resultShape[tiledDim.value()])) continue; - if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) { - resultShape[tiledDim.value()] = ShapedType::kDynamic; + if (RankedShapedType::isDynamic(innerTileSizes[tiledDim.index()])) { + resultShape[tiledDim.value()] = RankedShapedType::kDynamic; continue; } resultShape[tiledDim.value()] = ceilDiv(resultShape[tiledDim.value()], @@ -3584,7 +3590,7 @@ // use dispatchIndexOpFoldResults on the result, and rely on exact number of // dynamic dims returned by that. for (unsigned i = 0; i < resultDims.size(); ++i) { - if (!ShapedType::isDynamic(resultTypeShape[i])) + if (!RankedShapedType::isDynamic(resultTypeShape[i])) continue; resultDims[i] = getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]); @@ -3618,7 +3624,7 @@ SmallVector mixedSizes; for (auto [index, value] : llvm::enumerate(source.getType().cast().getShape())) { - if (ShapedType::isDynamic(value)) + if (RankedShapedType::isDynamic(value)) mixedSizes.push_back(b.create(loc, source, index).getResult()); else mixedSizes.push_back(b.getIndexAttr(value)); @@ -3654,14 +3660,14 @@ bool areTilesAndTiledDimsAllConstant(OpTy op) { static_assert(llvm::is_one_of::value, "applies to only pack or unpack operations"); - ShapedType packedType = (std::is_same::value) - ? op.getDestType() - : op.getSourceType(); + RankedShapedType packedType = (std::is_same::value) + ? op.getDestType() + : op.getSourceType(); SmallVector mixedTiles = op.getMixedTiles(); for (auto [dimDest, tile] : llvm::zip( packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) { std::optional constTileSize = getConstantIntValue(tile); - if (!constTileSize || ShapedType::isDynamic(dimDest)) + if (!constTileSize || RankedShapedType::isDynamic(dimDest)) return false; } return true; diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -28,7 +28,7 @@ SmallVector high(type.getRank(), zero); for (const auto &en : enumerate(type.getShape())) { // Pad only the static dimensions of the result tensor type. - if (ShapedType::isDynamic(en.value())) + if (RankedShapedType::isDynamic(en.value())) continue; // Compute the padding width. AffineExpr d0; @@ -47,7 +47,7 @@ auto tensorTy = rankedTensor.getType().cast(); SmallVector dynamicDims; for (const auto &en : llvm::enumerate(tensorTy.getShape())) { - if (en.value() == ShapedType::kDynamic) + if (en.value() == RankedShapedType::kDynamic) dynamicDims.push_back( b.create(loc, rankedTensor, en.index())); } @@ -63,7 +63,7 @@ auto shape = tensorTy.getShape(); if (dim >= static_cast(shape.size())) return failure(); - if (ShapedType::isDynamic(shape[dim])) + if (RankedShapedType::isDynamic(shape[dim])) return OpFoldResult(b.createOrFold(loc, rankedTensor, dim)); return OpFoldResult(b.getIndexAttr(shape[dim])); } @@ -73,7 +73,7 @@ auto tensorTy = rankedTensor.getType().cast(); SmallVector dims; for (const auto &en : llvm::enumerate(tensorTy.getShape())) { - if (ShapedType::isDynamic(en.value())) { + if (RankedShapedType::isDynamic(en.value())) { dims.push_back( b.createOrFold(loc, rankedTensor, en.index())); } else { diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -200,8 +200,8 @@ } auto input = op.getInput1(); - auto inputTy = input.getType().cast(); - if (!inputTy.hasRank()) + auto inputTy = input.getType().dyn_cast(); + if (!inputTy) return rewriter.notifyMatchFailure(op, "Unranked input."); int64_t numDynDims = 0; @@ -300,8 +300,8 @@ PatternRewriter &rewriter) const override { Value input = op.getInput(); Value output = op.getOutput(); - ShapedType inputType = input.getType().cast(); - ShapedType outputType = output.getType().cast(); + auto inputType = input.getType().cast(); + auto outputType = output.getType().cast(); if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) { return failure(); @@ -827,8 +827,8 @@ #define REDUCE_FOLDER(OP) \ OpFoldResult OP::fold(FoldAdaptor adaptor) { \ - ShapedType inputTy = getInput().getType().cast(); \ - if (!inputTy.hasRank()) \ + auto inputTy = getInput().getType().dyn_cast(); \ + if (!inputTy) \ return {}; \ if (inputTy.getDimSize(getAxis()) == 1) \ return getInput(); \ @@ -906,14 +906,14 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) { auto operand = getInput(); - auto operandTy = operand.getType().cast(); + auto operandTy = operand.getType().dyn_cast(); auto axis = getAxis(); auto operandAttr = adaptor.getInput().dyn_cast_or_null(); if (operandAttr) return operandAttr; // If the dim-length is 1, tosa.reverse is a no-op. - if (operandTy.hasRank() && operandTy.getDimSize(axis) == 1) + if (operandTy && operandTy.getDimSize(axis) == 1) return operand; return {}; @@ -974,8 +974,8 @@ } OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { - auto inputTy = getInput1().getType().cast(); - auto resultTy = getType().cast(); + auto inputTy = getInput1().getType().cast(); + auto resultTy = getType().cast(); // Transposing splat values just means reshaping. if (auto input = adaptor.getInput1().dyn_cast_or_null()) { diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -397,14 +397,14 @@ return failure(); llvm::SmallVector outputShape; - outputShape.resize(3, ShapedType::kDynamic); + outputShape.resize(3, RankedShapedType::kDynamic); outputShape[0] = inputShape.getDimSize(0); outputShape[1] = inputShape.getDimSize(1); int64_t inWidth = inputShape.getDimSize(2); // Note that we can support this calculation symbolically // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1] - if (inWidth != ShapedType::kDynamic) + if (inWidth != RankedShapedType::kDynamic) outputShape[2] = inWidth / 2 + 1; inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); @@ -438,13 +438,13 @@ // Copy the Operand's rank. if (!hasRankedInput) - outputShape.resize(operandShape.getRank(), ShapedType::kDynamic); + outputShape.resize(operandShape.getRank(), RankedShapedType::kDynamic); // Copy shapes until the dim is non-dynamic. for (int i = 0, s = operandShape.getRank(); i < s; i++) { if (i == axis || operandShape.isDynamicDim(i)) continue; - if (outputShape[i] == ShapedType::kDynamic) + if (outputShape[i] == RankedShapedType::kDynamic) outputShape[i] = operandShape.getDimSize(i); if (outputShape[i] != operandShape.getDimSize(i)) return emitOptionalError(location, @@ -469,7 +469,7 @@ // We need to know the length of the concatenation axis of all inputs to // determine the dimension size of the output shape. if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) { - concatDimSize = ShapedType::kDynamic; + concatDimSize = RankedShapedType::kDynamic; break; } @@ -513,7 +513,7 @@ // All shapes are dynamic. SmallVector outShape; - outShape.resize(2, ShapedType::kDynamic); + outShape.resize(2, RankedShapedType::kDynamic); if (inputShape.hasRank()) { outShape[0] = inputShape.getDimSize(0); @@ -524,8 +524,9 @@ } if (biasShape.hasRank()) { - outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0) - : outShape[1]; + outShape[1] = outShape[1] == RankedShapedType::kDynamic + ? biasShape.getDimSize(0) + : outShape[1]; } inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); @@ -543,7 +544,7 @@ // All shapes are dynamic. SmallVector outShape; - outShape.resize(3, ShapedType::kDynamic); + outShape.resize(3, RankedShapedType::kDynamic); if (lhsShape.hasRank()) { outShape[0] = lhsShape.getDimSize(0); @@ -551,8 +552,9 @@ } if (rhsShape.hasRank()) { - outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0) - : outShape[0]; + outShape[0] = outShape[0] == RankedShapedType::kDynamic + ? rhsShape.getDimSize(0) + : outShape[0]; outShape[2] = rhsShape.getDimSize(2); } @@ -583,7 +585,7 @@ return success(); } - outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamic); + outputShape.resize(paddingShape.getDimSize(0), RankedShapedType::kDynamic); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } @@ -591,7 +593,7 @@ DenseIntElementsAttr paddings; // If the paddings value is not a constant, all dimensions must be dynamic. if (!matchPattern(operands[1], m_Constant(&paddings))) { - outputShape.resize(inputShape.getRank(), ShapedType::kDynamic); + outputShape.resize(inputShape.getRank(), RankedShapedType::kDynamic); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } @@ -604,7 +606,7 @@ outputShape.reserve(inputShape.getRank()); for (int i = 0, s = inputShape.getRank(); i < s; i++) { if (inputShape.isDynamicDim(i)) { - outputShape.push_back(ShapedType::kDynamic); + outputShape.push_back(RankedShapedType::kDynamic); continue; } @@ -618,7 +620,7 @@ static SmallVector convertToMlirShape(ArrayRef shape) { return to_vector(llvm::map_range(shape, [](int64_t dim) { - return dim == -1 ? ShapedType::kDynamic : dim; + return dim == -1 ? RankedShapedType::kDynamic : dim; })); } @@ -656,7 +658,7 @@ ShapeAdaptor inputShape = operands.getShape(0); SmallVector outputShape; if (!inputShape.hasRank()) { - outputShape.resize(multiples.size(), ShapedType::kDynamic); + outputShape.resize(multiples.size(), RankedShapedType::kDynamic); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } @@ -665,7 +667,7 @@ outputShape.reserve(multiples.size()); for (int i = 0, s = inputShape.getRank(); i < s; i++) { int64_t dim = inputShape.getDimSize(i); - if (dim != ShapedType::kDynamic) + if (dim != RankedShapedType::kDynamic) dim *= multiples[i]; outputShape.push_back(dim); } @@ -696,14 +698,14 @@ int64_t numElements = inputShape.getNumElements(); int64_t staticMul = 1; for (auto val : newShapeValue) { - if (!ShapedType::isDynamic(val)) { + if (!RankedShapedType::isDynamic(val)) { staticMul *= val; } } // Determine the length of the dynamic dimension. for (auto &val : newShapeValue) { - if (ShapedType::isDynamic(val)) + if (RankedShapedType::isDynamic(val)) val = numElements / staticMul; } @@ -712,10 +714,11 @@ } mlir::LogicalResult tosa::ReshapeOp::verify() { - ShapedType inputType = getInput1().getType().cast(); - ShapedType outputType = getType().cast(); + auto inputType = getInput1().getType().dyn_cast(); + auto outputType = getType().dyn_cast(); - if (inputType.hasStaticShape() && outputType.hasStaticShape()) { + if (inputType && inputType.hasStaticShape() && outputType && + outputType.hasStaticShape()) { int64_t inputElementsNum = inputType.getNumElements(); int64_t outputElementsNum = outputType.getNumElements(); if (inputElementsNum != outputElementsNum) { @@ -765,7 +768,7 @@ // can determine the output rank. SmallVector outputShape; if (!inputShape.hasRank()) { - outputShape.resize(permsShape.getDimSize(0), ShapedType::kDynamic); + outputShape.resize(permsShape.getDimSize(0), RankedShapedType::kDynamic); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } @@ -793,7 +796,7 @@ return success(); } - outputShape.resize(inputShape.getRank(), ShapedType::kDynamic); + outputShape.resize(inputShape.getRank(), RankedShapedType::kDynamic); // If the permuations are a constant we can directly determine the output // shape. if (ShapeAdaptor permShape = operands.getValueAsShape(1)) { @@ -812,7 +815,7 @@ ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape; - outputShape.resize(3, ShapedType::kDynamic); + outputShape.resize(3, RankedShapedType::kDynamic); ShapeAdaptor valuesShape = operands.getShape(0); if (valuesShape.hasRank()) { @@ -822,9 +825,9 @@ ShapeAdaptor indicesShape = operands.getShape(1); if (indicesShape.hasRank()) { - if (outputShape[0] == ShapedType::kDynamic) + if (outputShape[0] == RankedShapedType::kDynamic) outputShape[0] = indicesShape.getDimSize(0); - if (outputShape[1] == ShapedType::kDynamic) + if (outputShape[1] == RankedShapedType::kDynamic) outputShape[1] = indicesShape.getDimSize(1); } @@ -838,7 +841,7 @@ SmallVectorImpl &inferredReturnShapes) { ResizeOpAdaptor adaptor(operands, attributes); llvm::SmallVector outputShape; - outputShape.resize(4, ShapedType::kDynamic); + outputShape.resize(4, RankedShapedType::kDynamic); ShapeAdaptor inputShape = operands.getShape(adaptor.getInput()); if (!inputShape.hasRank()) @@ -849,8 +852,8 @@ int64_t inputHeight = inputShape.getDimSize(1); int64_t inputWidth = inputShape.getDimSize(2); - if ((inputHeight == ShapedType::kDynamic) || - (inputWidth == ShapedType::kDynamic)) + if ((inputHeight == RankedShapedType::kDynamic) || + (inputWidth == RankedShapedType::kDynamic)) return failure(); llvm::ArrayRef scaleInt = adaptor.getScale(); @@ -877,7 +880,7 @@ ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape; - outputShape.resize(3, ShapedType::kDynamic); + outputShape.resize(3, RankedShapedType::kDynamic); ShapeAdaptor valuesInShape = operands.getShape(0); if (valuesInShape.hasRank()) { @@ -888,15 +891,15 @@ ShapeAdaptor indicesShape = operands.getShape(1); if (indicesShape.hasRank()) { - if (outputShape[0] == ShapedType::kDynamic) + if (outputShape[0] == RankedShapedType::kDynamic) outputShape[0] = indicesShape.getDimSize(0); } ShapeAdaptor inputShape = operands.getShape(2); if (inputShape.hasRank()) { - if (outputShape[0] == ShapedType::kDynamic) + if (outputShape[0] == RankedShapedType::kDynamic) outputShape[0] = inputShape.getDimSize(0); - if (outputShape[2] == ShapedType::kDynamic) + if (outputShape[2] == RankedShapedType::kDynamic) outputShape[2] = inputShape.getDimSize(2); } @@ -1018,7 +1021,7 @@ SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape = operands.getShape(0); llvm::SmallVector outputShape; - outputShape.resize(4, ShapedType::kDynamic); + outputShape.resize(4, RankedShapedType::kDynamic); // We only know the rank if the input type is unranked. if (!inputShape) { @@ -1037,12 +1040,12 @@ ArrayRef stride = attributes.get("stride").cast(); ArrayRef pad = attributes.get("pad").cast(); - if (!ShapedType::isDynamic(height)) { + if (!RankedShapedType::isDynamic(height)) { int64_t padded = height + pad[0] + pad[1] - kernel[0]; outputShape[1] = padded / stride[0] + 1; } - if (!ShapedType::isDynamic(width)) { + if (!RankedShapedType::isDynamic(width)) { int64_t padded = width + pad[2] + pad[3] - kernel[1]; outputShape[2] = padded / stride[1] + 1; } @@ -1055,13 +1058,13 @@ MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - llvm::SmallVector outputShape(4, ShapedType::kDynamic); + llvm::SmallVector outputShape(4, RankedShapedType::kDynamic); Conv2DOp::Adaptor adaptor(operands.getValues(), attributes); - int64_t inputWidth = ShapedType::kDynamic; - int64_t inputHeight = ShapedType::kDynamic; - int64_t weightWidth = ShapedType::kDynamic; - int64_t weightHeight = ShapedType::kDynamic; + int64_t inputWidth = RankedShapedType::kDynamic; + int64_t inputHeight = RankedShapedType::kDynamic; + int64_t weightWidth = RankedShapedType::kDynamic; + int64_t weightHeight = RankedShapedType::kDynamic; // Input shape describes input width/height and batch. @@ -1083,7 +1086,7 @@ // Bias shape can describe the output channels. ShapeAdaptor biasShape = operands.getShape(adaptor.getBias()); if (biasShape.hasRank()) { - outputShape[3] = ShapedType::isDynamic(outputShape[3]) + outputShape[3] = RankedShapedType::isDynamic(outputShape[3]) ? biasShape.getDimSize(0) : outputShape[3]; } @@ -1092,16 +1095,16 @@ llvm::ArrayRef stride = adaptor.getStride(); llvm::ArrayRef padding = adaptor.getPad(); - if (!ShapedType::isDynamic(inputHeight) && - !ShapedType::isDynamic(weightHeight)) { + if (!RankedShapedType::isDynamic(inputHeight) && + !RankedShapedType::isDynamic(weightHeight)) { int64_t inputSize = inputHeight + padding[0] + padding[1]; int64_t filterSize = (weightHeight - 1) * dilation[0] + 1; int64_t unstridedResult = inputSize - filterSize + 1; outputShape[1] = (unstridedResult - 1) / stride[0] + 1; } - if (!ShapedType::isDynamic(inputWidth) && - !ShapedType::isDynamic(weightWidth)) { + if (!RankedShapedType::isDynamic(inputWidth) && + !RankedShapedType::isDynamic(weightWidth)) { int64_t inputSize = inputWidth + padding[2] + padding[3]; int64_t filterSize = (weightWidth - 1) * dilation[1] + 1; int64_t unstridedResult = inputSize - filterSize + 1; @@ -1118,16 +1121,16 @@ MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - llvm::SmallVector outputShape(5, ShapedType::kDynamic); + llvm::SmallVector outputShape(5, RankedShapedType::kDynamic); Conv3DOp::Adaptor adaptor(operands.getValues(), attributes); - int64_t inputWidth = ShapedType::kDynamic; - int64_t inputHeight = ShapedType::kDynamic; - int64_t inputDepth = ShapedType::kDynamic; + int64_t inputWidth = RankedShapedType::kDynamic; + int64_t inputHeight = RankedShapedType::kDynamic; + int64_t inputDepth = RankedShapedType::kDynamic; - int64_t weightWidth = ShapedType::kDynamic; - int64_t weightHeight = ShapedType::kDynamic; - int64_t weightDepth = ShapedType::kDynamic; + int64_t weightWidth = RankedShapedType::kDynamic; + int64_t weightHeight = RankedShapedType::kDynamic; + int64_t weightDepth = RankedShapedType::kDynamic; // Input shape describes input width/height and batch. ShapeAdaptor inputShape = operands.getShape(adaptor.getInput()); @@ -1149,7 +1152,7 @@ // Bias shape can describe the output channels. ShapeAdaptor biasShape = operands.getShape(adaptor.getBias()); - if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) { + if (biasShape.hasRank() && RankedShapedType::isDynamic(outputShape[4])) { outputShape[4] = biasShape.getDimSize(0); } @@ -1157,24 +1160,24 @@ llvm::ArrayRef stride = adaptor.getStride(); llvm::ArrayRef pad = adaptor.getPad(); - if (!ShapedType::isDynamic(inputDepth) && - !ShapedType::isDynamic(weightDepth)) { + if (!RankedShapedType::isDynamic(inputDepth) && + !RankedShapedType::isDynamic(weightDepth)) { int32_t inputSize = inputDepth + pad[0] + pad[1]; int32_t filterSize = (weightDepth - 1) * dilation[0] + 1; int32_t unstridedResult = inputSize - filterSize + 1; outputShape[1] = (unstridedResult - 1) / stride[0] + 1; } - if (!ShapedType::isDynamic(inputHeight) && - !ShapedType::isDynamic(weightHeight)) { + if (!RankedShapedType::isDynamic(inputHeight) && + !RankedShapedType::isDynamic(weightHeight)) { int32_t inputSize = inputHeight + pad[2] + pad[3]; int32_t filterSize = (weightHeight - 1) * dilation[1] + 1; int32_t unstridedResult = inputSize - filterSize + 1; outputShape[2] = (unstridedResult - 1) / stride[1] + 1; } - if (!ShapedType::isDynamic(inputWidth) && - !ShapedType::isDynamic(weightWidth)) { + if (!RankedShapedType::isDynamic(inputWidth) && + !RankedShapedType::isDynamic(weightWidth)) { int32_t inputSize = inputWidth + pad[4] + pad[5]; int32_t filterSize = (weightWidth - 1) * dilation[2] + 1; int32_t unstridedResult = inputSize - filterSize + 1; @@ -1205,16 +1208,16 @@ MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - llvm::SmallVector outputShape(4, ShapedType::kDynamic); + llvm::SmallVector outputShape(4, RankedShapedType::kDynamic); DepthwiseConv2DOp::Adaptor adaptor(operands.getValues(), attributes); - int64_t inputWidth = ShapedType::kDynamic; - int64_t inputHeight = ShapedType::kDynamic; - int64_t inputChannels = ShapedType::kDynamic; + int64_t inputWidth = RankedShapedType::kDynamic; + int64_t inputHeight = RankedShapedType::kDynamic; + int64_t inputChannels = RankedShapedType::kDynamic; - int64_t weightWidth = ShapedType::kDynamic; - int64_t weightHeight = ShapedType::kDynamic; - int64_t depthChannels = ShapedType::kDynamic; + int64_t weightWidth = RankedShapedType::kDynamic; + int64_t weightHeight = RankedShapedType::kDynamic; + int64_t depthChannels = RankedShapedType::kDynamic; // Input shape describes input width/height and batch. ShapeAdaptor inputShape = operands.getShape(adaptor.getInput()); @@ -1230,7 +1233,7 @@ if (weightShape.hasRank()) { weightHeight = weightShape.getDimSize(0); weightWidth = weightShape.getDimSize(1); - inputChannels = ShapedType::isDynamic(inputChannels) + inputChannels = RankedShapedType::isDynamic(inputChannels) ? weightShape.getDimSize(2) : inputChannels; depthChannels = weightShape.getDimSize(3); @@ -1238,15 +1241,15 @@ // If both inputChannels and depthChannels are available we can determine // the output channels. - if (!ShapedType::isDynamic(inputChannels) && - !ShapedType::isDynamic(depthChannels)) { + if (!RankedShapedType::isDynamic(inputChannels) && + !RankedShapedType::isDynamic(depthChannels)) { outputShape[3] = inputChannels * depthChannels; } // Bias shape can describe the output channels. ShapeAdaptor biasShape = operands.getShape(adaptor.getBias()); if (biasShape.hasRank()) { - outputShape[3] = ShapedType::isDynamic(outputShape[3]) + outputShape[3] = RankedShapedType::isDynamic(outputShape[3]) ? biasShape.getDimSize(0) : outputShape[3]; } @@ -1255,16 +1258,16 @@ llvm::ArrayRef padding = adaptor.getPad(); llvm::ArrayRef stride = adaptor.getStride(); - if (!ShapedType::isDynamic(inputHeight) && - !ShapedType::isDynamic(weightHeight)) { + if (!RankedShapedType::isDynamic(inputHeight) && + !RankedShapedType::isDynamic(weightHeight)) { int64_t inputSize = inputHeight + padding[0] + padding[1]; int64_t filterSize = (weightHeight - 1) * dilation[0] + 1; int64_t unstridedResult = inputSize - filterSize + 1; outputShape[1] = (unstridedResult - 1) / stride[0] + 1; } - if (!ShapedType::isDynamic(inputWidth) && - !ShapedType::isDynamic(weightWidth)) { + if (!RankedShapedType::isDynamic(inputWidth) && + !RankedShapedType::isDynamic(weightWidth)) { int64_t inputSize = inputWidth + padding[2] + padding[3]; int64_t filterSize = (weightWidth - 1) * dilation[1] + 1; int64_t unstridedResult = inputSize - filterSize + 1; @@ -1286,15 +1289,15 @@ llvm::SmallVector outputShape = convertToMlirShape(adaptor.getOutShape()); - int64_t inputWidth = ShapedType::kDynamic; - int64_t inputHeight = ShapedType::kDynamic; - int64_t weightWidth = ShapedType::kDynamic; - int64_t weightHeight = ShapedType::kDynamic; + int64_t inputWidth = RankedShapedType::kDynamic; + int64_t inputHeight = RankedShapedType::kDynamic; + int64_t weightWidth = RankedShapedType::kDynamic; + int64_t weightHeight = RankedShapedType::kDynamic; // Input shape describes input width/height and batch. ShapeAdaptor inputShape = operands.getShape(adaptor.getInput()); if (inputShape.hasRank()) { - outputShape[0] = ShapedType::isDynamic(outputShape[0]) + outputShape[0] = RankedShapedType::isDynamic(outputShape[0]) ? inputShape.getDimSize(0) : outputShape[0]; inputHeight = inputShape.getDimSize(1); @@ -1304,7 +1307,7 @@ // Weight shapes describes the filter width/height and the output channels. ShapeAdaptor weightShape = operands.getShape(adaptor.getFilter()); if (weightShape.hasRank()) { - outputShape[3] = ShapedType::isDynamic(outputShape[3]) + outputShape[3] = RankedShapedType::isDynamic(outputShape[3]) ? weightShape.getDimSize(0) : outputShape[3]; weightHeight = weightShape.getDimSize(1); @@ -1314,7 +1317,7 @@ // Bias shape can describe the output channels. ShapeAdaptor biasShape = operands.getShape(adaptor.getInput()); if (biasShape.hasRank()) { - outputShape[3] = ShapedType::isDynamic(outputShape[3]) + outputShape[3] = RankedShapedType::isDynamic(outputShape[3]) ? biasShape.getDimSize(0) : outputShape[3]; } @@ -1322,20 +1325,22 @@ llvm::ArrayRef padding = adaptor.getOutPad(); llvm::ArrayRef stride = adaptor.getStride(); - if (!ShapedType::isDynamic(inputHeight) && - !ShapedType::isDynamic(weightHeight)) { + if (!RankedShapedType::isDynamic(inputHeight) && + !RankedShapedType::isDynamic(weightHeight)) { int64_t calculateSize = (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight; - outputShape[1] = - ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1]; + outputShape[1] = RankedShapedType::isDynamic(outputShape[1]) + ? calculateSize + : outputShape[1]; } - if (!ShapedType::isDynamic(inputWidth) && - !ShapedType::isDynamic(weightWidth)) { + if (!RankedShapedType::isDynamic(inputWidth) && + !RankedShapedType::isDynamic(weightWidth)) { int64_t calculateSize = (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth; - outputShape[2] = - ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2]; + outputShape[2] = RankedShapedType::isDynamic(outputShape[2]) + ? calculateSize + : outputShape[2]; } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp @@ -22,7 +22,7 @@ SmallVector convertFromMlirShape(ArrayRef shape) { return to_vector(llvm::map_range(shape, [](int64_t dim) { - return ShapedType::isDynamic(dim) ? -1 : dim; + return RankedShapedType::isDynamic(dim) ? -1 : dim; })); } @@ -34,16 +34,16 @@ PatternRewriter &rewriter) const override { Value input = op.getInput(); Value weight = op.getWeight(); - ShapedType inputType = input.getType().cast(); - ShapedType weightType = weight.getType().cast(); - ShapedType resultType = op.getType().cast(); + auto inputType = input.getType().cast(); + auto weightType = weight.getType().dyn_cast(); + auto resultType = op.getType().cast(); auto numDynamic = - llvm::count_if(inputType.getShape(), ShapedType::isDynamic); + llvm::count_if(inputType.getShape(), RankedShapedType::isDynamic); if (numDynamic > 1) return rewriter.notifyMatchFailure( op, "at most one dim in input may be dynamic"); - if (!weightType.hasRank()) + if (!weightType) return rewriter.notifyMatchFailure(op, "unranked weight input"); if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; })) @@ -76,7 +76,7 @@ llvm::SmallVector newShape(inputType.getShape()); for (int i = 0, s = newShape.size(); i < s; ++i) { - if (newShape[i] != ShapedType::kDynamic) { + if (newShape[i] != RankedShapedType::kDynamic) { newShape[i] += pad[i * 2] + pad[i * 2 + 1]; } } @@ -98,7 +98,7 @@ // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC]. ArrayRef inputShape = inputType.getShape(); - int64_t combined = ShapedType::kDynamic; + int64_t combined = RankedShapedType::kDynamic; if (numDynamic == 0) combined = inputShape[0] * inputShape[1] * inputShape[2]; llvm::SmallVector revisedInputShape{combined, inputShape[3]}; diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp @@ -28,9 +28,9 @@ PatternRewriter &rewriter) const override { Value input = op.getInput(); Value weight = op.getWeight(); - ShapedType inputType = input.getType().cast(); - ShapedType weightType = weight.getType().cast(); - ShapedType resultType = op.getOutput().getType().cast(); + auto inputType = input.getType().cast(); + auto weightType = weight.getType().cast(); + auto resultType = op.getOutput().getType().cast(); if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && resultType.hasStaticShape())) { @@ -100,7 +100,7 @@ llvm::SmallVector newShape(inputType.getShape()); for (int i = 0, s = pad.size(); i < s; ++i) { - if (newShape[i / 2] != ShapedType::kDynamic) { + if (newShape[i / 2] != RankedShapedType::kDynamic) { newShape[i / 2] += pad[i]; } } @@ -133,7 +133,8 @@ .getResult(); // Reshape output to [N, H, W, C * M]. - auto outputShape = op.getOutput().getType().cast().getShape(); + auto outputShape = + op.getOutput().getType().cast().getShape(); auto outputShapeType = RankedTensorType::get( outputShape, input.getType().dyn_cast().getElementType()); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -82,10 +82,13 @@ Value weight = op->getOperand(1); Value bias = op->getOperand(2); - ShapedType inputTy = input.getType().cast(); - ShapedType weightTy = weight.getType().cast(); - ShapedType biasTy = bias.getType().cast(); - ShapedType resultTy = op->getResult(0).getType().cast(); + auto inputTy = input.getType().dyn_cast(); + auto weightTy = weight.getType().dyn_cast(); + auto biasTy = bias.getType().dyn_cast(); + auto resultTy = op->getResult(0).getType().dyn_cast(); + + if (!inputTy || !weightTy || !biasTy || !resultTy) + return failure(); llvm::ArrayRef stride = op.getStride(); llvm::ArrayRef pad = op.getOutPad(); @@ -145,10 +148,13 @@ Value weight = op->getOperand(1); Value bias = op->getOperand(2); - ShapedType inputTy = input.getType().cast(); - ShapedType weightTy = weight.getType().cast(); - ShapedType biasTy = bias.getType().cast(); - ShapedType resultTy = op->getResult(0).getType().cast(); + auto inputTy = input.getType().dyn_cast(); + auto weightTy = weight.getType().dyn_cast(); + auto biasTy = bias.getType().dyn_cast(); + auto resultTy = op->getResult(0).getType().dyn_cast(); + + if (!inputTy || !weightTy || !biasTy || !resultTy) + return failure(); Type inputETy = inputTy.getElementType(); Type weightETy = weightTy.getElementType(); @@ -230,7 +236,7 @@ weight = createOpAndInfer( rewriter, loc, UnrankedTensorType::get(weightETy), weight, rewriter.getDenseI64ArrayAttr(weightReshapeDims1)); - ShapedType restridedWeightTy = weight.getType().cast(); + auto restridedWeightTy = weight.getType().cast(); weight = createOpAndInfer( rewriter, loc, UnrankedTensorType::get(weightETy), weight, @@ -296,7 +302,7 @@ } // Factor the resulting width / height. - ShapedType convTy = conv2d.getType().cast(); + auto convTy = conv2d.getType().cast(); Type convETy = convTy.getElementType(); int64_t convHeight = convTy.getDimSize(1); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp @@ -22,8 +22,8 @@ namespace { template -DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType, - ShapedType outputType, +DenseElementsAttr transposeType(ElementsAttr attr, RankedShapedType inputType, + RankedShapedType outputType, llvm::ArrayRef permValues) { if (inputType.getNumElements() == 0) return DenseElementsAttr::get(outputType, llvm::ArrayRef{}); diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp --- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp @@ -309,7 +309,7 @@ builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0], narrowRange.getValue(), convfunc.expressedType, isSigned); } else if (min.size() > 1) { // Per-axis quant on filterQuantDim. - auto shape = inputDType.dyn_cast(); + auto shape = inputDType.dyn_cast(); if (!shape) return {}; if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) { diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -41,7 +41,7 @@ // Dimensions are compatible when //. 1. One is dynamic, the rest are 1 - if (ShapedType::isDynamic(dim)) { + if (RankedShapedType::isDynamic(dim)) { if (seenDynamic || nonOneDim) return false; seenDynamic = true; @@ -81,7 +81,7 @@ // Check each dimension is consistent. for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) { - if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) { + if (RankedShapedType::isDynamic(*i1) || RankedShapedType::isDynamic(*i2)) { // One or both dimensions is unknown. Follow TensorFlow behavior: // - If either dimension is greater than 1, we assume that the program is // correct, and the other dimension will be broadcast to match it. @@ -95,7 +95,7 @@ } else if (*i2 == 1) { *iR = *i1; } else { - *iR = ShapedType::kDynamic; + *iR = RankedShapedType::kDynamic; } } else { if (*i1 == *i2 || *i2 == 1) { @@ -116,7 +116,7 @@ /// Returns the shape of the given type. Scalars will be considered as having a /// shape with zero dimensions. static ArrayRef getShape(Type type) { - if (auto sType = type.dyn_cast()) + if (auto sType = type.dyn_cast()) return sType.getShape(); return {}; } @@ -200,8 +200,8 @@ // then it is compatible, else if the inferred dim is 1 then it is also // compatible. But if the existing dim is 1 and the inferred is greater than // 1 then flag. - return dim1 == dim2 || ShapedType::isDynamic(dim1) || - ShapedType::isDynamic(dim2) || dim1 == 1; + return dim1 == dim2 || RankedShapedType::isDynamic(dim1) || + RankedShapedType::isDynamic(dim2) || dim1 == 1; }; if (inferred.size() != existing.size()) return false; @@ -220,7 +220,7 @@ llvm::interleave( shape, ss, [&](int64_t dim) { - if (ShapedType::isDynamic(dim)) + if (RankedShapedType::isDynamic(dim)) ss << '?'; else ss << dim; diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -17,8 +17,8 @@ using namespace mlir; std::optional> -mlir::getReassociationIndicesForReshape(ShapedType sourceType, - ShapedType targetType) { +mlir::getReassociationIndicesForReshape(RankedShapedType sourceType, + RankedShapedType targetType) { if (sourceType.getRank() > targetType.getRank()) return getReassociationIndicesForCollapse(sourceType.getShape(), targetType.getShape()); @@ -48,7 +48,7 @@ int64_t currTargetShape = targetShape[targetDim]; while (sourceDim < sourceShape.size() && - sourceShape[sourceDim] != ShapedType::kDynamic && + sourceShape[sourceDim] != RankedShapedType::kDynamic && prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) { prodOfCollapsedDims *= sourceShape[sourceDim]; currIndices.push_back(sourceDim++); @@ -57,14 +57,15 @@ // If the current expanded dimension is dynamic, then the collapsed // dimensions should also be dynamic and product of all previous unprocessed // dimensions of the expanded shape should be 1. - if (sourceShape[sourceDim] == ShapedType::kDynamic && - (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1)) + if (sourceShape[sourceDim] == RankedShapedType::kDynamic && + (currTargetShape != RankedShapedType::kDynamic || + prodOfCollapsedDims != 1)) return std::nullopt; // If the collapsed dim is dynamic, the current expanded dim should also // be dynamic. - if (currTargetShape == ShapedType::kDynamic && - sourceShape[sourceDim] != ShapedType::kDynamic) + if (currTargetShape == RankedShapedType::kDynamic && + sourceShape[sourceDim] != RankedShapedType::kDynamic) return std::nullopt; // For static shapes, if the product of dimensions of the expanded shape @@ -83,7 +84,7 @@ // Process any remaining entries in the source shape. They all need to be // 1 or dynamic. for (; sourceDim < sourceShape.size(); sourceDim++) { - if (sourceShape[sourceDim] != ShapedType::kDynamic && + if (sourceShape[sourceDim] != RankedShapedType::kDynamic && sourceShape[sourceDim] != 1) return std::nullopt; // The map is empty when the target type is a scalar. @@ -234,7 +235,7 @@ int64_t linearizedStaticShape = 1; for (const auto &dim : llvm::enumerate( expandedShape.slice(expandedDimStart, map.value().size()))) { - if (ShapedType::isDynamic(dim.value())) { + if (RankedShapedType::isDynamic(dim.value())) { if (isExpandingReshape && dynamicShape) { return emitError("invalid to have a single dimension (" + Twine(map.index()) + @@ -248,7 +249,7 @@ } } if (dynamicShape) { - if (!ShapedType::isDynamic(collapsedShape[map.index()])) { + if (!RankedShapedType::isDynamic(collapsedShape[map.index()])) { return emitError( "expected dimension " + Twine(map.index()) + " of collapsed type to be dynamic since one or more of the " diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -58,7 +58,7 @@ return; } dynamicVec.push_back(v); - staticVec.push_back(ShapedType::kDynamic); + staticVec.push_back(RankedShapedType::kDynamic); } void dispatchIndexOpFoldResults(ArrayRef ofrs, @@ -158,7 +158,7 @@ } /// Return a vector of OpFoldResults with the same size a staticValues, but all -/// elements for which ShapedType::isDynamic is true, will be replaced by +/// elements for which RankedShapedType::isDynamic is true, will be replaced by /// dynamicValues. SmallVector getMixedValues(ArrayRef staticValues, ValueRange dynamicValues, Builder &b) { @@ -168,7 +168,7 @@ unsigned count = static_cast(staticValues.size()); for (unsigned idx = 0; idx < count; ++idx) { int64_t value = staticValues[idx]; - res.push_back(ShapedType::isDynamic(value) + res.push_back(RankedShapedType::isDynamic(value) ? OpFoldResult{dynamicValues[numDynamic++]} : OpFoldResult{b.getI64IntegerAttr(staticValues[idx])}); } @@ -186,7 +186,7 @@ if (it.is()) { staticValues.push_back(it.get().cast().getInt()); } else { - staticValues.push_back(ShapedType::kDynamic); + staticValues.push_back(RankedShapedType::kDynamic); dynamicValues.push_back(it.get()); } } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -139,7 +139,7 @@ return succeeded(successStrides) && (strides.empty() || strides.back() == 1); } -AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType, +AffineMap mlir::vector::getTransferMinorIdentityMap(RankedShapedType shapedType, VectorType vectorType) { int64_t elementVectorRank = 0; VectorType elementVectorType = @@ -916,7 +916,7 @@ VectorType rhsType = this->getRhsType(); unsigned numVecDims = lhsIdxMap.getNumDims(); - SmallVector maskShape(numVecDims, ShapedType::kDynamic); + SmallVector maskShape(numVecDims, RankedShapedType::kDynamic); // Using the information in the indexing maps, extract the size of each // dimension in the vector.contract operation from the two input operands. @@ -925,7 +925,7 @@ for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize; - assert(!ShapedType::isDynamicShape(maskShape) && + assert(!RankedShapedType::isDynamicShape(maskShape) && "Mask shape couldn't be computed"); return VectorType::get(maskShape, @@ -3347,7 +3347,7 @@ } static LogicalResult -verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, +verifyTransferOp(VectorTransferOpInterface op, RankedShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds) { @@ -3535,7 +3535,7 @@ LogicalResult TransferReadOp::verify() { // Consistency of elemental types in source and vector. - ShapedType shapedType = getShapedType(); + RankedShapedType shapedType = getShapedType(); VectorType vectorType = getVectorType(); VectorType maskType = getMaskType(); auto paddingType = getPadding().getType(); @@ -3994,7 +3994,7 @@ LogicalResult TransferWriteOp::verify() { // Consistency of elemental types in shape and vector. - ShapedType shapedType = getShapedType(); + RankedShapedType shapedType = getShapedType(); VectorType vectorType = getVectorType(); VectorType maskType = getMaskType(); auto permutationMap = getPermutationMap(); @@ -4602,7 +4602,7 @@ VectorType indVType = getIndexVectorType(); VectorType maskVType = getMaskVectorType(); VectorType resVType = getVectorType(); - ShapedType baseType = getBaseType(); + RankedShapedType baseType = getBaseType(); if (!baseType.isa()) return emitOpError("requires base to be a memref or ranked tensor type"); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -427,7 +427,7 @@ loc, accType, multiReductionOp.getAcc()); Value castMask; if (maskableOp.isMasked()) { - auto maskType = mask.getType().cast(); + auto maskType = mask.getType().cast(); auto castMaskType = VectorType::get(ArrayRef{1, maskType.getShape().back()}, maskType.getElementType()); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -123,7 +123,8 @@ // appropriate indices for the extract/insert operations. Value result = rewriter.create( loc, resType, rewriter.getZeroAttr(resType)); - int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape); + int64_t numTransposedElements = + RankedShapedType::getNumElements(prunedInShape); for (int64_t linearIdx = 0; linearIdx < numTransposedElements; ++linearIdx) { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -406,7 +406,7 @@ /// input starting at `firstDimToCollapse`. static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, Value input, int64_t firstDimToCollapse) { - ShapedType inputType = input.getType().cast(); + auto inputType = input.getType().cast(); if (inputType.getRank() == 1) return input; SmallVector reassociation; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -171,11 +171,12 @@ resStrides(bT.getRank(), 0); for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) { resShape[idx] = - (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic; - resStrides[idx] = - (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic; + (aShape[idx] == bShape[idx]) ? aShape[idx] : RankedShapedType::kDynamic; + resStrides[idx] = (aStrides[idx] == bStrides[idx]) + ? aStrides[idx] + : RankedShapedType::kDynamic; } - resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic; + resOffset = (aOffset == bOffset) ? aOffset : RankedShapedType::kDynamic; return MemRefType::get( resShape, aT.getElementType(), StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides)); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2249,7 +2249,7 @@ } static void -printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os, +printDenseElementsAttrImpl(bool isSplat, RankedShapedType type, raw_ostream &os, function_ref printEltFn) { // Special case for 0-d and splat tensors. if (isSplat) @@ -2476,7 +2476,7 @@ .Case([&](RankedTensorType tensorTy) { os << "tensor<"; for (int64_t dim : tensorTy.getShape()) { - if (ShapedType::isDynamic(dim)) + if (RankedShapedType::isDynamic(dim)) os << '?'; else os << dim; @@ -2498,7 +2498,7 @@ .Case([&](MemRefType memrefTy) { os << "memref<"; for (int64_t dim : memrefTy.getShape()) { - if (ShapedType::isDynamic(dim)) + if (RankedShapedType::isDynamic(dim)) os << '?'; else os << dim; diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -83,7 +83,8 @@ /// signals if the data is already known to be a splat. Callers to this /// function are expected to tag preknown splat values when possible, e.g. one /// element shapes. - static KeyTy getKey(ShapedType ty, ArrayRef data, bool isKnownSplat) { + static KeyTy getKey(RankedShapedType ty, ArrayRef data, + bool isKnownSplat) { // Handle an empty storage instance. if (data.empty()) return KeyTy(ty, data, 0); @@ -234,7 +235,7 @@ /// signals if the data is already known to be a splat. Callers to this /// function are expected to tag preknown splat values when possible, e.g. one /// element shapes. - static KeyTy getKey(ShapedType ty, ArrayRef data, + static KeyTy getKey(RankedShapedType ty, ArrayRef data, bool isKnownSplat) { // Handle an empty storage instance. if (data.empty()) diff --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp --- a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp +++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp @@ -32,7 +32,8 @@ return elementsAttr.getType().getNumElements(); } -bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef index) { +bool ElementsAttr::isValidIndex(RankedShapedType type, + ArrayRef index) { // Verify that the rank of the indices matches the held type. int64_t rank = type.getRank(); if (rank == 0 && index.size() == 1 && index[0] == 0) @@ -53,7 +54,7 @@ } uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef index) { - ShapedType shapeType = type.cast(); + RankedShapedType shapeType = type.cast(); assert(isValidIndex(shapeType, index) && "expected valid multi-dimensional index"); diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -208,7 +208,7 @@ /// Prints a strided layout attribute. void StridedLayoutAttr::print(llvm::raw_ostream &os) const { auto printIntOrQuestion = [&](int64_t value) { - if (ShapedType::isDynamic(value)) + if (RankedShapedType::isDynamic(value)) os << "?"; else os << value; @@ -581,7 +581,8 @@ /// Returns true if 'values' corresponds to a splat, i.e. one element, or has /// the same element count as 'type'. template -static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { +static bool hasSameElementsOrSplat(RankedShapedType type, + const Values &values) { return (values.size() == 1) || (type.getNumElements() == static_cast(values.size())); } @@ -887,7 +888,7 @@ return attr.isa(); } -DenseElementsAttr DenseElementsAttr::get(ShapedType type, +DenseElementsAttr DenseElementsAttr::get(RankedShapedType type, ArrayRef values) { assert(hasSameElementsOrSplat(type, values)); @@ -972,7 +973,7 @@ return DenseIntOrFPElementsAttr::getRaw(type, data); } -DenseElementsAttr DenseElementsAttr::get(ShapedType type, +DenseElementsAttr DenseElementsAttr::get(RankedShapedType type, ArrayRef values) { assert(hasSameElementsOrSplat(type, values)); assert(type.getElementType().isInteger(1)); @@ -997,7 +998,7 @@ return DenseIntOrFPElementsAttr::getRaw(type, buff); } -DenseElementsAttr DenseElementsAttr::get(ShapedType type, +DenseElementsAttr DenseElementsAttr::get(RankedShapedType type, ArrayRef values) { assert(!type.getElementType().isIntOrFloat()); return DenseStringElementsAttr::get(type, values); @@ -1006,14 +1007,14 @@ /// Constructs a dense integer elements attribute from an array of APInt /// values. Each APInt value is expected to have the same bitwidth as the /// element type of 'type'. -DenseElementsAttr DenseElementsAttr::get(ShapedType type, +DenseElementsAttr DenseElementsAttr::get(RankedShapedType type, ArrayRef values) { assert(type.getElementType().isIntOrIndex()); assert(hasSameElementsOrSplat(type, values)); size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values); } -DenseElementsAttr DenseElementsAttr::get(ShapedType type, +DenseElementsAttr DenseElementsAttr::get(RankedShapedType type, ArrayRef> values) { ComplexType complex = type.getElementType().cast(); assert(complex.getElementType().isa()); @@ -1027,7 +1028,7 @@ // Constructs a dense float elements attribute from an array of APFloat // values. Each APFloat value is expected to have the same bitwidth as the // element type of 'type'. -DenseElementsAttr DenseElementsAttr::get(ShapedType type, +DenseElementsAttr DenseElementsAttr::get(RankedShapedType type, ArrayRef values) { assert(type.getElementType().isa()); assert(hasSameElementsOrSplat(type, values)); @@ -1035,7 +1036,7 @@ return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values); } DenseElementsAttr -DenseElementsAttr::get(ShapedType type, +DenseElementsAttr::get(RankedShapedType type, ArrayRef> values) { ComplexType complex = type.getElementType().cast(); assert(complex.getElementType().isa()); @@ -1050,12 +1051,13 @@ /// data for this attribute. Users should generally not use this methods as /// the expected buffer format may not be a form the user expects. DenseElementsAttr -DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef rawBuffer) { +DenseElementsAttr::getFromRawBuffer(RankedShapedType type, + ArrayRef rawBuffer) { return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer); } /// Returns true if the given buffer is a valid raw buffer for the given type. -bool DenseElementsAttr::isValidRawBuffer(ShapedType type, +bool DenseElementsAttr::isValidRawBuffer(RankedShapedType type, ArrayRef rawBuffer, bool &detectedSplat) { size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); @@ -1119,14 +1121,14 @@ } /// Defaults down the subclass implementation. -DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, +DenseElementsAttr DenseElementsAttr::getRawComplex(RankedShapedType type, ArrayRef data, int64_t dataEltSize, bool isInt, bool isSigned) { return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, isSigned); } -DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, +DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(RankedShapedType type, ArrayRef data, int64_t dataEltSize, bool isInt, @@ -1203,8 +1205,8 @@ /// Return a new DenseElementsAttr that has the same data as the current /// attribute, but has been reshaped to 'newType'. The new type must have the /// same total number of elements as well as element type. -DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { - ShapedType curType = getType(); +DenseElementsAttr DenseElementsAttr::reshape(RankedShapedType newType) { + RankedShapedType curType = getType(); if (curType == newType) return *this; @@ -1215,10 +1217,10 @@ return DenseIntOrFPElementsAttr::getRaw(newType, getRawData()); } -DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) { +DenseElementsAttr DenseElementsAttr::resizeSplat(RankedShapedType newType) { assert(isSplat() && "expected a splat type"); - ShapedType curType = getType(); + RankedShapedType curType = getType(); if (curType == newType) return *this; @@ -1232,7 +1234,7 @@ /// type must have the same shape and element types of the same bitwidth as the /// current type. DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) { - ShapedType curType = getType(); + RankedShapedType curType = getType(); Type curElType = curType.getElementType(); if (curElType == newElType) return *this; @@ -1255,7 +1257,7 @@ return cast().mapValues(newElementType, mapping); } -ShapedType DenseElementsAttr::getType() const { +RankedShapedType DenseElementsAttr::getType() const { return static_cast(impl)->type; } @@ -1292,7 +1294,7 @@ /// Constructs a dense elements attribute from an array of raw APFloat values. /// Each APFloat value is expected to have the same bitwidth as the element /// type of 'type'. 'type' must be a vector or tensor with static shape. -DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, +DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(RankedShapedType type, size_t storageWidth, ArrayRef values) { std::vector data; @@ -1304,7 +1306,7 @@ /// Constructs a dense elements attribute from an array of raw APInt values. /// Each APInt value is expected to have the same bitwidth as the element type /// of 'type'. -DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, +DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(RankedShapedType type, size_t storageWidth, ArrayRef values) { std::vector data; @@ -1312,7 +1314,7 @@ return DenseIntOrFPElementsAttr::getRaw(type, data); } -DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, +DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(RankedShapedType type, ArrayRef data) { assert(type.hasStaticShape() && "type must have static shape"); bool isSplat = false; @@ -1325,7 +1327,7 @@ /// Overload of the raw 'get' method that asserts that the given type is of /// complex type. This method is used to verify type invariants that the /// templatized 'get' method cannot. -DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, +DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(RankedShapedType type, ArrayRef data, int64_t dataEltSize, bool isInt, @@ -1343,10 +1345,9 @@ /// Overload of the 'getRaw' method that asserts that the given type is of /// integer type. This method is used to verify type invariants that the /// templatized 'get' method cannot. -DenseElementsAttr -DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef data, - int64_t dataEltSize, bool isInt, - bool isSigned) { +DenseElementsAttr DenseIntOrFPElementsAttr::getRawIntOrFloat( + RankedShapedType type, ArrayRef data, int64_t dataEltSize, bool isInt, + bool isSigned) { assert( ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); @@ -1401,7 +1402,7 @@ void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( ArrayRef inRawData, MutableArrayRef outRawData, - ShapedType type) { + RankedShapedType type) { size_t numElements = type.getNumElements(); Type elementType = type.getElementType(); if (ComplexType complexTy = elementType.dyn_cast()) { @@ -1423,13 +1424,14 @@ //===----------------------------------------------------------------------===// template -static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, - Type newElementType, - llvm::SmallVectorImpl &data) { +static RankedShapedType +mappingHelper(Fn mapping, Attr &attr, RankedShapedType inType, + Type newElementType, llvm::SmallVectorImpl &data) { size_t bitWidth = getDenseElementBitWidth(newElementType); size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); - ShapedType newArrayType = inType.cloneWith(inType.getShape(), newElementType); + RankedShapedType newArrayType = cast( + inType.cloneWith(inType.getShape(), newElementType)); size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); data.resize(llvm::divideCeil(storageBitWidth * numRawElements, CHAR_BIT)); @@ -1499,12 +1501,12 @@ //===----------------------------------------------------------------------===// DenseResourceElementsAttr -DenseResourceElementsAttr::get(ShapedType type, +DenseResourceElementsAttr::get(RankedShapedType type, DenseResourceElementsHandle handle) { return Base::get(type.getContext(), type, handle); } -DenseResourceElementsAttr DenseResourceElementsAttr::get(ShapedType type, +DenseResourceElementsAttr DenseResourceElementsAttr::get(RankedShapedType type, StringRef blobName, AsmResourceBlob blob) { // Extract the builtin dialect resource manager from context and construct a @@ -1573,7 +1575,7 @@ template DenseResourceElementsAttrBase -DenseResourceElementsAttrBase::get(ShapedType type, StringRef blobName, +DenseResourceElementsAttrBase::get(RankedShapedType type, StringRef blobName, AsmResourceBlob blob) { // Check that the blob is in the form we were expecting. assert(blob.getDataAlignment() == alignof(T) && @@ -1687,16 +1689,15 @@ return flatSparseIndices; } -LogicalResult -SparseElementsAttr::verify(function_ref emitError, - ShapedType type, DenseIntElementsAttr sparseIndices, - DenseElementsAttr values) { - ShapedType valuesType = values.getType(); +LogicalResult SparseElementsAttr::verify( + function_ref emitError, RankedShapedType type, + DenseIntElementsAttr sparseIndices, DenseElementsAttr values) { + RankedShapedType valuesType = values.getType(); if (valuesType.getRank() != 1) return emitError() << "expected 1-d tensor for sparse element values"; // Verify the indices and values shape. - ShapedType indicesType = sparseIndices.getType(); + RankedShapedType indicesType = sparseIndices.getType(); auto emitShapeError = [&]() { return emitError() << "expected shape ([" << type.getShape() << "]); inferred shape of indices literal ([" @@ -1757,7 +1758,7 @@ // AffineExpr for offset. // Static case. - if (!ShapedType::isDynamic(offset)) { + if (!RankedShapedType::isDynamic(offset)) { auto cst = getAffineConstantExpr(offset, context); expr = cst; } else { @@ -1774,7 +1775,7 @@ auto d = getAffineDimExpr(dim, context); AffineExpr mult; // Static case. - if (!ShapedType::isDynamic(stride)) + if (!RankedShapedType::isDynamic(stride)) mult = getAffineConstantExpr(stride, context); else // Dynamic case, new symbol for each new stride. diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp --- a/mlir/lib/IR/BuiltinDialectBytecode.cpp +++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp @@ -536,7 +536,7 @@ DenseStringElementsAttr BuiltinDialectBytecodeInterface::readDenseStringElementsAttr( DialectBytecodeReader &reader) const { - ShapedType type; + RankedShapedType type; uint64_t isSplat; if (failed(reader.readType(type)) || failed(reader.readVarInt(isSplat))) return DenseStringElementsAttr(); diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp --- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp +++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp @@ -20,14 +20,15 @@ #include "mlir/IR/BuiltinTypeInterfaces.cpp.inc" //===----------------------------------------------------------------------===// -// ShapedType +// RankedShapedType //===----------------------------------------------------------------------===// -constexpr int64_t ShapedType::kDynamic; +constexpr int64_t RankedShapedType::kDynamic; -int64_t ShapedType::getNumElements(ArrayRef shape) { +int64_t RankedShapedType::getNumElements(ArrayRef shape) { int64_t num = 1; for (int64_t dim : shape) { + assert(!RankedShapedType::isDynamic(dim) && "expected only static dims"); num *= dim; assert(num >= 0 && "integer overflow in element count computation"); } diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -271,10 +271,6 @@ bool TensorType::hasRank() const { return !isa(); } -ArrayRef TensorType::getShape() const { - return cast().getShape(); -} - TensorType TensorType::cloneWith(std::optional> shape, Type elementType) const { if (auto unrankedTy = dyn_cast()) { @@ -319,7 +315,7 @@ ArrayRef shape, Type elementType, Attribute encoding) { for (int64_t s : shape) - if (s < 0 && !ShapedType::isDynamic(s)) + if (s < 0 && !RankedShapedType::isDynamic(s)) return emitError() << "invalid tensor dimension size"; if (auto v = encoding.dyn_cast_or_null()) if (failed(v.verifyEncoding(shape, elementType, emitError))) @@ -349,10 +345,6 @@ bool BaseMemRefType::hasRank() const { return !isa(); } -ArrayRef BaseMemRefType::getShape() const { - return cast().getShape(); -} - BaseMemRefType BaseMemRefType::cloneWith(std::optional> shape, Type elementType) const { if (auto unrankedTy = dyn_cast()) { @@ -426,9 +418,9 @@ if (originalType == candidateReducedType) return SliceVerificationResult::Success; - ShapedType originalShapedType = originalType.cast(); - ShapedType candidateReducedShapedType = - candidateReducedType.cast(); + auto originalShapedType = originalType.cast(); + auto candidateReducedShapedType = + candidateReducedType.cast(); // Rank and size logic is valid for all ShapedTypes. ArrayRef originalShape = originalShapedType.getShape(); @@ -617,7 +609,7 @@ // Negative sizes are not allowed except for `kDynamic`. for (int64_t s : shape) - if (s < 0 && !ShapedType::isDynamic(s)) + if (s < 0 && !RankedShapedType::isDynamic(s)) return emitError() << "invalid memref size"; assert(layout && "missing layout specification"); @@ -712,7 +704,7 @@ } /// A stride specification is a list of integer values that are either static -/// or dynamic (encoded with ShapedType::kDynamic). Strides encode +/// or dynamic (encoded with RankedShapedType::kDynamic). Strides encode /// the distance in the number of elements between successive entries along a /// particular dimension. /// @@ -801,12 +793,12 @@ if (auto cst = offsetExpr.dyn_cast()) offset = cst.getValue(); else - offset = ShapedType::kDynamic; + offset = RankedShapedType::kDynamic; for (auto e : strideExprs) { if (auto c = e.dyn_cast()) strides.push_back(c.getValue()); else - strides.push_back(ShapedType::kDynamic); + strides.push_back(RankedShapedType::kDynamic); } return success(); } diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -63,8 +63,8 @@ for (auto dims : llvm::zip(shape1, shape2)) { int64_t dim1 = std::get<0>(dims); int64_t dim2 = std::get<1>(dims); - if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) && - dim1 != dim2) + if (!RankedShapedType::isDynamic(dim1) && + !RankedShapedType::isDynamic(dim2) && dim1 != dim2) return failure(); } return success(); @@ -85,10 +85,11 @@ if (!sType2) return failure(); - if (!sType1.hasRank() || !sType2.hasRank()) + if (!isa(sType1) || !isa(sType2)) return success(); - return verifyCompatibleShape(sType1.getShape(), sType2.getShape()); + return verifyCompatibleShape(cast(sType1).getShape(), + cast(sType2).getShape()); } /// Returns success if the given two arrays have the same number of elements and @@ -107,10 +108,10 @@ return success(); auto staticDim = std::accumulate( dims.begin(), dims.end(), dims.front(), [](auto fold, auto dim) { - return ShapedType::isDynamic(dim) ? fold : dim; + return RankedShapedType::isDynamic(dim) ? fold : dim; }); return success(llvm::all_of(dims, [&](auto dim) { - return ShapedType::isDynamic(dim) || dim == staticDim; + return RankedShapedType::isDynamic(dim) || dim == staticDim; })); } @@ -142,23 +143,31 @@ } // Remove all unranked shapes - auto shapes = llvm::to_vector<8>(llvm::make_filter_range( - shapedTypes, [](auto shapedType) { return shapedType.hasRank(); })); + auto shapes = llvm::to_vector<8>( + llvm::make_filter_range(shapedTypes, [](auto shapedType) { + return isa(shapedType); + })); if (shapes.empty()) return success(); // All ranks should be equal - auto firstRank = shapes.front().getRank(); - if (llvm::any_of(shapes, - [&](auto shape) { return firstRank != shape.getRank(); })) + auto firstRank = cast(shapes.front()).getRank(); + if (llvm::any_of(shapes, [&](auto shape) { + return firstRank != cast(shape).getRank(); + })) return failure(); for (unsigned i = 0; i < firstRank; ++i) { // Retrieve all ranked dimensions auto dims = llvm::to_vector<8>(llvm::map_range( llvm::make_filter_range( - shapes, [&](auto shape) { return shape.getRank() >= i; }), - [&](auto shape) { return shape.getDimSize(i); })); + shapes, + [&](auto shape) { + return cast(shape).getRank() >= i; + }), + [&](auto shape) { + return cast(shape).getDimSize(i); + })); if (verifyCompatibleDims(dims).failed()) return failure(); } diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -39,20 +39,21 @@ auto shapedType = result.getType().dyn_cast(); if (!shapedType) continue; - if (!shapedType.hasRank()) { + if (!isa(shapedType)) { // Nothing to check for unranked shaped values. ++resultIdx; continue; } + auto rankedShapedType = cast(shapedType); // Assert one OpFoldResult per dimension. - assert(shapedType.getRank() == + assert(rankedShapedType.getRank() == static_cast(reifiedReturnShapes[resultIdx].size()) && "incorrect implementation of ReifyRankedShapedTypeOpInterface"); - for (int64_t dim = 0; dim < shapedType.getRank(); ++dim) { + for (int64_t dim = 0; dim < rankedShapedType.getRank(); ++dim) { // reifyResultShapes must return: // * Attribute for static dimensions // * Value for dynamic dimensions - assert(shapedType.isDynamicDim(dim) == + assert(rankedShapedType.isDynamicDim(dim) == reifiedReturnShapes[resultIdx][dim].is() && "incorrect implementation of ReifyRankedShapedTypeOpInterface"); } @@ -69,7 +70,7 @@ if (val.isNull()) return false; if (auto t = val.dyn_cast()) - return t.cast().hasRank(); + return isa(t); if (val.is()) return true; return val.get()->hasRank(); @@ -88,7 +89,7 @@ void ShapeAdaptor::getDims(SmallVectorImpl &res) const { assert(hasRank()); if (auto t = val.dyn_cast()) { - ArrayRef vals = t.cast().getShape(); + ArrayRef vals = t.cast().getShape(); res.assign(vals.begin(), vals.end()); } else if (auto attr = val.dyn_cast()) { auto dattr = attr.cast(); @@ -111,7 +112,7 @@ int64_t ShapeAdaptor::getDimSize(int index) const { assert(hasRank()); if (auto t = val.dyn_cast()) - return t.cast().getDimSize(index); + return t.cast().getDimSize(index); if (auto attr = val.dyn_cast()) return attr.cast() .getValues()[index] @@ -123,7 +124,7 @@ int64_t ShapeAdaptor::getRank() const { assert(hasRank()); if (auto t = val.dyn_cast()) - return t.cast().getRank(); + return t.cast().getRank(); if (auto attr = val.dyn_cast()) return attr.cast().size(); return val.get()->getDims().size(); @@ -134,23 +135,23 @@ return false; if (auto t = val.dyn_cast()) - return t.cast().hasStaticShape(); + return t.cast().hasStaticShape(); if (auto attr = val.dyn_cast()) { auto dattr = attr.cast(); for (auto index : dattr.getValues()) - if (ShapedType::isDynamic(index.getSExtValue())) + if (RankedShapedType::isDynamic(index.getSExtValue())) return false; return true; } auto *stc = val.get(); - return llvm::none_of(stc->getDims(), ShapedType::isDynamic); + return llvm::none_of(stc->getDims(), RankedShapedType::isDynamic); } int64_t ShapeAdaptor::getNumElements() const { assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); if (auto t = val.dyn_cast()) - return t.cast().getNumElements(); + return t.cast().getNumElements(); if (auto attr = val.dyn_cast()) { auto dattr = attr.cast(); @@ -180,7 +181,7 @@ SmallVector dims; getDims(dims); auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string { - if (ShapedType::isDynamic(dim)) + if (RankedShapedType::isDynamic(dim)) return "?"; return llvm::formatv("{0}", dim).str(); }); diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -51,8 +51,8 @@ assert(!dim.has_value() && "invalid dim value"); } else if (auto shapedType = dyn_cast(value.getType())) { assert(*dim >= 0 && "invalid dim value"); - if (shapedType.hasRank()) - assert(*dim < shapedType.getRank() && "invalid dim value"); + if (auto rankedShapedType = dyn_cast(shapedType)) + assert(*dim < rankedShapedType.getRank() && "invalid dim value"); } else { llvm_unreachable("unsupported type"); } @@ -83,8 +83,9 @@ auto shapedType = dyn_cast(value.getType()); if (shapedType) { // Static dimension: return constant directly. - if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim)) - return builder.getAffineConstantExpr(shapedType.getDimSize(*dim)); + auto rankedShapedType = dyn_cast(shapedType); + if (rankedShapedType && !rankedShapedType.isDynamicDim(*dim)) + return builder.getAffineConstantExpr(rankedShapedType.getDimSize(*dim)); } else { // Constant index value: return directly. if (auto constInt = getConstantIntValue(value)) @@ -164,9 +165,9 @@ // Check for static dim size. if (dim != kIndexValue) { - auto shapedType = cast(value.getType()); - if (shapedType.hasRank() && !shapedType.isDynamicDim(dim)) { - bound(value)[dim] == getExpr(shapedType.getDimSize(dim)); + auto rankedShapedType = cast(value.getType()); + if (rankedShapedType && !rankedShapedType.isDynamicDim(dim)) { + bound(value)[dim] == getExpr(rankedShapedType.getDimSize(dim)); continue; } } @@ -346,7 +347,9 @@ continue; } - assert(cast(value.getType()).isDynamicDim(dim) && + auto rankedShapedType = cast(value.getType()); + (void)rankedShapedType; + assert((!rankedShapedType || rankedShapedType.isDynamicDim(dim)) && "expected dynamic dim"); mapOperands.push_back(std::make_pair(value, dim)); } diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -28,7 +28,7 @@ << " values, got " << staticVals.size(); unsigned expectedNumDynamicEntries = llvm::count_if(staticVals, [&](int64_t staticVal) { - return ShapedType::isDynamic(staticVal); + return RankedShapedType::isDynamic(staticVal); }); if (values.size() != expectedNumDynamicEntries) return op->emitError("expected ") @@ -112,7 +112,7 @@ } unsigned idx = 0; llvm::interleaveComma(integers, printer, [&](int64_t integer) { - if (ShapedType::isDynamic(integer)) + if (RankedShapedType::isDynamic(integer)) printer << values[idx++]; else printer << integer; @@ -131,7 +131,7 @@ auto res = parser.parseOptionalOperand(operand); if (res.has_value() && succeeded(res.value())) { values.push_back(operand); - integerVals.push_back(ShapedType::kDynamic); + integerVals.push_back(RankedShapedType::kDynamic); } else { int64_t integer; if (failed(parser.parseInteger(integer))) diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -986,14 +986,15 @@ if (auto iType = type.dyn_cast()) return (os << "size_t"), success(); if (auto tType = type.dyn_cast()) { - if (!tType.hasRank()) + auto rtType = dyn_cast(tType); + if (!rtType) return emitError(loc, "cannot emit unranked tensor type"); - if (!tType.hasStaticShape()) + if (!rtType.hasStaticShape()) return emitError(loc, "cannot emit tensor type with non static shape"); os << "Tensor<"; - if (failed(emitType(loc, tType.getElementType()))) + if (failed(emitType(loc, rtType.getElementType()))) return failure(); - auto shape = tType.getShape(); + auto shape = rtType.getShape(); for (auto dimSize : shape) { os << ", "; os << dimSize; diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -233,7 +233,7 @@ if (!llvm::ConstantDataSequential::isElementTypeCompatible(innermostLLVMType)) return nullptr; - ShapedType type = denseElementsAttr.getType(); + RankedShapedType type = denseElementsAttr.getType(); if (type.getNumElements() == 0) return nullptr; diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -685,7 +685,7 @@ uint32_t resultID = 0; if (auto attr = valueAttr.dyn_cast()) { - int rank = attr.getType().dyn_cast().getRank(); + int rank = attr.getType().dyn_cast().getRank(); SmallVector index(rank); resultID = prepareDenseElementsConstant(loc, constType, attr, /*dim=*/0, index); @@ -732,7 +732,7 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType, DenseElementsAttr valueAttr, int dim, MutableArrayRef index) { - auto shapedType = valueAttr.getType().dyn_cast(); + auto shapedType = valueAttr.getType().dyn_cast(); assert(dim <= shapedType.getRank()); if (shapedType.getRank() == dim) { if (auto attr = valueAttr.dyn_cast()) { diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -218,7 +218,7 @@ static_split_point = split_point dynamic_split_point = None else: - static_split_point = _get_int64_attr(ShapedType.get_dynamic_size()) + static_split_point = _get_int64_attr(RankedShapedType.get_dynamic_size()) dynamic_split_point = _get_op_result_or_value(split_point) target = _get_op_result_or_value(target) @@ -281,7 +281,7 @@ if isinstance(size, int): static_sizes.append(size) else: - static_sizes.append(ShapedType.get_dynamic_size()) + static_sizes.append(RankedShapedType.get_dynamic_size()) dynamic_sizes.append(_get_op_result_or_value(size)) sizes_attr = DenseI64ArrayAttr.get(static_sizes) diff --git a/mlir/python/mlir/dialects/_tensor_ops_ext.py b/mlir/python/mlir/dialects/_tensor_ops_ext.py --- a/mlir/python/mlir/dialects/_tensor_ops_ext.py +++ b/mlir/python/mlir/dialects/_tensor_ops_ext.py @@ -30,7 +30,7 @@ if isinstance(s, int): static_sizes.append(s) else: - static_sizes.append(ShapedType.get_dynamic_size()) + static_sizes.append(RankedShapedType.get_dynamic_size()) dynamic_sizes.append(s) result_type = RankedTensorType.get(static_sizes, element_type) op = self.build_generic( diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -181,7 +181,7 @@ # indexing maps and iterators that match the rank of the first output tensor. # An operation is rank polymorphic if the iteration domain has rank zero. if not iterator_types_attr: - rank = ShapedType(outs[0].type).rank + rank = RankedShapedType(outs[0].type).rank iterator_types_attr = ArrayAttr.get( [Attribute.parse("#linalg.iterator_type")] * rank) scalar_map = AffineMap.get(rank, 0, []) @@ -192,7 +192,7 @@ indexing_maps.append(scalar_map) if arg_def.operand_def.is_tensor(): idx = arg_def.operand_def.registered_index - if idx < len(ins) and ShapedType(ins[idx].type).rank == 0: + if idx < len(ins) and RankedShapedType(ins[idx].type).rank == 0: indexing_maps.append(scalar_map) else: indexing_maps.append(tensor_map) diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -716,11 +716,12 @@ if (!mlirTypeIsAVector(vector) || !mlirTypeIsAShaped(vector)) return 14; if (!mlirTypeEqual(mlirShapedTypeGetElementType(vector), f32) || - !mlirShapedTypeHasRank(vector) || mlirShapedTypeGetRank(vector) != 2 || - mlirShapedTypeGetDimSize(vector, 0) != 2 || - mlirShapedTypeIsDynamicDim(vector, 0) || - mlirShapedTypeGetDimSize(vector, 1) != 3 || - !mlirShapedTypeHasStaticShape(vector)) + !mlirTypeIsARankedShaped(vector) || + mlirRankedShapedTypeGetRank(vector) != 2 || + mlirRankedShapedTypeGetDimSize(vector, 0) != 2 || + mlirRankedShapedTypeIsDynamicDim(vector, 0) || + mlirRankedShapedTypeGetDimSize(vector, 1) != 3 || + !mlirRankedShapedTypeHasStaticShape(vector)) return 15; mlirTypeDump(vector); fprintf(stderr, "\n"); @@ -741,7 +742,7 @@ MlirType unrankedTensor = mlirUnrankedTensorTypeGet(f32); if (!mlirTypeIsATensor(unrankedTensor) || !mlirTypeIsAUnrankedTensor(unrankedTensor) || - mlirShapedTypeHasRank(unrankedTensor)) + mlirTypeIsARankedShaped(unrankedTensor)) return 17; mlirTypeDump(unrankedTensor); fprintf(stderr, "\n"); diff --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir --- a/mlir/test/IR/invalid-builtin-attributes.mlir +++ b/mlir/test/IR/invalid-builtin-attributes.mlir @@ -1,13 +1,13 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics func.func @elementsattr_non_tensor_type() -> () { - "foo"(){bar = dense<[4]> : i32} : () -> () // expected-error {{elements literal must be a shaped type}} + "foo"(){bar = dense<[4]> : i32} : () -> () // expected-error {{elements literal type must be a ranked shaped type with static shape}} } // ----- func.func @elementsattr_non_ranked() -> () { - "foo"(){bar = dense<[4]> : tensor} : () -> () // expected-error {{elements literal type must have static shape}} + "foo"(){bar = dense<[4]> : tensor} : () -> () // expected-error {{elements literal type must be a ranked shaped type with static shape}} } // ----- diff --git a/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp b/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp --- a/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp +++ b/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp @@ -41,13 +41,13 @@ return; } llvm::outs() << "MemRefType offset: "; - if (ShapedType::isDynamic(offset)) + if (RankedShapedType::isDynamic(offset)) llvm::outs() << "?"; else llvm::outs() << offset; llvm::outs() << " strides: "; llvm::interleaveComma(strides, llvm::outs(), [&](int64_t v) { - if (ShapedType::isDynamic(v)) + if (RankedShapedType::isDynamic(v)) llvm::outs() << "?"; else llvm::outs() << v; diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -113,7 +113,7 @@ if (!op.getSource().hasOneUse()) return false; - auto resultType = op.getResult().getType().cast(); + auto resultType = op.getResult().getType().cast(); constexpr int64_t kConstantFoldingMaxNumElements = 1024; return resultType.getNumElements() <= kConstantFoldingMaxNumElements; }; diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -86,7 +86,7 @@ ]> { let mnemonic = "i64_elements"; let parameters = (ins - AttributeSelfTypeParameter<"", "::mlir::ShapedType">:$type, + AttributeSelfTypeParameter<"", "::mlir::RankedShapedType">:$type, ArrayRefParameter<"uint64_t">:$elements ); let extraClassDeclaration = [{ diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -86,7 +86,8 @@ LogicalResult TestI64ElementsAttr::verify(function_ref emitError, - ShapedType type, ArrayRef elements) { + RankedShapedType type, + ArrayRef elements) { if (type.getNumElements() != static_cast(elements.size())) { return emitError() << "number of elements does not match the provided shape type, got: " diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1348,7 +1348,8 @@ if (!sval) { return emitOptionalError(location, "only shaped type operands allowed"); } - int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic; + auto rsval = sval.dyn_cast(); + int64_t dim = rsval ? rsval.getShape().front() : RankedShapedType::kDynamic; auto type = IntegerType::get(context, 17); Attribute encoding; diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -260,7 +260,7 @@ DerivedTypeAttr element_dtype = DerivedTypeAttr<"return getElementTypeOrSelf(getOutput().getType());">; DerivedAttr num_elements = DerivedAttr<"int", - "return getOutput().getType().cast().getNumElements();", + "return getOutput().getType().cast().getNumElements();", "$_builder.getI32IntegerAttr($_self)">; } diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -24,7 +24,8 @@ # CHECK-NEXT: %[[RES:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : tensor<12x?xf32>) -> tensor<12x?xf32> # CHECK-NEXT: return %[[RES]] : tensor<12x?xf32> @func.FuncOp.from_py_func( - RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32)) + RankedTensorType.get((12, RankedShapedType.get_dynamic_size()), f32) + ) def fill_tensor(out): zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result return linalg.fill(zero, outs=[out]) @@ -35,7 +36,8 @@ # CHECK-NEXT: linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : memref<12x?xf32>) # CHECK-NEXT: return @func.FuncOp.from_py_func( - MemRefType.get((12, ShapedType.get_dynamic_size()), f32)) + MemRefType.get((12, RankedShapedType.get_dynamic_size()), f32) + ) def fill_buffer(out): zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result linalg.fill(zero, outs=[out]) diff --git a/mlir/test/python/dialects/shape.py b/mlir/test/python/dialects/shape.py --- a/mlir/test/python/dialects/shape.py +++ b/mlir/test/python/dialects/shape.py @@ -19,8 +19,10 @@ module = Module.create() f32 = F32Type.get() with InsertionPoint(module.body): + @func.FuncOp.from_py_func( - RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32)) + RankedTensorType.get((12, RankedShapedType.get_dynamic_size()), f32) + ) def const_shape_tensor(arg): shape.ConstWitnessOp(False) shape.ConstSizeOp(30) @@ -31,8 +33,6 @@ DenseElementsAttr.get( np.array([3, 4], dtype=np.int64), type=IndexType.get())) - - # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>) # CHECK-DAG: shape.const_witness false # CHECK-DAG: shape.const_size 30 diff --git a/mlir/test/python/dialects/tensor.py b/mlir/test/python/dialects/tensor.py --- a/mlir/test/python/dialects/tensor.py +++ b/mlir/test/python/dialects/tensor.py @@ -23,8 +23,13 @@ @func.FuncOp.from_py_func( RankedTensorType.get( - (ShapedType.get_dynamic_size(), ShapedType.get_dynamic_size()), - f32Type)) + ( + RankedShapedType.get_dynamic_size(), + RankedShapedType.get_dynamic_size(), + ), + f32Type, + ) + ) # CHECK: func @tensor_static_dim # CHECK-SAME: %[[ARG0:.+]]: tensor # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index diff --git a/mlir/test/python/dialects/vector.py b/mlir/test/python/dialects/vector.py --- a/mlir/test/python/dialects/vector.py +++ b/mlir/test/python/dialects/vector.py @@ -36,8 +36,12 @@ with InsertionPoint(module.body): vector_type = VectorType.get([2, 3], F32Type.get()) memref_type = MemRefType.get( - [ShapedType.get_dynamic_size(), - ShapedType.get_dynamic_size()], F32Type.get()) + [ + RankedShapedType.get_dynamic_size(), + RankedShapedType.get_dynamic_size(), + ], + F32Type.get(), + ) index_type = IndexType.get() mask_type = VectorType.get(vector_type.shape, IntegerType.get_signless(1)) identity_map = AffineMap.get_identity(vector_type.rank) diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py --- a/mlir/test/python/ir/array_attributes.py +++ b/mlir/test/python/ir/array_attributes.py @@ -74,13 +74,13 @@ try: attr = DenseElementsAttr.get_splat(non_shaped_type, element) except ValueError as e: - # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32) + # CHECK: Expected a static RankedShapedType for the shaped_type parameter: Type(f32) print(e) try: attr = DenseElementsAttr.get_splat(dynamic_shaped_type, element) except ValueError as e: - # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>) + # CHECK: Expected a static RankedShapedType for the shaped_type parameter: Type(tensor<*xf32>) print(e) try: @@ -364,4 +364,3 @@ # CHECK: {{\[}}[1 2 3] # CHECK: {{\[}}4 5 6]] print(np.array(attr)) - diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -548,7 +548,7 @@ print(attr.strides[2]) attr = StridedLayoutAttr.get_fully_dynamic(3) - dynamic = ShapedType.get_dynamic_stride_or_offset() + dynamic = RankedShapedType.get_dynamic_stride_or_offset() # CHECK: strided<[?, ?, ?], offset: ?> print(attr) # CHECK: offset is dynamic: True diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -254,8 +254,6 @@ vector = VectorType(Type.parse("vector<2x3xf32>")) # CHECK: element type: f32 print("element type:", vector.element_type) - # CHECK: whether the given shaped type is ranked: True - print("whether the given shaped type is ranked:", vector.has_rank) # CHECK: rank: 2 print("rank:", vector.rank) # CHECK: whether the shaped type has a static shape: True @@ -270,7 +268,10 @@ print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1)) # CHECK: isinstance(ShapedType): True print("isinstance(ShapedType):", isinstance(vector, ShapedType)) - + # CHECK: isinstance(RankedShapedType): True + print("isinstance(RankedShapedType):", isinstance(vector, RankedShapedType)) + # CHECK: has_rank: True + print("has_rank:", vector.has_rank) # CHECK-LABEL: TEST: testAbstractShapedType # Tests that ShapedType operates as an abstract base class of a concrete @@ -340,27 +341,6 @@ unranked_tensor = UnrankedTensorType.get(f32) # CHECK: unranked tensor type: tensor<*xf32> print("unranked tensor type:", unranked_tensor) - try: - invalid_rank = unranked_tensor.rank - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") - try: - invalid_is_dynamic_dim = unranked_tensor.is_dynamic_dim(0) - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") - try: - invalid_get_dim_size = unranked_tensor.get_dim_size(1) - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") none = NoneType.get() try: @@ -423,27 +403,6 @@ unranked_memref = UnrankedMemRefType.get(f32, Attribute.parse("2")) # CHECK: unranked memref type: memref<*xf32, 2> print("unranked memref type:", unranked_memref) - try: - invalid_rank = unranked_memref.rank - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") - try: - invalid_is_dynamic_dim = unranked_memref.is_dynamic_dim(0) - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") - try: - invalid_get_dim_size = unranked_memref.get_dim_size(1) - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") none = NoneType.get() try: @@ -501,11 +460,11 @@ print("data:", opaque.data) -# CHECK-LABEL: TEST: testShapedTypeConstants -# Tests that ShapedType exposes magic value constants. +# CHECK-LABEL: TEST: testRankedShapedTypeConstants +# Tests that RankedShapedType exposes magic value constants. @run -def testShapedTypeConstants(): +def testRankedShapedTypeConstants(): # CHECK: - print(type(ShapedType.get_dynamic_size())) + print(type(RankedShapedType.get_dynamic_size())) # CHECK: - print(type(ShapedType.get_dynamic_stride_or_offset())) + print(type(RankedShapedType.get_dynamic_stride_or_offset())) diff --git a/mlir/unittests/Dialect/BroadcastShapeTest.cpp b/mlir/unittests/Dialect/BroadcastShapeTest.cpp --- a/mlir/unittests/Dialect/BroadcastShapeTest.cpp +++ b/mlir/unittests/Dialect/BroadcastShapeTest.cpp @@ -47,7 +47,7 @@ TEST(BroadcastShapeTest, InterleavingUnknowns) { SmallVector result; - int64_t dyn = mlir::ShapedType::kDynamic; + int64_t dyn = mlir::RankedShapedType::kDynamic; ASSERT_TRUE(getBroadcastedShape({1, 2, dyn, dyn, dyn}, {dyn, dyn, dyn, 4, 1}, result)); EXPECT_THAT(result, ElementsAre(dyn, 2, dyn, 4, dyn));