diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -27,6 +27,7 @@ namespace mlir { +class BaseMemRefType; class ComplexType; class LLVMTypeConverter; class UnrankedMemRefType; @@ -74,15 +75,38 @@ SignatureConversion &result); /// Convert a non-empty list of types to be returned from a function into a - /// supported LLVM IR type. In particular, if more than one values is + /// supported LLVM IR type. In particular, if more than one value is /// returned, create an LLVM IR structure type with elements that correspond /// to each of the MLIR types converted with `convertType`. Type packFunctionResults(ArrayRef types); + /// Convert a type in the context of the default or bare pointer calling + /// convention. Calling convention sensitive types, such as MemRefType and + /// UnrankedMemRefType, are converted following the specific rules for the + /// calling convention. Calling convention independent types are converted + /// following the default LLVM type conversions. + Type convertCallingConventionType(Type type); + + /// Promote the bare pointer resulting from an UnrankedMemRefType to a + /// UnrankedMemRefDescriptor struct. The rank of the descriptor is + /// initialized to one since the bare pointer calling convention do not + /// preserve dynamic shape information. + Value promoteBarePtrToUnrankedMemRefDesc(Value barePtr, + UnrankedMemRefType memrefTy, + Location loc, + ConversionPatternRewriter &rewriter); + + /// Promote the bare pointers in 'values' that resulted from ranked and + /// unranked memrefs to their corresponding descriptors. 'stdTypes' holds + /// values' types before the conversion to the LLVM-IR dialect (i.e., + /// MemRefType, UnrankedMemRefType or any other Standard type). + void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter, + Location loc, ArrayRef stdTypes, + SmallVectorImpl &values); + /// Returns the MLIR context. MLIRContext &getContext(); - /// Returns the LLVM dialect. LLVM::LLVMDialect *getDialect() { return llvmDialect; } @@ -179,6 +203,10 @@ // runtime rank and a pointer to the static ranked memref desc Type convertUnrankedMemRefType(UnrankedMemRefType type); + /// Convert a ranked or unranked memref type to a bare pointer to the memref + /// element type. + Type convertMemRefToBarePtr(BaseMemRefType type); + // Convert a 1D vector type into an LLVM vector type. Type convertVectorType(VectorType type); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -52,6 +52,14 @@ return wrappedLLVMType; } +// Creates a constant Op producing a value of `resultType` from an index-typed +// integer attribute. +static Value createIndexAttrConstant(OpBuilder &builder, Location loc, + Type resultType, int64_t value) { + return builder.create( + loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); +} + /// Callback to convert function argument types. It converts a MemRef function /// argument to a list of non-aggregate types containing descriptor /// information, and an UnrankedmemRef function argument to a list containing @@ -80,37 +88,12 @@ return success(); } -/// Convert a MemRef type to a bare pointer to the MemRef element type. -static Type convertMemRefTypeToBarePtr(LLVMTypeConverter &converter, - MemRefType type) { - int64_t offset; - SmallVector strides; - if (failed(getStridesAndOffset(type, strides, offset))) - return {}; - - LLVM::LLVMType elementType = - unwrap(converter.convertType(type.getElementType())); - if (!elementType) - return {}; - return elementType.getPointerTo(type.getMemorySpace()); -} - /// Callback to convert function argument types. It converts MemRef function /// arguments to bare pointers to the MemRef element type. LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, SmallVectorImpl &result) { - // TODO: Add support for unranked memref. - if (auto memrefTy = type.dyn_cast()) { - auto llvmTy = convertMemRefTypeToBarePtr(converter, memrefTy); - if (!llvmTy) - return failure(); - - result.push_back(llvmTy); - return success(); - } - - auto llvmTy = converter.convertType(type); + auto llvmTy = converter.convertCallingConventionType(type); if (!llvmTy) return failure(); @@ -272,14 +255,14 @@ // Function has one VoidType result. If MLIR Function has more than one result, // they are into an LLVM StructType in their order of appearance. LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature( - FunctionType type, bool isVariadic, + FunctionType funcTy, bool isVariadic, LLVMTypeConverter::SignatureConversion &result) { // Select the argument converter depending on the calling convetion. auto funcArgConverter = options.useBarePtrCallConv ? barePtrFuncArgTypeConverter : structFuncArgTypeConverter; // Convert argument types one by one and check for errors. - for (auto &en : llvm::enumerate(type.getInputs())) { + for (auto &en : llvm::enumerate(funcTy.getInputs())) { Type type = en.value(); SmallVector converted; if (failed(funcArgConverter(*this, type, converted))) @@ -296,9 +279,9 @@ // if it returns on element, convert it, otherwise pack the result types into // a struct. LLVM::LLVMType resultType = - type.getNumResults() == 0 + funcTy.getNumResults() == 0 ? LLVM::LLVMType::getVoidTy(&getContext()) - : unwrap(packFunctionResults(type.getResults())); + : unwrap(packFunctionResults(funcTy.getResults())); if (!resultType) return {}; return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic); @@ -394,6 +377,24 @@ return LLVM::LLVMType::getStructTy(rankTy, ptrTy); } +/// Convert a ranked or unranked memref type to a bare pointer to the memref +/// element type. +Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) { + if (auto memrefTy = type.dyn_cast()) { + // Check that the memref shape can be computed statically. Otherwise, it + // wouldn't be safe to lower a MemRefType to a bare pointer. + int64_t offset; + SmallVector strides; + if (failed(getStridesAndOffset(memrefTy, strides, offset))) + return {}; + } + + LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); + if (!elementType) + return {}; + return elementType.getPointerTo(type.getMemorySpace()); +} + // Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when // n > 1. // For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and @@ -410,6 +411,51 @@ return vectorType; } +/// Convert a type in the context of the default or bare pointer calling +/// convention. Calling convention sensitive types, such as MemRefType and +/// UnrankedMemRefType, are converted following the specific rules for the +/// calling convention. Calling convention independent types are converted +/// following the default LLVM type conversions. +Type LLVMTypeConverter::convertCallingConventionType(Type type) { + if (options.useBarePtrCallConv) + if (auto memrefTy = type.dyn_cast()) + return convertMemRefToBarePtr(memrefTy); + + return convertType(type); +} + +/// Promote the bare pointer resulting from an UnrankedMemRefType to a +/// UnrankedMemRefDescriptor struct. The rank of the descriptor is +/// initialized to one since the bare-ptr calling convention doesn't +/// preserve dynamic shape information. +Value LLVMTypeConverter::promoteBarePtrToUnrankedMemRefDesc( + Value barePtr, UnrankedMemRefType barePtrTy, Location loc, + ConversionPatternRewriter &rewriter) { + auto oneRank = createIndexAttrConstant(rewriter, loc, getIndexType(), 1); + return UnrankedMemRefDescriptor::pack(rewriter, loc, *this, barePtrTy, + {oneRank, barePtr}); +} + +/// Promote the bare pointers in 'values' that resulted from ranked and unranked +/// memrefs to their corresponding descriptors. 'stdTypes' holds 'values' types +/// before the conversion to the LLVM-IR dialect (i.e., MemRefType, +/// UnrankedMemRefType or any other Standard type). +void LLVMTypeConverter::promoteBarePtrsToDescriptors( + ConversionPatternRewriter &rewriter, Location loc, ArrayRef stdTypes, + SmallVectorImpl &values) { + assert(stdTypes.size() == values.size() && + "The number of types and values doesn't match"); + for (unsigned i = 0, end = values.size(); i < end; ++i) { + Type stdTy = stdTypes[i]; + if (auto memrefTy = stdTy.dyn_cast()) + values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this, + memrefTy, values[i]); + else if (auto uMemrefTy = stdTy.dyn_cast()) + values[i] = promoteBarePtrToUnrankedMemRefDesc(values[i], uMemrefTy, loc, + rewriter); + } +} + ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &typeConverter, @@ -547,14 +593,6 @@ setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr); } -// Creates a constant Op producing a value of `resultType` from an index-typed -// integer attribute. -static Value createIndexAttrConstant(OpBuilder &builder, Location loc, - Type resultType, int64_t value) { - return builder.create( - loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); -} - /// Builds IR extracting the offset from the descriptor. Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) { return builder.create( @@ -1088,18 +1126,6 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern { protected: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - using UnsignedTypePair = std::pair; - - // Gather the positions and types of memref-typed arguments in a given - // FunctionType. - void getMemRefArgIndicesAndTypes( - FunctionType type, SmallVectorImpl &argsInfo) const { - argsInfo.reserve(type.getNumInputs()); - for (auto en : llvm::enumerate(type.getInputs())) { - if (en.value().isa()) - argsInfo.push_back({en.index(), en.value()}); - } - } // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided // to this legalization pattern. @@ -1192,11 +1218,10 @@ ConversionPatternRewriter &rewriter) const override { auto funcOp = cast(op); - // Store the positions and type of memref-typed arguments so that we can - // promote them to MemRef descriptor structs at the beginning of the - // function. - SmallVector promotedArgsInfo; - getMemRefArgIndicesAndTypes(funcOp.getType(), promotedArgsInfo); + // Store the type of memref-typed arguments before the conversion so that we + // can promote them to MemRef descriptor at the beginning of the function. + SmallVector oldArgTypes = + llvm::to_vector<8>(funcOp.getType().getInputs()); auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); if (!newFuncOp) @@ -1206,27 +1231,44 @@ return success(); } - // Promote bare pointers from MemRef arguments to a MemRef descriptor struct - // at the beginning of the function so that all the MemRefs in the function - // have a uniform representation. - Block *firstBlock = &newFuncOp.getBody().front(); - rewriter.setInsertionPoint(firstBlock, firstBlock->begin()); - auto funcLoc = funcOp.getLoc(); - for (const auto &argInfo : promotedArgsInfo) { - // TODO: Add support for unranked MemRefs. - if (auto memrefType = argInfo.second.dyn_cast()) { - // Replace argument with a placeholder (undef), promote argument to a - // MemRef descriptor and replace placeholder with the last instruction - // of the MemRef descriptor. The placeholder is needed to avoid - // replacing argument uses in the MemRef descriptor instructions. - BlockArgument arg = firstBlock->getArgument(argInfo.first); - Value placeHolder = - rewriter.create(funcLoc, arg.getType()); - rewriter.replaceUsesOfBlockArgument(arg, placeHolder); - auto desc = MemRefDescriptor::fromStaticShape( - rewriter, funcLoc, typeConverter, memrefType, arg); - rewriter.replaceOp(placeHolder.getDefiningOp(), {desc}); + // Promote bare pointers from ranked and unranked memref arguments to a + // memref descriptor at the beginning of the function so that all the + // memrefs in the function have a uniform representation. + Block *entryBlock = &newFuncOp.getBody().front(); + auto blockArgs = entryBlock->getArguments(); + assert(blockArgs.size() == oldArgTypes.size() && + "The number of arguments and types doesn't match"); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(entryBlock); + for (auto it : llvm::zip(blockArgs, oldArgTypes)) { + BlockArgument arg = std::get<0>(it); + Type argTy = std::get<1>(it); + + if (!argTy.isa()) + continue; + + // Replace barePtr with a placeholder (undef), promote barePtr to a ranked + // or unranked memref descriptor and replace placeholder with the last + // instruction of the MemRef descriptor. + // TODO: The placeholder is needed to avoid replacing barePtr uses in the + // MemRef descriptor instructions. We may want to have a utility in the + // rewriter to properly handle this use case. + Location loc = op->getLoc(); + auto placeholder = rewriter.create(loc, argTy); + rewriter.replaceUsesOfBlockArgument(arg, placeholder); + + Value desc; + if (auto memrefTy = argTy.dyn_cast()) { + desc = MemRefDescriptor::fromStaticShape(rewriter, loc, typeConverter, + memrefTy, arg); + } else { + auto unrankedMemrefTy = argTy.cast(); + desc = typeConverter.promoteBarePtrToUnrankedMemRefDesc( + arg, unrankedMemrefTy, loc, rewriter); } + + rewriter.replaceOp(placeholder, {desc}); } rewriter.eraseOp(op); @@ -2138,12 +2180,22 @@ rewriter.getI64ArrayAttr(i))); } } - if (failed(copyUnrankedDescriptors( - rewriter, op->getLoc(), this->typeConverter, op->getResultTypes(), - results, /*toDynamic=*/false))) - return failure(); - rewriter.replaceOp(op, results); + if (this->typeConverter.getOptions().useBarePtrCallConv) { + // For the bare-ptr calling convention, promote ranked and unranked + // memref results to their respective descriptors. + assert(results.size() == resultTypes.size() && + "The number of arguments and types doesn't match"); + this->typeConverter.promoteBarePtrsToDescriptors(rewriter, op->getLoc(), + resultTypes, results); + } else { + if (failed(copyUnrankedDescriptors( + rewriter, op->getLoc(), this->typeConverter, resultTypes, results, + /*toDynamic=*/false))) + return failure(); + } + + rewriter.replaceOp(op, results); return success(); } }; @@ -2706,11 +2758,31 @@ LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); unsigned numArguments = op->getNumOperands(); - auto updatedOperands = llvm::to_vector<4>(operands); - copyUnrankedDescriptors(rewriter, op->getLoc(), typeConverter, - op->getOperands().getTypes(), updatedOperands, - /*toDynamic=*/true); + SmallVector updatedOperands; + + if (typeConverter.getOptions().useBarePtrCallConv) { + // For the bare-ptr calling convention, extract the allocated pointers + // from the memrefs to be returned. + for (auto it : llvm::zip(op->getOperands(), operands)) { + Type oldTy = std::get<0>(it).getType(); + Value newOperand = std::get<1>(it); + if (oldTy.isa()) { + MemRefDescriptor memrefDesc(newOperand); + newOperand = memrefDesc.allocatedPtr(rewriter, loc); + } else if (oldTy.isa()) { + UnrankedMemRefDescriptor unrankedMemrefDesc(newOperand); + newOperand = unrankedMemrefDesc.memRefDescPtr(rewriter, loc); + } + updatedOperands.push_back(newOperand); + } + } else { + updatedOperands = llvm::to_vector<4>(operands); + copyUnrankedDescriptors(rewriter, loc, typeConverter, + op->getOperands().getTypes(), updatedOperands, + /*toDynamic=*/true); + } // If ReturnOp has 0 or 1 operand, create it and return immediately. if (numArguments == 0) { @@ -2729,10 +2801,10 @@ auto packedType = typeConverter.packFunctionResults( llvm::to_vector<4>(op->getOperandTypes())); - Value packed = rewriter.create(op->getLoc(), packedType); + Value packed = rewriter.create(loc, packedType); for (unsigned i = 0; i < numArguments; ++i) { packed = rewriter.create( - op->getLoc(), packedType, packed, updatedOperands[i], + loc, packedType, packed, updatedOperands[i], rewriter.getI64ArrayAttr(i)); } rewriter.replaceOpWithNewOp(op, ArrayRef(), packed, @@ -3380,17 +3452,20 @@ populateStdToLLVMMemoryConversionPatterns(converter, patterns); } -// Create an LLVM IR structure type if there is more than one result. +/// Convert a non-empty list of types to be returned from a function into a +/// supported LLVM IR type. In particular, if more than one value is returned, +/// create an LLVM IR structure type with elements that correspond to each of +/// the MLIR types converted with `convertType`. Type LLVMTypeConverter::packFunctionResults(ArrayRef types) { assert(!types.empty() && "expected non-empty list of type"); if (types.size() == 1) - return convertType(types.front()); + return convertCallingConventionType(types.front()); SmallVector resultTypes; resultTypes.reserve(types.size()); for (auto t : types) { - auto converted = convertType(t).dyn_cast(); + auto converted = convertCallingConventionType(t).dyn_cast(); if (!converted) return {}; resultTypes.push_back(converted); @@ -3426,16 +3501,28 @@ auto operand = std::get<0>(it); auto llvmOperand = std::get<1>(it); - if (operand.getType().isa()) { - UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand, - promotedOperands); - continue; - } - if (auto memrefType = operand.getType().dyn_cast()) { - MemRefDescriptor::unpack(builder, loc, llvmOperand, - operand.getType().cast(), - promotedOperands); - continue; + if (options.useBarePtrCallConv) { + // For the bare-ptr calling convention, we only have to extract the + // allocated pointer of a memref. + if (operand.getType().isa()) { + UnrankedMemRefDescriptor desc(llvmOperand); + llvmOperand = desc.memRefDescPtr(builder, loc); + } else if (auto memrefType = operand.getType().dyn_cast()) { + MemRefDescriptor desc(llvmOperand); + llvmOperand = desc.allocatedPtr(builder, loc); + } + } else { + if (operand.getType().isa()) { + UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand, + promotedOperands); + continue; + } + if (auto memrefType = operand.getType().dyn_cast()) { + MemRefDescriptor::unpack(builder, loc, llvmOperand, + operand.getType().cast(), + promotedOperands); + continue; + } } promotedOperands.push_back(llvmOperand); diff --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir @@ -14,13 +14,13 @@ // CHECK-COUNT-5: !llvm.i64 // CHECK-SAME: -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-LABEL: func @check_static_return -// BAREPTR-SAME: (%[[arg:.*]]: !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { +// BAREPTR-SAME: (%[[arg:.*]]: !llvm.ptr) -> !llvm.ptr { func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> { // CHECK: llvm.return %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR: %[[udf:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// BAREPTR-NEXT: %[[base:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// BAREPTR-NEXT: %[[base0:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // BAREPTR-NEXT: %[[ins0:.*]] = llvm.insertvalue %[[val0]], %[[aligned]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 @@ -31,7 +31,8 @@ // BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// BAREPTR-NEXT: llvm.return %[[ins4]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// BAREPTR-NEXT: %[[base1:.*]] = llvm.extractvalue %[[ins4]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// BAREPTR-NEXT: llvm.return %[[base1]] : !llvm.ptr return %static : memref<32x18xf32> } @@ -42,13 +43,13 @@ // CHECK-COUNT-5: !llvm.i64 // CHECK-SAME: -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-LABEL: func @check_static_return_with_offset -// BAREPTR-SAME: (%[[arg:.*]]: !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { +// BAREPTR-SAME: (%[[arg:.*]]: !llvm.ptr) -> !llvm.ptr { func @check_static_return_with_offset(%static : memref<32x18xf32, offset:7, strides:[22,1]>) -> memref<32x18xf32, offset:7, strides:[22,1]> { // CHECK: llvm.return %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR: %[[udf:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// BAREPTR-NEXT: %[[base:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// BAREPTR-NEXT: %[[base0:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val0:.*]] = llvm.mlir.constant(7 : index) : !llvm.i64 // BAREPTR-NEXT: %[[ins0:.*]] = llvm.insertvalue %[[val0]], %[[aligned]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 @@ -59,14 +60,15 @@ // BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// BAREPTR-NEXT: llvm.return %[[ins4]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// BAREPTR-NEXT: %[[base1:.*]] = llvm.extractvalue %[[ins4]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// BAREPTR-NEXT: llvm.return %[[base1]] : !llvm.ptr return %static : memref<32x18xf32, offset:7, strides:[22,1]> } // ----- // CHECK-LABEL: func @zero_d_alloc() -> !llvm.struct<(ptr, ptr, i64)> { -// BAREPTR-LABEL: func @zero_d_alloc() -> !llvm.struct<(ptr, ptr, i64)> { +// BAREPTR-LABEL: func @zero_d_alloc() -> !llvm.ptr { func @zero_d_alloc() -> memref { // CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr @@ -174,7 +176,7 @@ // ----- // CHECK-LABEL: func @static_alloc() -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { -// BAREPTR-LABEL: func @static_alloc() -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { +// BAREPTR-LABEL: func @static_alloc() -> !llvm.ptr { func @static_alloc() -> memref<32x18xf32> { // CHECK: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 // CHECK-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 @@ -388,3 +390,47 @@ %4 = dim %static, %c4 : memref<42x32x15x13x27xf32> return } + +// BAREPTR-LABEL: func @check_unranked_memref_args +// BAREPTR-SAME: %[[in:.*]]: !llvm.ptr) +func @check_unranked_memref_args(%in : memref<*xi8>) { + // BAREPTR-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 + // BAREPTR-NEXT: %[[undef:.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr)> + // BAREPTR-NEXT: %[[rank:.*]] = llvm.insertvalue %[[one]], %[[undef]][0] : !llvm.struct<(i64, ptr)> + // BAREPTR-NEXT: %[[desc:.*]] = llvm.insertvalue %[[in]], %[[rank]][1] : !llvm.struct<(i64, ptr)> + return +} + +// ----- + +// BAREPTR-LABEL: func @check_unranked_memref_return +// BAREPTR-SAME: %[[in:.*]]: !llvm.ptr) -> !llvm.ptr +func @check_unranked_memref_return(%in : memref<*xi8>) -> memref<*xi8> { + // BAREPTR: llvm.insertvalue + // BAREPTR-NEXT: %[[desc:.*]] = llvm.insertvalue + // BAREPTR-NEXT: %[[barePtr:.*]] = llvm.extractvalue %[[desc]][1] : !llvm.struct<(i64, ptr)> + // BAREPTR-NEXT: llvm.return %[[barePtr]] : !llvm.ptr + return %in : memref<*xi8> +} + +// ----- + +// BAREPTR: llvm.func @foo(!llvm.ptr) -> !llvm.ptr +func @foo(memref<*xi8>) -> memref<*xi8> + +// BAREPTR-LABEL: func @check_unranked_memref_func_call +// BAREPTR-SAME: %[[in:.*]]: !llvm.ptr) -> !llvm.ptr +func @check_unranked_memref_func_call(%in : memref<*xi8>) -> memref<*xi8> { + // BAREPTR: llvm.insertvalue + // BAREPTR-NEXT: %[[inDesc:.*]] = llvm.insertvalue + // BAREPTR-NEXT: %[[barePtr:.*]] = llvm.extractvalue %[[inDesc]][1] : !llvm.struct<(i64, ptr)> + // BAREPTR-NEXT: %[[call:.*]] = llvm.call @foo(%[[barePtr]]) : (!llvm.ptr) -> !llvm.ptr + // BAREPTR-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 + // BAREPTR-NEXT: %[[undef:.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr)> + // BAREPTR-NEXT: %[[rank:.*]] = llvm.insertvalue %[[one]], %[[undef]][0] : !llvm.struct<(i64, ptr)> + // BAREPTR-NEXT: %[[outDesc:.*]] = llvm.insertvalue %[[call]], %[[rank]][1] : !llvm.struct<(i64, ptr)> + %res = call @foo(%in) : (memref<*xi8>) -> (memref<*xi8>) + // BAREPTR-NEXT: %[[res:.*]] = llvm.extractvalue %[[outDesc]][1] : !llvm.struct<(i64, ptr)> + // BAREPTR-NEXT: llvm.return %[[res]] : !llvm.ptr + return %res : memref<*xi8> +}