diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h --- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h @@ -53,20 +53,23 @@ /// one and results are packed into a wrapped LLVM IR structure type. `result` /// is populated with argument mapping. Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, + bool useBarePtrCallConv, SignatureConversion &result); /// Convert a non-empty list of types to be returned from a function into a /// supported LLVM IR type. In particular, if more than one value is /// returned, create an LLVM IR structure type with elements that correspond /// to each of the MLIR types converted with `convertType`. - Type packFunctionResults(TypeRange types); + Type packFunctionResults(TypeRange types, + bool useBarePointerCallConv = false); /// Convert a type in the context of the default or bare pointer calling /// convention. Calling convention sensitive types, such as MemRefType and /// UnrankedMemRefType, are converted following the specific rules for the /// calling convention. Calling convention independent types are converted /// following the default LLVM type conversions. - Type convertCallingConventionType(Type type); + Type convertCallingConventionType(Type type, + bool useBarePointerCallConv = false); /// Promote the bare pointers in 'values' that resulted from memrefs to /// descriptors. 'stdTypes' holds the types of 'values' before the conversion @@ -95,8 +98,8 @@ /// of the platform-specific C/C++ ABI lowering related to struct argument /// passing. SmallVector promoteOperands(Location loc, ValueRange opOperands, - ValueRange operands, - OpBuilder &builder); + ValueRange operands, OpBuilder &builder, + bool useBarePtrCallConv = false); /// Promote the LLVM struct representation of one MemRef descriptor to stack /// and use pointer to struct to avoid the complexity of the platform-specific diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -58,6 +58,14 @@ static constexpr StringRef varargsAttrName = "func.varargs"; static constexpr StringRef linkageAttrName = "llvm.linkage"; +static constexpr StringRef barePtrAttrName = "func.bareptr"; + +/// Return `true` if the `op` should use base pointer calling convention. +static bool shouldUseBarePtrCallConv(Operation *op, + LLVMTypeConverter *typeConverter) { + return op && (op->hasAttr("func.bareptr") || + typeConverter->getOptions().useBarePtrCallConv); +} /// Only retain those attributes that are not constructed by /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument @@ -267,6 +275,55 @@ } } +/// Modifies the body of the function to construct the `MemRefDescriptor` from +/// the base pointer calling convention lowering of `memref` types. +static void modifyFuncOpToUseBarePtrCallingConv( + ConversionPatternRewriter &rewriter, Location loc, + LLVMTypeConverter &typeConverter, LLVM::LLVMFuncOp funcOp, + ArrayRef oldArgTypes) { + if (funcOp.getBody().empty()) + return; + + // Promote bare pointers from memref arguments to memref descriptors at the + // beginning of the function so that all the memrefs in the function have a + // uniform representation. + Block *entryBlock = &funcOp.getBody().front(); + auto blockArgs = entryBlock->getArguments(); + assert(blockArgs.size() == oldArgTypes.size() && + "The number of arguments and types doesn't match"); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(entryBlock); + for (auto it : llvm::zip(blockArgs, oldArgTypes)) { + BlockArgument arg = std::get<0>(it); + Type argTy = std::get<1>(it); + + // Unranked memrefs are not supported in the bare pointer calling + // convention. We should have bailed out before in the presence of + // unranked memrefs. + assert(!argTy.isa() && + "Unranked memref is not supported"); + auto memrefTy = argTy.dyn_cast(); + if (!memrefTy) + continue; + + // Replace barePtr with a placeholder (undef), promote barePtr to a ranked + // or unranked memref descriptor and replace placeholder with the last + // instruction of the memref descriptor. + // TODO: The placeholder is needed to avoid replacing barePtr uses in the + // MemRef descriptor instructions. We may want to have a utility in the + // rewriter to properly handle this use case. + Location loc = funcOp.getLoc(); + auto placeholder = rewriter.create( + loc, typeConverter.convertType(memrefTy)); + rewriter.replaceUsesOfBlockArgument(arg, placeholder); + + Value desc = MemRefDescriptor::fromStaticShape(rewriter, loc, typeConverter, + memrefTy, arg); + rewriter.replaceOp(placeholder, {desc}); + } +} + namespace { struct FuncOpConversionBase : public ConvertOpToLLVMPattern { @@ -284,7 +341,7 @@ TypeConverter::SignatureConversion result(funcOp.getNumArguments()); auto llvmType = getTypeConverter()->convertFunctionSignature( funcOp.getFunctionType(), varargsAttr && varargsAttr.getValue(), - result); + shouldUseBarePtrCallConv(funcOp, getTypeConverter()), result); if (!llvmType) return nullptr; @@ -415,89 +472,24 @@ if (!newFuncOp) return failure(); - if (funcOp->getAttrOfType( - LLVM::LLVMDialect::getEmitCWrapperAttrName())) { - if (newFuncOp.isVarArg()) - return funcOp->emitError("C interface for variadic functions is not " - "supported yet."); - - if (newFuncOp.isExternal()) - wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(), - funcOp, newFuncOp); - else - wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(), - funcOp, newFuncOp); - } - - rewriter.eraseOp(funcOp); - return success(); - } -}; - -/// FuncOp legalization pattern that converts MemRef arguments to bare pointers -/// to the MemRef element type. This will impact the calling convention and ABI. -struct BarePtrFuncOpConversion : public FuncOpConversionBase { - using FuncOpConversionBase::FuncOpConversionBase; - - LogicalResult - matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - // TODO: bare ptr conversion could be handled by argument materialization - // and most of the code below would go away. But to do this, we would need a - // way to distinguish between FuncOp and other regions in the - // addArgumentMaterialization hook. + if (!shouldUseBarePtrCallConv(funcOp, this->getTypeConverter())) { + if (funcOp->getAttrOfType( + LLVM::LLVMDialect::getEmitCWrapperAttrName())) { + if (newFuncOp.isVarArg()) + return funcOp->emitError("C interface for variadic functions is not " + "supported yet."); - // Store the type of memref-typed arguments before the conversion so that we - // can promote them to MemRef descriptor at the beginning of the function. - SmallVector oldArgTypes = - llvm::to_vector<8>(funcOp.getFunctionType().getInputs()); - - auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); - if (!newFuncOp) - return failure(); - if (newFuncOp.getBody().empty()) { - rewriter.eraseOp(funcOp); - return success(); - } - - // Promote bare pointers from memref arguments to memref descriptors at the - // beginning of the function so that all the memrefs in the function have a - // uniform representation. - Block *entryBlock = &newFuncOp.getBody().front(); - auto blockArgs = entryBlock->getArguments(); - assert(blockArgs.size() == oldArgTypes.size() && - "The number of arguments and types doesn't match"); - - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(entryBlock); - for (auto it : llvm::zip(blockArgs, oldArgTypes)) { - BlockArgument arg = std::get<0>(it); - Type argTy = std::get<1>(it); - - // Unranked memrefs are not supported in the bare pointer calling - // convention. We should have bailed out before in the presence of - // unranked memrefs. - assert(!argTy.isa() && - "Unranked memref is not supported"); - auto memrefTy = argTy.dyn_cast(); - if (!memrefTy) - continue; - - // Replace barePtr with a placeholder (undef), promote barePtr to a ranked - // or unranked memref descriptor and replace placeholder with the last - // instruction of the memref descriptor. - // TODO: The placeholder is needed to avoid replacing barePtr uses in the - // MemRef descriptor instructions. We may want to have a utility in the - // rewriter to properly handle this use case. - Location loc = funcOp.getLoc(); - auto placeholder = rewriter.create( - loc, getTypeConverter()->convertType(memrefTy)); - rewriter.replaceUsesOfBlockArgument(arg, placeholder); - - Value desc = MemRefDescriptor::fromStaticShape( - rewriter, loc, *getTypeConverter(), memrefTy, arg); - rewriter.replaceOp(placeholder, {desc}); + if (newFuncOp.isExternal()) + wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(), + funcOp, newFuncOp); + else + wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(), + funcOp, newFuncOp); + } + } else { + modifyFuncOpToUseBarePtrCallingConv(rewriter, funcOp.getLoc(), + *getTypeConverter(), newFuncOp, + funcOp.getFunctionType().getInputs()); } rewriter.eraseOp(funcOp); @@ -535,23 +527,24 @@ using Super = CallOpInterfaceLowering; using Base = ConvertOpToLLVMPattern; - LogicalResult - matchAndRewrite(CallOpType callOp, typename CallOpType::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + LogicalResult matchAndRewriteImpl(CallOpType callOp, + typename CallOpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter, + bool useBarePtrCallConv = false) const { // Pack the result types into a struct. Type packedResult = nullptr; unsigned numResults = callOp.getNumResults(); auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); if (numResults != 0) { - if (!(packedResult = - this->getTypeConverter()->packFunctionResults(resultTypes))) + if (!(packedResult = this->getTypeConverter()->packFunctionResults( + resultTypes, useBarePtrCallConv))) return failure(); } auto promoted = this->getTypeConverter()->promoteOperands( callOp.getLoc(), /*opOperands=*/callOp->getOperands(), - adaptor.getOperands(), rewriter); + adaptor.getOperands(), rewriter, useBarePtrCallConv); auto newOp = rewriter.create( callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), promoted, callOp->getAttrs()); @@ -570,7 +563,7 @@ } } - if (this->getTypeConverter()->getOptions().useBarePtrCallConv) { + if (useBarePtrCallConv) { // For the bare-ptr calling convention, promote memref results to // descriptors. assert(results.size() == resultTypes.size() && @@ -590,11 +583,28 @@ struct CallOpLowering : public CallOpInterfaceLowering { using Super::Super; + + LogicalResult + matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + bool useBarePtrCallConv = false; + if (Operation *callee = SymbolTable::lookupNearestSymbolFrom( + callOp, callOp.getCalleeAttr())) { + useBarePtrCallConv = shouldUseBarePtrCallConv(callee, getTypeConverter()); + } + return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv); + } }; struct CallIndirectOpLowering : public CallOpInterfaceLowering { using Super::Super; + + LogicalResult + matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter); + } }; struct UnrealizedConversionCastOpLowering @@ -640,7 +650,10 @@ unsigned numArguments = op.getNumOperands(); SmallVector updatedOperands; - if (getTypeConverter()->getOptions().useBarePtrCallConv) { + auto funcOp = op->getParentOfType(); + bool useBarePtrCallingConv = + shouldUseBarePtrCallingConv(funcOp, getTypeConverter()); + if (useBarePtrCallingConv) { // For the bare-ptr calling convention, extract the aligned pointer to // be returned from the memref descriptor. for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) { @@ -692,10 +705,7 @@ void mlir::populateFuncToLLVMFuncOpConversionPattern( LLVMTypeConverter &converter, RewritePatternSet &patterns) { - if (converter.getOptions().useBarePtrCallConv) - patterns.add(converter); - else - patterns.add(converter); + patterns.add(converter); } void mlir::populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -47,7 +47,8 @@ TypeConverter::SignatureConversion signatureConversion( gpuFuncOp.front().getNumArguments()); Type funcType = getTypeConverter()->convertFunctionSignature( - gpuFuncOp.getFunctionType(), /*isVariadic=*/false, signatureConversion); + gpuFuncOp.getFunctionType(), /*isVariadic=*/false, + /*useBarePtrCallConv=*/false, signatureConversion); // Create the new function operation. Only copy those attributes that are // not specific to function modeling. diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -209,8 +209,8 @@ // pointer-to-function types. Type LLVMTypeConverter::convertFunctionType(FunctionType type) { SignatureConversion conversion(type.getNumInputs()); - Type converted = - convertFunctionSignature(type, /*isVariadic=*/false, conversion); + Type converted = convertFunctionSignature( + type, /*isVariadic=*/false, options.useBarePtrCallConv, conversion); if (!converted) return {}; return getPointerType(converted); @@ -221,12 +221,12 @@ // Function has one VoidType result. If MLIR Function has more than one result, // they are into an LLVM StructType in their order of appearance. Type LLVMTypeConverter::convertFunctionSignature( - FunctionType funcTy, bool isVariadic, + FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, LLVMTypeConverter::SignatureConversion &result) { // Select the argument converter depending on the calling convention. - auto funcArgConverter = options.useBarePtrCallConv - ? barePtrFuncArgTypeConverter - : structFuncArgTypeConverter; + useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv; + auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter + : structFuncArgTypeConverter; // Convert argument types one by one and check for errors. for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) { SmallVector converted; @@ -238,9 +238,10 @@ // If function does not return anything, create the void result type, // if it returns on element, convert it, otherwise pack the result types into // a struct. - Type resultType = funcTy.getNumResults() == 0 - ? LLVM::LLVMVoidType::get(&getContext()) - : packFunctionResults(funcTy.getResults()); + Type resultType = + funcTy.getNumResults() == 0 + ? LLVM::LLVMVoidType::get(&getContext()) + : packFunctionResults(funcTy.getResults(), useBarePtrCallConv); if (!resultType) return {}; return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(), @@ -472,8 +473,9 @@ /// UnrankedMemRefType, are converted following the specific rules for the /// calling convention. Calling convention independent types are converted /// following the default LLVM type conversions. -Type LLVMTypeConverter::convertCallingConventionType(Type type) { - if (options.useBarePtrCallConv) +Type LLVMTypeConverter::convertCallingConventionType(Type type, + bool useBarePtrCallConv) { + if (useBarePtrCallConv) if (auto memrefTy = type.dyn_cast()) return convertMemRefToBarePtr(memrefTy); @@ -498,16 +500,18 @@ /// supported LLVM IR type. In particular, if more than one value is returned, /// create an LLVM IR structure type with elements that correspond to each of /// the MLIR types converted with `convertType`. -Type LLVMTypeConverter::packFunctionResults(TypeRange types) { +Type LLVMTypeConverter::packFunctionResults(TypeRange types, + bool useBarePtrCallConv) { assert(!types.empty() && "expected non-empty list of type"); + useBarePtrCallConv |= options.useBarePtrCallConv; if (types.size() == 1) - return convertCallingConventionType(types.front()); + return convertCallingConventionType(types.front(), useBarePtrCallConv); SmallVector resultTypes; resultTypes.reserve(types.size()); for (auto t : types) { - auto converted = convertCallingConventionType(t); + auto converted = convertCallingConventionType(t, useBarePtrCallConv); if (!converted || !LLVM::isCompatibleType(converted)) return {}; resultTypes.push_back(converted); @@ -530,17 +534,18 @@ return allocated; } -SmallVector LLVMTypeConverter::promoteOperands(Location loc, - ValueRange opOperands, - ValueRange operands, - OpBuilder &builder) { +SmallVector +LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands, + ValueRange operands, OpBuilder &builder, + bool useBarePtrCallConv) { SmallVector promotedOperands; promotedOperands.reserve(operands.size()); + useBarePtrCallConv |= options.useBarePtrCallConv; for (auto it : llvm::zip(opOperands, operands)) { auto operand = std::get<0>(it); auto llvmOperand = std::get<1>(it); - if (options.useBarePtrCallConv) { + if (useBarePtrCallConv) { // For the bare-ptr calling convention, we only have to extract the // aligned pointer of a memref. if (auto memrefType = operand.getType().dyn_cast()) { @@ -603,7 +608,8 @@ LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, SmallVectorImpl &result) { - auto llvmTy = converter.convertCallingConventionType(type); + auto llvmTy = + converter.convertCallingConventionType(type, /*useBarePtrCallConv=*/true); if (!llvmTy) return failure();