diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -27,6 +27,7 @@ namespace mlir { +class BaseMemRefType; class ComplexType; class LLVMTypeConverter; class UnrankedMemRefType; @@ -74,15 +75,28 @@ SignatureConversion &result); /// Convert a non-empty list of types to be returned from a function into a - /// supported LLVM IR type. In particular, if more than one values is + /// supported LLVM IR type. In particular, if more than one value is /// returned, create an LLVM IR structure type with elements that correspond /// to each of the MLIR types converted with `convertType`. Type packFunctionResults(ArrayRef types); + /// Convert a type in the context of the default or bare pointer calling + /// convention. Calling convention sensitive types, such as MemRefType and + /// UnrankedMemRefType, are converted following the specific rules for the + /// calling convention. Calling convention independent types are converted + /// following the default LLVM type conversions. + Type convertCallingConventionType(Type type); + + /// Promote the bare pointers in 'values' that resulted from memrefs to + /// descriptors. 'stdTypes' holds the types of 'values' before the conversion + /// to the LLVM-IR dialect (i.e., MemRefType, or any other Standard type). + void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter, + Location loc, ArrayRef stdTypes, + SmallVectorImpl &values); + /// Returns the MLIR context. MLIRContext &getContext(); - /// Returns the LLVM dialect. LLVM::LLVMDialect *getDialect() { return llvmDialect; } @@ -179,6 +193,9 @@ // runtime rank and a pointer to the static ranked memref desc Type convertUnrankedMemRefType(UnrankedMemRefType type); + /// Convert a memref type to a bare pointer to the memref element type. + Type convertMemRefToBarePtr(BaseMemRefType type); + // Convert a 1D vector type into an LLVM vector type. Type convertVectorType(VectorType type); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -80,37 +80,12 @@ return success(); } -/// Convert a MemRef type to a bare pointer to the MemRef element type. -static Type convertMemRefTypeToBarePtr(LLVMTypeConverter &converter, - MemRefType type) { - int64_t offset; - SmallVector strides; - if (failed(getStridesAndOffset(type, strides, offset))) - return {}; - - LLVM::LLVMType elementType = - unwrap(converter.convertType(type.getElementType())); - if (!elementType) - return {}; - return elementType.getPointerTo(type.getMemorySpace()); -} - /// Callback to convert function argument types. It converts MemRef function /// arguments to bare pointers to the MemRef element type. LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, SmallVectorImpl &result) { - // TODO: Add support for unranked memref. - if (auto memrefTy = type.dyn_cast()) { - auto llvmTy = convertMemRefTypeToBarePtr(converter, memrefTy); - if (!llvmTy) - return failure(); - - result.push_back(llvmTy); - return success(); - } - - auto llvmTy = converter.convertType(type); + auto llvmTy = converter.convertCallingConventionType(type); if (!llvmTy) return failure(); @@ -272,14 +247,14 @@ // Function has one VoidType result. If MLIR Function has more than one result, // they are into an LLVM StructType in their order of appearance. LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature( - FunctionType type, bool isVariadic, + FunctionType funcTy, bool isVariadic, LLVMTypeConverter::SignatureConversion &result) { // Select the argument converter depending on the calling convetion. auto funcArgConverter = options.useBarePtrCallConv ? barePtrFuncArgTypeConverter : structFuncArgTypeConverter; // Convert argument types one by one and check for errors. - for (auto &en : llvm::enumerate(type.getInputs())) { + for (auto &en : llvm::enumerate(funcTy.getInputs())) { Type type = en.value(); SmallVector converted; if (failed(funcArgConverter(*this, type, converted))) @@ -296,9 +271,9 @@ // if it returns on element, convert it, otherwise pack the result types into // a struct. LLVM::LLVMType resultType = - type.getNumResults() == 0 + funcTy.getNumResults() == 0 ? LLVM::LLVMType::getVoidTy(&getContext()) - : unwrap(packFunctionResults(type.getResults())); + : unwrap(packFunctionResults(funcTy.getResults())); if (!resultType) return {}; return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic); @@ -394,6 +369,36 @@ return LLVM::LLVMType::getStructTy(rankTy, ptrTy); } +/// Convert a memref type to a bare pointer to the memref element type. +Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) { + if (type.isa()) + // Unranked memref is not supported in the bare pointer calling convention. + return {}; + + // Check that the memref has static shape, strides and offset. Otherwise, it + // cannot be lowered to a bare pointer. + auto memrefTy = type.cast(); + if (!memrefTy.hasStaticShape()) + return {}; + + int64_t offset = 0; + SmallVector strides; + if (failed(getStridesAndOffset(memrefTy, strides, offset))) + return {}; + + for (int64_t stride : strides) + if (ShapedType::isDynamicStrideOrOffset(stride)) + return {}; + + if (ShapedType::isDynamicStrideOrOffset(offset)) + return {}; + + LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); + if (!elementType) + return {}; + return elementType.getPointerTo(type.getMemorySpace()); +} + // Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when // n > 1. // For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and @@ -410,6 +415,37 @@ return vectorType; } +/// Convert a type in the context of the default or bare pointer calling +/// convention. Calling convention sensitive types, such as MemRefType and +/// UnrankedMemRefType, are converted following the specific rules for the +/// calling convention. Calling convention independent types are converted +/// following the default LLVM type conversions. +Type LLVMTypeConverter::convertCallingConventionType(Type type) { + if (options.useBarePtrCallConv) + if (auto memrefTy = type.dyn_cast()) + return convertMemRefToBarePtr(memrefTy); + + return convertType(type); +} + +/// Promote the bare pointers in 'values' that resulted from memrefs to +/// descriptors. 'stdTypes' holds they types of 'values' before the conversion +/// to the LLVM-IR dialect (i.e., MemRefType, or any other Standard type). +void LLVMTypeConverter::promoteBarePtrsToDescriptors( + ConversionPatternRewriter &rewriter, Location loc, ArrayRef stdTypes, + SmallVectorImpl &values) { + assert(stdTypes.size() == values.size() && + "The number of types and values doesn't match"); + for (unsigned i = 0, end = values.size(); i < end; ++i) { + Type stdTy = stdTypes[i]; + if (auto memrefTy = stdTy.dyn_cast()) + values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this, + memrefTy, values[i]); + else + llvm_unreachable("Unranked memrefs are not supported"); + } +} + ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &typeConverter, @@ -1088,18 +1124,6 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern { protected: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - using UnsignedTypePair = std::pair; - - // Gather the positions and types of memref-typed arguments in a given - // FunctionType. - void getMemRefArgIndicesAndTypes( - FunctionType type, SmallVectorImpl &argsInfo) const { - argsInfo.reserve(type.getNumInputs()); - for (auto en : llvm::enumerate(type.getInputs())) { - if (en.value().isa()) - argsInfo.push_back({en.index(), en.value()}); - } - } // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided // to this legalization pattern. @@ -1192,11 +1216,10 @@ ConversionPatternRewriter &rewriter) const override { auto funcOp = cast(op); - // Store the positions and type of memref-typed arguments so that we can - // promote them to MemRef descriptor structs at the beginning of the - // function. - SmallVector promotedArgsInfo; - getMemRefArgIndicesAndTypes(funcOp.getType(), promotedArgsInfo); + // Store the type of memref-typed arguments before the conversion so that we + // can promote them to MemRef descriptor at the beginning of the function. + SmallVector oldArgTypes = + llvm::to_vector<8>(funcOp.getType().getInputs()); auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); if (!newFuncOp) @@ -1206,27 +1229,42 @@ return success(); } - // Promote bare pointers from MemRef arguments to a MemRef descriptor struct - // at the beginning of the function so that all the MemRefs in the function - // have a uniform representation. - Block *firstBlock = &newFuncOp.getBody().front(); - rewriter.setInsertionPoint(firstBlock, firstBlock->begin()); - auto funcLoc = funcOp.getLoc(); - for (const auto &argInfo : promotedArgsInfo) { - // TODO: Add support for unranked MemRefs. - if (auto memrefType = argInfo.second.dyn_cast()) { - // Replace argument with a placeholder (undef), promote argument to a - // MemRef descriptor and replace placeholder with the last instruction - // of the MemRef descriptor. The placeholder is needed to avoid - // replacing argument uses in the MemRef descriptor instructions. - BlockArgument arg = firstBlock->getArgument(argInfo.first); - Value placeHolder = - rewriter.create(funcLoc, arg.getType()); - rewriter.replaceUsesOfBlockArgument(arg, placeHolder); - auto desc = MemRefDescriptor::fromStaticShape( - rewriter, funcLoc, typeConverter, memrefType, arg); - rewriter.replaceOp(placeHolder.getDefiningOp(), {desc}); - } + // Promote bare pointers from memref arguments to memref descriptors at the + // beginning of the function so that all the memrefs in the function have a + // uniform representation. + Block *entryBlock = &newFuncOp.getBody().front(); + auto blockArgs = entryBlock->getArguments(); + assert(blockArgs.size() == oldArgTypes.size() && + "The number of arguments and types doesn't match"); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(entryBlock); + for (auto it : llvm::zip(blockArgs, oldArgTypes)) { + BlockArgument arg = std::get<0>(it); + Type argTy = std::get<1>(it); + + // Unranked memrefs are not supported in the bare pointer calling + // convention. We should have bailed out before in the presence of + // unranked memrefs. + assert(!argTy.isa() && + "Unranked memref is not supported"); + auto memrefTy = argTy.dyn_cast(); + if (!memrefTy) + continue; + + // Replace barePtr with a placeholder (undef), promote barePtr to a ranked + // or unranked memref descriptor and replace placeholder with the last + // instruction of the memref descriptor. + // TODO: The placeholder is needed to avoid replacing barePtr uses in the + // MemRef descriptor instructions. We may want to have a utility in the + // rewriter to properly handle this use case. + Location loc = op->getLoc(); + auto placeholder = rewriter.create(loc, memrefTy); + rewriter.replaceUsesOfBlockArgument(arg, placeholder); + + Value desc = MemRefDescriptor::fromStaticShape( + rewriter, loc, typeConverter, memrefTy, arg); + rewriter.replaceOp(placeholder, {desc}); } rewriter.eraseOp(op); @@ -2138,12 +2176,22 @@ rewriter.getI64ArrayAttr(i))); } } - if (failed(copyUnrankedDescriptors( - rewriter, op->getLoc(), this->typeConverter, op->getResultTypes(), - results, /*toDynamic=*/false))) + + if (this->typeConverter.getOptions().useBarePtrCallConv) { + // For the bare-ptr calling convention, promote memref results to + // descriptors. + assert(results.size() == resultTypes.size() && + "The number of arguments and types doesn't match"); + this->typeConverter.promoteBarePtrsToDescriptors(rewriter, op->getLoc(), + resultTypes, results); + } else if (failed(copyUnrankedDescriptors(rewriter, op->getLoc(), + this->typeConverter, resultTypes, + results, + /*toDynamic=*/false))) { return failure(); - rewriter.replaceOp(op, results); + } + rewriter.replaceOp(op, results); return success(); } }; @@ -2706,11 +2754,32 @@ LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); unsigned numArguments = op->getNumOperands(); - auto updatedOperands = llvm::to_vector<4>(operands); - copyUnrankedDescriptors(rewriter, op->getLoc(), typeConverter, - op->getOperands().getTypes(), updatedOperands, - /*toDynamic=*/true); + SmallVector updatedOperands; + + if (typeConverter.getOptions().useBarePtrCallConv) { + // For the bare-ptr calling convention, extract the aligned pointer to + // be returned from the memref descriptor. + for (auto it : llvm::zip(op->getOperands(), operands)) { + Type oldTy = std::get<0>(it).getType(); + Value newOperand = std::get<1>(it); + if (oldTy.isa()) { + MemRefDescriptor memrefDesc(newOperand); + newOperand = memrefDesc.alignedPtr(rewriter, loc); + } else if (oldTy.isa()) { + // Unranked memref is not supported in the bare pointer calling + // convention. + return failure(); + } + updatedOperands.push_back(newOperand); + } + } else { + updatedOperands = llvm::to_vector<4>(operands); + copyUnrankedDescriptors(rewriter, loc, typeConverter, + op->getOperands().getTypes(), updatedOperands, + /*toDynamic=*/true); + } // If ReturnOp has 0 or 1 operand, create it and return immediately. if (numArguments == 0) { @@ -2729,10 +2798,10 @@ auto packedType = typeConverter.packFunctionResults( llvm::to_vector<4>(op->getOperandTypes())); - Value packed = rewriter.create(op->getLoc(), packedType); + Value packed = rewriter.create(loc, packedType); for (unsigned i = 0; i < numArguments; ++i) { packed = rewriter.create( - op->getLoc(), packedType, packed, updatedOperands[i], + loc, packedType, packed, updatedOperands[i], rewriter.getI64ArrayAttr(i)); } rewriter.replaceOpWithNewOp(op, TypeRange(), packed, @@ -3380,17 +3449,21 @@ populateStdToLLVMMemoryConversionPatterns(converter, patterns); } -// Create an LLVM IR structure type if there is more than one result. +/// Convert a non-empty list of types to be returned from a function into a +/// supported LLVM IR type. In particular, if more than one value is returned, +/// create an LLVM IR structure type with elements that correspond to each of +/// the MLIR types converted with `convertType`. Type LLVMTypeConverter::packFunctionResults(ArrayRef types) { assert(!types.empty() && "expected non-empty list of type"); if (types.size() == 1) - return convertType(types.front()); + return convertCallingConventionType(types.front()); SmallVector resultTypes; resultTypes.reserve(types.size()); for (auto t : types) { - auto converted = convertType(t).dyn_cast_or_null(); + auto converted = + convertCallingConventionType(t).dyn_cast_or_null(); if (!converted) return {}; resultTypes.push_back(converted); @@ -3426,16 +3499,27 @@ auto operand = std::get<0>(it); auto llvmOperand = std::get<1>(it); - if (operand.getType().isa()) { - UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand, - promotedOperands); - continue; - } - if (auto memrefType = operand.getType().dyn_cast()) { - MemRefDescriptor::unpack(builder, loc, llvmOperand, - operand.getType().cast(), - promotedOperands); - continue; + if (options.useBarePtrCallConv) { + // For the bare-ptr calling convention, we only have to extract the + // aligned pointer of a memref. + if (auto memrefType = operand.getType().dyn_cast()) { + MemRefDescriptor desc(llvmOperand); + llvmOperand = desc.alignedPtr(builder, loc); + } else if (operand.getType().isa()) { + llvm_unreachable("Unranked memrefs are not supported"); + } + } else { + if (operand.getType().isa()) { + UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand, + promotedOperands); + continue; + } + if (auto memrefType = operand.getType().dyn_cast()) { + MemRefDescriptor::unpack(builder, loc, llvmOperand, + operand.getType().cast(), + promotedOperands); + continue; + } } promotedOperands.push_back(llvmOperand); diff --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir @@ -14,13 +14,13 @@ // CHECK-COUNT-5: !llvm.i64 // CHECK-SAME: -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-LABEL: func @check_static_return -// BAREPTR-SAME: (%[[arg:.*]]: !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { +// BAREPTR-SAME: (%[[arg:.*]]: !llvm.ptr) -> !llvm.ptr { func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> { // CHECK: llvm.return %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR: %[[udf:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// BAREPTR-NEXT: %[[base:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// BAREPTR-NEXT: %[[base0:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // BAREPTR-NEXT: %[[ins0:.*]] = llvm.insertvalue %[[val0]], %[[aligned]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 @@ -31,7 +31,8 @@ // BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// BAREPTR-NEXT: llvm.return %[[ins4]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// BAREPTR-NEXT: %[[base1:.*]] = llvm.extractvalue %[[ins4]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// BAREPTR-NEXT: llvm.return %[[base1]] : !llvm.ptr return %static : memref<32x18xf32> } @@ -42,13 +43,13 @@ // CHECK-COUNT-5: !llvm.i64 // CHECK-SAME: -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-LABEL: func @check_static_return_with_offset -// BAREPTR-SAME: (%[[arg:.*]]: !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { +// BAREPTR-SAME: (%[[arg:.*]]: !llvm.ptr) -> !llvm.ptr { func @check_static_return_with_offset(%static : memref<32x18xf32, offset:7, strides:[22,1]>) -> memref<32x18xf32, offset:7, strides:[22,1]> { // CHECK: llvm.return %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR: %[[udf:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// BAREPTR-NEXT: %[[base:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// BAREPTR-NEXT: %[[base0:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val0:.*]] = llvm.mlir.constant(7 : index) : !llvm.i64 // BAREPTR-NEXT: %[[ins0:.*]] = llvm.insertvalue %[[val0]], %[[aligned]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 @@ -59,14 +60,15 @@ // BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// BAREPTR-NEXT: llvm.return %[[ins4]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// BAREPTR-NEXT: %[[base1:.*]] = llvm.extractvalue %[[ins4]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// BAREPTR-NEXT: llvm.return %[[base1]] : !llvm.ptr return %static : memref<32x18xf32, offset:7, strides:[22,1]> } // ----- // CHECK-LABEL: func @zero_d_alloc() -> !llvm.struct<(ptr, ptr, i64)> { -// BAREPTR-LABEL: func @zero_d_alloc() -> !llvm.struct<(ptr, ptr, i64)> { +// BAREPTR-LABEL: func @zero_d_alloc() -> !llvm.ptr { func @zero_d_alloc() -> memref { // CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr @@ -174,7 +176,7 @@ // ----- // CHECK-LABEL: func @static_alloc() -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { -// BAREPTR-LABEL: func @static_alloc() -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { +// BAREPTR-LABEL: func @static_alloc() -> !llvm.ptr { func @static_alloc() -> memref<32x18xf32> { // CHECK: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 // CHECK-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 @@ -388,3 +390,29 @@ %4 = dim %static, %c4 : memref<42x32x15x13x27xf32> return } + +// ----- + +// BAREPTR: llvm.func @foo(!llvm.ptr) -> !llvm.ptr +func @foo(memref<10xi8>) -> memref<20xi8> + +// BAREPTR-LABEL: func @check_memref_func_call +// BAREPTR-SAME: %[[in:.*]]: !llvm.ptr) -> !llvm.ptr +func @check_memref_func_call(%in : memref<10xi8>) -> memref<20xi8> { + // BAREPTR: %[[inDesc:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] + // BAREPTR-NEXT: %[[barePtr:.*]] = llvm.extractvalue %[[inDesc]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // BAREPTR-NEXT: %[[call:.*]] = llvm.call @foo(%[[barePtr]]) : (!llvm.ptr) -> !llvm.ptr + // BAREPTR-NEXT: %[[desc0:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // BAREPTR-NEXT: %[[desc1:.*]] = llvm.insertvalue %[[call]], %[[desc0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // BAREPTR-NEXT: %[[desc2:.*]] = llvm.insertvalue %[[call]], %[[desc1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // BAREPTR-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // BAREPTR-NEXT: %[[desc4:.*]] = llvm.insertvalue %[[c0]], %[[desc2]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // BAREPTR-NEXT: %[[c20:.*]] = llvm.mlir.constant(20 : index) : !llvm.i64 + // BAREPTR-NEXT: %[[desc6:.*]] = llvm.insertvalue %[[c20]], %[[desc4]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // BAREPTR-NEXT: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 + // BAREPTR-NEXT: %[[outDesc:.*]] = llvm.insertvalue %[[c1]], %[[desc6]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %res = call @foo(%in) : (memref<10xi8>) -> (memref<20xi8>) + // BAREPTR-NEXT: %[[res:.*]] = llvm.extractvalue %[[outDesc]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // BAREPTR-NEXT: llvm.return %[[res]] : !llvm.ptr + return %res : memref<20xi8> +}