diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h --- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h @@ -125,6 +125,9 @@ unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout); + /// Check if a memref type can be converted to a bare pointer. + bool canConvertToBarePtr(BaseMemRefType type); + protected: /// Pointer to the LLVM dialect. LLVM::LLVMDialect *llvmDialect; @@ -191,8 +194,8 @@ /// These types can be recomposed to a unranked memref descriptor struct. SmallVector getUnrankedMemRefDescriptorFields(); - // Convert an unranked memref type to an LLVM type that captures the - // runtime rank and a pointer to the static ranked memref desc + /// Convert an unranked memref type to an LLVM type that captures the + /// runtime rank and a pointer to the static ranked memref desc Type convertUnrankedMemRefType(UnrankedMemRefType type); /// Convert a memref type to a bare pointer to the memref element type. diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -599,7 +599,8 @@ for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) { Type oldTy = std::get<0>(it).getType(); Value newOperand = std::get<1>(it); - if (oldTy.isa()) { + if (oldTy.isa() && getTypeConverter()->canConvertToBarePtr( + oldTy.cast())) { MemRefDescriptor memrefDesc(newOperand); newOperand = memrefDesc.alignedPtr(rewriter, loc); } else if (oldTy.isa()) { diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -366,30 +366,37 @@ getUnrankedMemRefDescriptorFields()); } -/// Convert a memref type to a bare pointer to the memref element type. -Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) { +// Check if a memref type can be converted to a bare pointer. +bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) { if (type.isa()) // Unranked memref is not supported in the bare pointer calling convention. - return {}; + return false; // Check that the memref has static shape, strides and offset. Otherwise, it // cannot be lowered to a bare pointer. auto memrefTy = type.cast(); if (!memrefTy.hasStaticShape()) - return {}; + return false; int64_t offset = 0; SmallVector strides; if (failed(getStridesAndOffset(memrefTy, strides, offset))) - return {}; + return false; for (int64_t stride : strides) if (ShapedType::isDynamicStrideOrOffset(stride)) - return {}; + return false; if (ShapedType::isDynamicStrideOrOffset(offset)) - return {}; + return false; + + return true; +} +/// Convert a memref type to a bare pointer to the memref element type. +Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) { + if (!canConvertToBarePtr(type)) + return {}; Type elementType = convertType(type.getElementType()); if (!elementType) return {}; diff --git a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir --- a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir +++ b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir @@ -86,3 +86,12 @@ // BAREPTR-NEXT: llvm.return %[[res]] : !llvm.ptr return %res : memref<20xi8> } + +// ----- + +// BAREPTR-LABEL: func @check_return( +// BAREPTR-SAME: %{{.*}}: memref) -> memref +func @check_return(%in : memref) -> memref { + // BAREPTR: llvm.return {{.*}} : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + return %in : memref +}