diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -517,7 +517,7 @@ ConversionPatternRewriter &rewriter) const; /// Returns if the givem memref type is supported. - bool isSupportedMemRefType(MemRefType type) const; + bool isConvertibleAndHasIdentityMaps(MemRefType type) const; /// Returns the type of a pointer to an element of the memref. Type getElementPtrType(MemRefType type) const; diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -348,7 +348,7 @@ MemRefType memRefType = allocOp.getType(); if (failed(areAllLLVMTypes(allocOp, operands, rewriter)) || - !isSupportedMemRefType(memRefType) || + !isConvertibleAndHasIdentityMaps(memRefType) || failed(isAsyncWithOneDependency(rewriter, allocOp))) return failure(); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1089,7 +1089,8 @@ // Check if the MemRefType `type` is supported by the lowering. We currently // only support memrefs with identity maps. -bool ConvertToLLVMPattern::isSupportedMemRefType(MemRefType type) const { +bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps( + MemRefType type) const { if (!typeConverter->convertType(type.getElementType())) return false; return type.getAffineMaps().empty() || @@ -1107,7 +1108,7 @@ Location loc, MemRefType memRefType, ArrayRef dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl &sizes, SmallVectorImpl &strides, Value &sizeBytes) const { - assert(isSupportedMemRefType(memRefType) && + assert(isConvertibleAndHasIdentityMaps(memRefType) && "layout maps must have been normalized away"); sizes.reserve(memRefType.getRank()); @@ -1977,7 +1978,7 @@ LogicalResult match(Operation *op) const override { MemRefType memRefType = getMemRefResultType(op); - return success(isSupportedMemRefType(memRefType)); + return success(isConvertibleAndHasIdentityMaps(memRefType)); } // An `alloc` is converted into a definition of a memref descriptor value and @@ -2411,7 +2412,7 @@ matchAndRewrite(GlobalMemrefOp global, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { MemRefType type = global.type().cast(); - if (!isSupportedMemRefType(type)) + if (!isConvertibleAndHasIdentityMaps(type)) return failure(); LLVM::LLVMType arrayTy = @@ -3027,12 +3028,12 @@ template struct LoadStoreOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - using ConvertOpToLLVMPattern::isSupportedMemRefType; + using ConvertOpToLLVMPattern::isConvertibleAndHasIdentityMaps; using Base = LoadStoreOpLowering; LogicalResult match(Derived op) const override { MemRefType type = op.getMemRefType(); - return isSupportedMemRefType(type) ? success() : failure(); + return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); } };