diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -80,27 +80,18 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) { auto builder = OpBuilder::atBlockTerminator(module.getBody()); - MLIRContext *ctx = module.getContext(); - Location loc = module.getLoc(); - - if (!module.lookupSymbol(kCreateToken)) - builder.create(loc, kCreateToken, - AsyncAPI::createTokenFunctionType(ctx)); - - if (!module.lookupSymbol(kEmplaceToken)) - builder.create(loc, kEmplaceToken, - AsyncAPI::emplaceTokenFunctionType(ctx)); - - if (!module.lookupSymbol(kAwaitToken)) - builder.create(loc, kAwaitToken, - AsyncAPI::awaitTokenFunctionType(ctx)); - - if (!module.lookupSymbol(kExecute)) - builder.create(loc, kExecute, AsyncAPI::executeFunctionType(ctx)); + auto addFuncDecl = [&](StringRef name, FunctionType type) { + if (module.lookupSymbol(name)) + return; + builder.create(module.getLoc(), name, type); + }; - if (!module.lookupSymbol(kAwaitAndExecute)) - builder.create(loc, kAwaitAndExecute, - AsyncAPI::awaitAndExecuteFunctionType(ctx)); + MLIRContext *ctx = module.getContext(); + addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); + addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); + addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); + addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); + addFuncDecl(kAwaitAndExecute, AsyncAPI::awaitAndExecuteFunctionType(ctx)); } //===----------------------------------------------------------------------===// @@ -116,13 +107,21 @@ static constexpr const char *kCoroFree = "llvm.coro.free"; static constexpr const char *kCoroResume = "llvm.coro.resume"; +/// Adds an LLVM function declaration to a module. +static void addLLVMFuncDecl(ModuleOp module, OpBuilder &builder, StringRef name, + LLVM::LLVMType ret, + ArrayRef params) { + if (module.lookupSymbol(name)) + return; + LLVM::LLVMType type = LLVM::LLVMType::getFunctionTy(ret, params, false); + builder.create(module.getLoc(), name, type); +} + /// Adds coroutine intrinsics declarations to the module. static void addCoroutineIntrinsicsDeclarations(ModuleOp module) { using namespace mlir::LLVM; MLIRContext *ctx = module.getContext(); - Location loc = module.getLoc(); - OpBuilder builder(module.getBody()->getTerminator()); auto token = LLVMTokenType::get(ctx); @@ -134,38 +133,14 @@ auto i64 = LLVMType::getInt64Ty(ctx); auto i8Ptr = LLVMType::getInt8PtrTy(ctx); - if (!module.lookupSymbol(kCoroId)) - builder.create( - loc, kCoroId, - LLVMType::getFunctionTy(token, {i32, i8Ptr, i8Ptr, i8Ptr}, false)); - - if (!module.lookupSymbol(kCoroSizeI64)) - builder.create(loc, kCoroSizeI64, - LLVMType::getFunctionTy(i64, false)); - - if (!module.lookupSymbol(kCoroBegin)) - builder.create( - loc, kCoroBegin, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false)); - - if (!module.lookupSymbol(kCoroSave)) - builder.create(loc, kCoroSave, - LLVMType::getFunctionTy(token, i8Ptr, false)); - - if (!module.lookupSymbol(kCoroSuspend)) - builder.create(loc, kCoroSuspend, - LLVMType::getFunctionTy(i8, {token, i1}, false)); - - if (!module.lookupSymbol(kCoroEnd)) - builder.create(loc, kCoroEnd, - LLVMType::getFunctionTy(i1, {i8Ptr, i1}, false)); - - if (!module.lookupSymbol(kCoroFree)) - builder.create( - loc, kCoroFree, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false)); - - if (!module.lookupSymbol(kCoroResume)) - builder.create(loc, kCoroResume, - LLVMType::getFunctionTy(voidTy, i8Ptr, false)); + addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr}); + addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {}); + addLLVMFuncDecl(module, builder, kCoroBegin, i8Ptr, {token, i8Ptr}); + addLLVMFuncDecl(module, builder, kCoroSave, token, {i8Ptr}); + addLLVMFuncDecl(module, builder, kCoroSuspend, i8, {token, i1}); + addLLVMFuncDecl(module, builder, kCoroEnd, i1, {i8Ptr, i1}); + addLLVMFuncDecl(module, builder, kCoroFree, i8Ptr, {token, i8Ptr}); + addLLVMFuncDecl(module, builder, kCoroResume, voidTy, {i8Ptr}); } //===----------------------------------------------------------------------===// @@ -180,21 +155,14 @@ using namespace mlir::LLVM; MLIRContext *ctx = module.getContext(); - Location loc = module.getLoc(); - OpBuilder builder(module.getBody()->getTerminator()); auto voidTy = LLVMType::getVoidTy(ctx); auto i64 = LLVMType::getInt64Ty(ctx); auto i8Ptr = LLVMType::getInt8PtrTy(ctx); - if (!module.lookupSymbol(kMalloc)) - builder.create( - loc, kMalloc, LLVMType::getFunctionTy(i8Ptr, {i64}, false)); - - if (!module.lookupSymbol(kFree)) - builder.create( - loc, kFree, LLVMType::getFunctionTy(voidTy, i8Ptr, false)); + addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64}); + addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr}); } //===----------------------------------------------------------------------===// @@ -219,8 +187,8 @@ auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); auto resumeOp = moduleBuilder.create( - loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); - SymbolTable::setSymbolVisibility(resumeOp, SymbolTable::Visibility::Private); + loc, kResume, LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false)); + resumeOp.setPrivate(); auto *block = resumeOp.addEntryBlock(); OpBuilder blockBuilder = OpBuilder::atBlockEnd(block);