diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" @@ -227,37 +228,6 @@ AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); } -//===----------------------------------------------------------------------===// -// Add malloc/free declarations to the module. -//===----------------------------------------------------------------------===// - -static constexpr const char *kMalloc = "malloc"; -static constexpr const char *kFree = "free"; - -static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder, - StringRef name, Type ret, ArrayRef params) { - if (module.lookupSymbol(name)) - return; - Type type = LLVM::LLVMFunctionType::get(ret, params); - builder.create(name, type); -} - -/// Adds malloc/free declarations to the module. -static void addCRuntimeDeclarations(ModuleOp module) { - using namespace mlir::LLVM; - - MLIRContext *ctx = module.getContext(); - auto builder = - ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); - - auto voidTy = LLVMVoidType::get(ctx); - auto i64 = IntegerType::get(ctx, 64); - auto i8Ptr = LLVMPointerType::get(IntegerType::get(ctx, 8)); - - addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64}); - addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr}); -} - //===----------------------------------------------------------------------===// // Coroutine resume function wrapper. //===----------------------------------------------------------------------===// @@ -365,11 +335,18 @@ // Get coroutine frame size: @llvm.coro.size.i64. auto coroSize = rewriter.create(loc, rewriter.getI64Type()); + // The coroutine lowering doesn't properly account for alignment of the + // frame, so align everything to 64 bytes which ought to be enough for + // everyone. https://llvm.org/PR53148 + auto coroAlign = rewriter.create( + op->getLoc(), rewriter.getI64Type(), rewriter.getI64IntegerAttr(64)); // Allocate memory for the coroutine frame. + auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( + op->getParentOfType(), rewriter.getI64Type()); auto coroAlloc = rewriter.create( - loc, i8Ptr, SymbolRefAttr::get(rewriter.getContext(), kMalloc), - ValueRange(coroSize.getResult())); + loc, i8Ptr, SymbolRefAttr::get(allocFuncOp), + ValueRange{coroAlign, coroSize.getResult()}); // Begin a coroutine: @llvm.coro.begin. auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).id(); @@ -401,9 +378,11 @@ rewriter.create(loc, i8Ptr, adaptor.getOperands()); // Free the memory. - rewriter.replaceOpWithNewOp( - op, TypeRange(), SymbolRefAttr::get(rewriter.getContext(), kFree), - ValueRange(coroMem.getResult())); + auto freeFuncOp = + LLVM::lookupOrCreateFreeFn(op->getParentOfType()); + rewriter.replaceOpWithNewOp(op, TypeRange(), + SymbolRefAttr::get(freeFuncOp), + ValueRange(coroMem.getResult())); return success(); } @@ -968,7 +947,6 @@ // We delay adding the resume function until it's needed because it currently // fails to compile unless '-O0' is specified. addAsyncRuntimeApiDeclarations(module); - addCRuntimeDeclarations(module); // Lower async.runtime and async.coro operations to Async Runtime API and // LLVM coroutine intrinsics. diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir --- a/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir @@ -14,7 +14,8 @@ // CHECK: %[[ID:.*]] = llvm.intr.coro.id %0 = async.coro.id // CHECK: %[[SIZE:.*]] = llvm.intr.coro.size : i64 - // CHECK: %[[ALLOC:.*]] = llvm.call @malloc(%[[SIZE]]) + // CHECK: %[[ALIGN:.*]] = llvm.mlir.constant(64 : i64) : i64 + // CHECK: %[[ALLOC:.*]] = llvm.call @aligned_alloc(%[[ALIGN]], %[[SIZE]]) // CHECK: %[[HDL:.*]] = llvm.intr.coro.begin %[[ID]], %[[ALLOC]] %1 = async.coro.begin %0 return