diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -252,7 +252,11 @@ /// If this is ranked type, return the size of the specified dimension. /// Otherwise, abort. - int64_t getDimSize(int64_t i) const; + int64_t getDimSize(unsigned idx) const; + + /// Returns true if this dimension has a dynamic size (for ranked types); + /// aborts for unranked types. + bool isDynamicDim(unsigned idx) const; /// Returns the position of the dynamic dimension relative to just the dynamic /// dimensions, given its `index` within the shape. @@ -276,7 +280,9 @@ } /// Whether the given dimension size indicates a dynamic dimension. - static constexpr bool isDynamic(int64_t dSize) { return dSize < 0; } + static constexpr bool isDynamic(int64_t dSize) { + return dSize == kDynamicSize; + } static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) { return dStrideOrOffset == kDynamicStrideOrOffset; } diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -330,11 +330,10 @@ if (addMemRefDimBounds) { auto memRefType = memref.getType().cast(); for (unsigned r = 0; r < rank; r++) { - cst.addConstantLowerBound(r, 0); - int64_t dimSize = memRefType.getDimSize(r); - if (ShapedType::isDynamic(dimSize)) + cst.addConstantLowerBound(/*pos=*/r, /*lb=*/0); + if (memRefType.isDynamicDim(r)) continue; - cst.addConstantUpperBound(r, dimSize - 1); + cst.addConstantUpperBound(/*pos=*/r, memRefType.getDimSize(r) - 1); } } cst.removeTrivialRedundancy(); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1805,16 +1805,15 @@ OperandAdaptor transformed(operands); MemRefType type = dimOp.getOperand().getType().cast(); - auto shape = type.getShape(); int64_t index = dimOp.getIndex(); // Extract dynamic size from the memref descriptor. - if (ShapedType::isDynamic(shape[index])) + if (type.isDynamicDim(index)) rewriter.replaceOp(op, {MemRefDescriptor(transformed.memrefOrTensor()) .size(rewriter, op->getLoc(), index)}); else // Use constant for static size. - rewriter.replaceOp( - op, createIndexConstant(rewriter, op->getLoc(), shape[index])); + rewriter.replaceOp(op, createIndexConstant(rewriter, op->getLoc(), + type.getDimSize(index))); return success(); } }; diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -133,7 +133,7 @@ unsigned index) { auto memRefType = memrefDefOp.getType(); // Statically shaped. - if (!ShapedType::isDynamic(memRefType.getDimSize(index))) + if (!memRefType.isDynamicDim(index)) return true; // Get the position of the dimension among dynamic dimensions; unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1068,14 +1068,14 @@ OpFoldResult DimOp::fold(ArrayRef operands) { // Constant fold dim when the size along the index referred to is a constant. auto opType = memrefOrTensor().getType(); - int64_t indexSize = -1; + int64_t dimSize = -1; if (auto tensorType = opType.dyn_cast()) - indexSize = tensorType.getShape()[getIndex()]; + dimSize = tensorType.getShape()[getIndex()]; else if (auto memrefType = opType.dyn_cast()) - indexSize = memrefType.getShape()[getIndex()]; + dimSize = memrefType.getShape()[getIndex()]; - if (!ShapedType::isDynamic(indexSize)) - return IntegerAttr::get(IndexType::get(getContext()), indexSize); + if (!ShapedType::isDynamic(dimSize)) + return IntegerAttr::get(IndexType::get(getContext()), dimSize); // Fold dim to the size argument for an AllocOp/ViewOp/SubViewOp. auto memrefType = opType.dyn_cast(); @@ -2310,13 +2310,12 @@ static LogicalResult verifyDynamicStrides(MemRefType memrefType, ArrayRef strides) { - ArrayRef shape = memrefType.getShape(); unsigned rank = memrefType.getRank(); assert(rank == strides.size()); bool dynamicStrides = false; for (int i = rank - 2; i >= 0; --i) { // If size at dim 'i + 1' is dynamic, set the 'dynamicStrides' flag. - if (ShapedType::isDynamic(shape[i + 1])) + if (memrefType.isDynamicDim(i + 1)) dynamicStrides = true; // If stride at dim 'i' is not dynamic, return error. if (dynamicStrides && strides[i] != MemRefType::getDynamicStrideOrOffset()) diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -184,9 +184,14 @@ bool ShapedType::hasRank() const { return !isa(); } -int64_t ShapedType::getDimSize(int64_t i) const { - assert(i >= 0 && i < getRank() && "invalid index for shaped type"); - return getShape()[i]; +int64_t ShapedType::getDimSize(unsigned idx) const { + assert(idx < getRank() && "invalid index for shaped type"); + return getShape()[idx]; +} + +bool ShapedType::isDynamicDim(unsigned idx) const { + assert(idx < getRank() && "invalid index for shaped type"); + return isDynamic(getShape()[idx]); } unsigned ShapedType::getDynamicDimIndex(unsigned index) const {