diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -436,6 +436,10 @@ SmallVectorImpl &strides, int64_t &offset); +/// Wrapper around getStridesAndOffset(MemRefType, SmallVectorImpl, +/// int64_t) that will assert if the logical result is not succeeded. +std::pair, int64_t> getStridesAndOffset(MemRefType t); + /// Return a version of `t` with identity layout if it can be determined /// statically that the layout is the canonical contiguous strided layout. /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp --- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp @@ -52,11 +52,7 @@ assert(type.hasStaticShape() && "unexpected dynamic shape"); // Extract all strides and offsets and verify they are static. - int64_t offset; - SmallVector strides; - auto result = getStridesAndOffset(type, strides, offset); - (void)result; - assert(succeeded(result) && "unexpected failure in stride computation"); + auto [strides, offset] = getStridesAndOffset(type); assert(!ShapedType::isDynamic(offset) && "expected static offset"); assert(!llvm::any_of(strides, ShapedType::isDynamic) && diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -69,11 +69,7 @@ Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const { - int64_t offset; - SmallVector strides; - auto successStrides = getStridesAndOffset(type, strides, offset); - assert(succeeded(successStrides) && "unexpected non-strided memref"); - (void)successStrides; + auto [strides, offset] = getStridesAndOffset(type); MemRefDescriptor memRefDescriptor(memRefDesc); Value base = memRefDescriptor.alignedPtr(rewriter, loc); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2538,11 +2538,7 @@ assert(staticStrides.size() == rank && "staticStrides length mismatch"); // Extract source offset and strides. - int64_t sourceOffset; - SmallVector sourceStrides; - auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); - assert(succeeded(res) && "SubViewOp expected strided memref type"); - (void)res; + auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceMemRefType); // Compute target offset whose value is: // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. @@ -3098,12 +3094,8 @@ AffineMap permutationMap) { auto rank = memRefType.getRank(); auto originalSizes = memRefType.getShape(); - int64_t offset; - SmallVector originalStrides; - auto res = getStridesAndOffset(memRefType, originalStrides, offset); - assert(succeeded(res) && - originalStrides.size() == static_cast(rank)); - (void)res; + auto [originalStrides, offset] = getStridesAndOffset(memRefType); + assert(originalStrides.size() == static_cast(rank)); // Compute permuted sizes and strides. SmallVector sizes(rank, 0); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -60,14 +60,7 @@ auto newExtractStridedMetadata = rewriter.create(origLoc, source); - SmallVector sourceStrides; - int64_t sourceOffset; - - bool hasKnownStridesAndOffset = - succeeded(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)); - (void)hasKnownStridesAndOffset; - assert(hasKnownStridesAndOffset && - "getStridesAndOffset must work on valid subviews"); + auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceType); // Compute the new strides and offset from the base strides and offset: // newStride#i = baseStride#i * subStride#i @@ -265,13 +258,7 @@ // Collect the statically known information about the original stride. Value source = expandShape.getSrc(); auto sourceType = source.getType().cast(); - SmallVector strides; - int64_t offset; - bool hasKnownStridesAndOffset = - succeeded(getStridesAndOffset(sourceType, strides, offset)); - (void)hasKnownStridesAndOffset; - assert(hasKnownStridesAndOffset && - "getStridesAndOffset must work on valid expand_shape"); + auto [strides, offset] = getStridesAndOffset(sourceType); OpFoldResult origStride = ShapedType::isDynamic(strides[groupId]) @@ -414,13 +401,7 @@ Value source = collapseShape.getSrc(); auto sourceType = source.getType().cast(); - SmallVector strides; - int64_t offset; - bool hasKnownStridesAndOffset = - succeeded(getStridesAndOffset(sourceType, strides, offset)); - (void)hasKnownStridesAndOffset; - assert(hasKnownStridesAndOffset && - "getStridesAndOffset must work on valid collapse_shape"); + auto [strides, offset] = getStridesAndOffset(sourceType); SmallVector collapsedStride; int64_t innerMostDimForGroup = reassocGroup.back(); @@ -473,13 +454,7 @@ rewriter.create(origLoc, source); // Collect statically known information. - SmallVector strides; - int64_t offset; - bool hasKnownStridesAndOffset = - succeeded(getStridesAndOffset(sourceType, strides, offset)); - (void)hasKnownStridesAndOffset; - assert(hasKnownStridesAndOffset && - "getStridesAndOffset must work on valid reassociative_reshape_like"); + auto [strides, offset] = getStridesAndOffset(sourceType); MemRefType reshapeType = reshape.getResultType(); unsigned reshapeRank = reshapeType.getRank(); diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -804,6 +804,16 @@ return success(); } +std::pair, int64_t> +mlir::getStridesAndOffset(MemRefType t) { + SmallVector strides; + int64_t offset; + LogicalResult status = getStridesAndOffset(t, strides, offset); + (void)status; + assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset"); + return {strides, offset}; +} + //===----------------------------------------------------------------------===// /// TupleType //===----------------------------------------------------------------------===//