diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -90,10 +90,6 @@ Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector); -/// Return true if the last dimension of the MemRefType has unit stride. Also -/// return true for memrefs with no strides. -bool isLastMemrefDimUnitStride(MemRefType type); - /// Build the default minor identity map suitable for a vector transfer. This /// also handles the case memref<... x vector<...>> -> vector<...> in which the /// rank of the identity map must take the vector element type into account. diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -534,9 +534,13 @@ AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef sizes, MLIRContext *context); -/// Return true if the layout for `t` is compatible with strided semantics. +/// Return "true" if the layout for `t` is compatible with strided semantics. bool isStrided(MemRefType t); +/// Return "true" if the last dimension of the given type has a static unit +/// stride. Also return "true" for types with no strides. +bool isLastMemrefDimUnitStride(MemRefType type); + } // namespace mlir #endif // MLIR_IR_BUILTINTYPES_H diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -92,13 +92,11 @@ // Check if the last stride is non-unit or the memory space is not zero. static LogicalResult isMemRefTypeSupported(MemRefType memRefType, LLVMTypeConverter &converter) { - int64_t offset; - SmallVector strides; - auto successStrides = getStridesAndOffset(memRefType, strides, offset); + if (!isLastMemrefDimUnitStride(memRefType)) + return failure(); FailureOr addressSpace = converter.getMemRefAddressSpace(memRefType); - if (failed(successStrides) || strides.back() != 1 || failed(addressSpace) || - *addressSpace != 0) + if (failed(addressSpace) || *addressSpace != 0) return failure(); return success(); } diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -1185,14 +1185,6 @@ } }; -/// Return true if the last dimension of the MemRefType has unit stride. -static bool isLastMemrefDimUnitStride(MemRefType type) { - int64_t offset; - SmallVector strides; - auto successStrides = getStridesAndOffset(type, strides, offset); - return succeeded(successStrides) && (strides.empty() || strides.back() == 1); -} - /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is /// necessary in cases where a 1D vector transfer op cannot be lowered into /// vector load/stores due to non-unit strides or broadcasts: diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1546,17 +1546,6 @@ // GPU_SubgroupMmaLoadMatrixOp //===----------------------------------------------------------------------===// -/// Return true if the last dimension of the MemRefType has unit stride. Also -/// return true for memrefs with no strides. -static bool isLastMemrefDimUnitStride(MemRefType type) { - int64_t offset; - SmallVector strides; - if (failed(getStridesAndOffset(type, strides, offset))) { - return false; - } - return strides.back() == 1; -} - LogicalResult SubgroupMmaLoadMatrixOp::verify() { auto srcType = getSrcMemref().getType(); auto resType = getRes().getType(); diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -53,17 +53,6 @@ // NVGPU_DeviceAsyncCopyOp //===----------------------------------------------------------------------===// -/// Return true if the last dimension of the MemRefType has unit stride. Also -/// return true for memrefs with no strides. -static bool isLastMemrefDimUnitStride(MemRefType type) { - int64_t offset; - SmallVector strides; - if (failed(getStridesAndOffset(type, strides, offset))) { - return false; - } - return strides.back() == 1; -} - LogicalResult DeviceAsyncCopyOp::verify() { auto srcMemref = llvm::cast(getSrc().getType()); auto dstMemref = llvm::cast(getDst().getType()); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -130,15 +130,6 @@ return false; } -/// Return true if the last dimension of the MemRefType has unit stride. Also -/// return true for memrefs with no strides. -bool mlir::vector::isLastMemrefDimUnitStride(MemRefType type) { - int64_t offset; - SmallVector strides; - auto successStrides = getStridesAndOffset(type, strides, offset); - return succeeded(successStrides) && (strides.empty() || strides.back() == 1); -} - AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType) { int64_t elementVectorRank = 0; diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -419,7 +419,7 @@ return rewriter.notifyMatchFailure(read, "not a memref source"); // Non-unit strides are handled by VectorToSCF. - if (!vector::isLastMemrefDimUnitStride(memRefType)) + if (!isLastMemrefDimUnitStride(memRefType)) return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF"); // If there is broadcasting involved then we first load the unbroadcasted @@ -567,7 +567,7 @@ }); // Non-unit strides are handled by VectorToSCF. - if (!vector::isLastMemrefDimUnitStride(memRefType)) + if (!isLastMemrefDimUnitStride(memRefType)) return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { diag << "most minor stride is not 1: " << write; }); diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -956,10 +956,16 @@ return makeCanonicalStridedLayoutExpr(sizes, exprs, context); } -/// Return true if the layout for `t` is compatible with strided semantics. bool mlir::isStrided(MemRefType t) { int64_t offset; SmallVector strides; auto res = getStridesAndOffset(t, strides, offset); return succeeded(res); } + +bool mlir::isLastMemrefDimUnitStride(MemRefType type) { + int64_t offset; + SmallVector strides; + auto successStrides = getStridesAndOffset(type, strides, offset); + return succeeded(successStrides) && (strides.empty() || strides.back() == 1); +}