diff --git a/mlir/include/mlir/IR/ImplicitLocOpBuilder.h b/mlir/include/mlir/IR/ImplicitLocOpBuilder.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/ImplicitLocOpBuilder.h @@ -0,0 +1,122 @@ +//===- ImplicitLocOpBuilder.h - Convenience OpBuilder -----------*- C++ -*-===// +// +// Helper class to create ops with a modally set location. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_IMPLICITLOCOPBUILDER_H +#define MLIR_IR_IMPLICITLOCOPBUILDER_H + +#include "mlir/IR/Builders.h" + +namespace mlir { + +/// ImplictLocOpBuilder maintains a 'current location', allowing use of the +/// create<> method without specifying the location. It is otherwise the same +/// as OpBuilder. +class ImplicitLocOpBuilder : public mlir::OpBuilder { + using OpBuilder = mlir::OpBuilder; + using Location = mlir::Location; + using Block = mlir::Block; + using Value = mlir::Value; + +public: + /// Create an ImplicitLocOpBuilder using the insertion point and listener from + /// an existing OpBuilder. + ImplicitLocOpBuilder(Location loc, const OpBuilder &builder) + : OpBuilder(builder), curLoc(loc) {} + + /// OpBuilder has a bunch of convenience constructors - we support them all + /// with the additional Location. + template + ImplicitLocOpBuilder(Location loc, T &&operand, Listener *listener = nullptr) + : OpBuilder(operand, listener), curLoc(loc) {} + + ImplicitLocOpBuilder(Location loc, Block *block, Block::iterator insertPoint, + Listener *listener = nullptr) + : OpBuilder(block, insertPoint, listener), curLoc(loc) {} + + /// Create a builder and set the insertion point to before the first operation + /// in the block but still inside the block. + static ImplicitLocOpBuilder atBlockBegin(Location loc, Block *block, + Listener *listener = nullptr) { + return ImplicitLocOpBuilder(loc, block, block->begin(), listener); + } + + /// Create a builder and set the insertion point to after the last operation + /// in the block but still inside the block. + static ImplicitLocOpBuilder atBlockEnd(Location loc, Block *block, + Listener *listener = nullptr) { + return ImplicitLocOpBuilder(loc, block, block->end(), listener); + } + + /// Create a builder and set the insertion point to before the block + /// terminator. + static ImplicitLocOpBuilder atBlockTerminator(Location loc, Block *block, + Listener *listener = nullptr) { + auto *terminator = block->getTerminator(); + assert(terminator != nullptr && "the block has no terminator"); + return ImplicitLocOpBuilder(loc, block, Block::iterator(terminator), + listener); + } + + /// Accessors for the implied location. + Location getLoc() const { return curLoc; } + void setLoc(Location loc) { curLoc = loc; } + + // We allow clients to use the explicit-loc version of create as well. + using OpBuilder::create; + using OpBuilder::createOrFold; + + /// Create an operation of specific op type at the current insertion point and + /// location. + template + OpTy create(Args &&... args) { + return OpBuilder::create(curLoc, std::forward(args)...); + } + + /// Create an operation of specific op type at the current insertion point, + /// and immediately try to fold it. This functions populates 'results' with + /// the results after folding the operation. + template + void createOrFold(llvm::SmallVectorImpl &results, Args &&... args) { + OpBuilder::createOrFold(results, curLoc, std::forward(args)...); + } + + /// Overload to create or fold a single result operation. + template + typename std::enable_if(), + Value>::type + createOrFold(Args &&... args) { + return OpBuilder::createOrFold(curLoc, std::forward(args)...); + } + + /// Overload to create or fold a zero result operation. + template + typename std::enable_if(), + OpTy>::type + createOrFold(Args &&... args) { + return OpBuilder::createOrFold(curLoc, std::forward(args)...); + } + + /// This builder can also be used to emit diagnostics to the current location. + mlir::InFlightDiagnostic + emitError(const llvm::Twine &message = llvm::Twine()) { + return mlir::emitError(curLoc, message); + } + mlir::InFlightDiagnostic + emitWarning(const llvm::Twine &message = llvm::Twine()) { + return mlir::emitWarning(curLoc, message); + } + mlir::InFlightDiagnostic + emitRemark(const llvm::Twine &message = llvm::Twine()) { + return mlir::emitRemark(curLoc, message); + } + +private: + Location curLoc; +}; + +} // namespace mlir + +#endif // MLIR_IR_IMPLICITLOCOPBUILDER_H \ No newline at end of file diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -13,7 +13,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -112,12 +112,13 @@ // Adds Async Runtime C API declarations to the module. static void addAsyncRuntimeApiDeclarations(ModuleOp module) { - auto builder = OpBuilder::atBlockTerminator(module.getBody()); + auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(), + module.getBody()); auto addFuncDecl = [&](StringRef name, FunctionType type) { if (module.lookupSymbol(name)) return; - builder.create(module.getLoc(), name, type).setPrivate(); + builder.create(name, type).setPrivate(); }; MLIRContext *ctx = module.getContext(); @@ -149,13 +150,13 @@ static constexpr const char *kCoroResume = "llvm.coro.resume"; /// Adds an LLVM function declaration to a module. -static void addLLVMFuncDecl(ModuleOp module, OpBuilder &builder, StringRef name, - LLVM::LLVMType ret, +static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder, + StringRef name, LLVM::LLVMType ret, ArrayRef params) { if (module.lookupSymbol(name)) return; LLVM::LLVMType type = LLVM::LLVMType::getFunctionTy(ret, params, false); - builder.create(module.getLoc(), name, type); + builder.create(name, type); } /// Adds coroutine intrinsics declarations to the module. @@ -163,7 +164,8 @@ using namespace mlir::LLVM; MLIRContext *ctx = module.getContext(); - OpBuilder builder(module.getBody()->getTerminator()); + ImplicitLocOpBuilder builder(module.getLoc(), + module.getBody()->getTerminator()); auto token = LLVMTokenType::get(ctx); auto voidTy = LLVMType::getVoidTy(ctx); @@ -196,7 +198,8 @@ using namespace mlir::LLVM; MLIRContext *ctx = module.getContext(); - OpBuilder builder(module.getBody()->getTerminator()); + ImplicitLocOpBuilder builder(module.getLoc(), + module.getBody()->getTerminator()); auto voidTy = LLVMType::getVoidTy(ctx); auto i64 = LLVMType::getInt64Ty(ctx); @@ -232,13 +235,13 @@ resumeOp.setPrivate(); auto *block = resumeOp.addEntryBlock(); - OpBuilder blockBuilder = OpBuilder::atBlockEnd(block); + auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); - blockBuilder.create(loc, TypeRange(), + blockBuilder.create(TypeRange(), blockBuilder.getSymbolRefAttr(kCoroResume), resumeOp.getArgument(0)); - blockBuilder.create(loc, ValueRange()); + blockBuilder.create(ValueRange()); } //===----------------------------------------------------------------------===// @@ -302,13 +305,12 @@ Block *entryBlock = func.addEntryBlock(); Location loc = func.getBody().getLoc(); - OpBuilder builder = OpBuilder::atBlockBegin(entryBlock); + auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, entryBlock); // ------------------------------------------------------------------------ // // Allocate async tokens/values that we will return from a ramp function. // ------------------------------------------------------------------------ // - auto createToken = - builder.create(loc, kCreateToken, TokenType::get(ctx)); + auto createToken = builder.create(kCreateToken, TokenType::get(ctx)); // ------------------------------------------------------------------------ // // Initialize coroutine: allocate frame, get coroutine handle. @@ -316,28 +318,28 @@ // Constants for initializing coroutine frame. auto constZero = - builder.create(loc, i32, builder.getI32IntegerAttr(0)); + builder.create(i32, builder.getI32IntegerAttr(0)); auto constFalse = - builder.create(loc, i1, builder.getBoolAttr(false)); - auto nullPtr = builder.create(loc, i8Ptr); + builder.create(i1, builder.getBoolAttr(false)); + auto nullPtr = builder.create(i8Ptr); // Get coroutine id: @llvm.coro.id auto coroId = builder.create( - loc, token, builder.getSymbolRefAttr(kCoroId), + token, builder.getSymbolRefAttr(kCoroId), ValueRange({constZero, nullPtr, nullPtr, nullPtr})); // Get coroutine frame size: @llvm.coro.size.i64 auto coroSize = builder.create( - loc, i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange()); + i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange()); // Allocate memory for coroutine frame. - auto coroAlloc = builder.create( - loc, i8Ptr, builder.getSymbolRefAttr(kMalloc), - ValueRange(coroSize.getResult(0))); + auto coroAlloc = + builder.create(i8Ptr, builder.getSymbolRefAttr(kMalloc), + ValueRange(coroSize.getResult(0))); // Begin a coroutine: @llvm.coro.begin auto coroHdl = builder.create( - loc, i8Ptr, builder.getSymbolRefAttr(kCoroBegin), + i8Ptr, builder.getSymbolRefAttr(kCoroBegin), ValueRange({coroId.getResult(0), coroAlloc.getResult(0)})); Block *cleanupBlock = func.addBlock(); @@ -350,15 +352,14 @@ // Get a pointer to the coroutine frame memory: @llvm.coro.free. auto coroMem = builder.create( - loc, i8Ptr, builder.getSymbolRefAttr(kCoroFree), + i8Ptr, builder.getSymbolRefAttr(kCoroFree), ValueRange({coroId.getResult(0), coroHdl.getResult(0)})); // Free the memory. - builder.create(loc, TypeRange(), - builder.getSymbolRefAttr(kFree), + builder.create(TypeRange(), builder.getSymbolRefAttr(kFree), ValueRange(coroMem.getResult(0))); // Branch into the suspend block. - builder.create(loc, suspendBlock); + builder.create(suspendBlock); // ------------------------------------------------------------------------ // // Coroutine suspend block: mark the end of a coroutine and return allocated @@ -367,17 +368,17 @@ builder.setInsertionPointToStart(suspendBlock); // Mark the end of a coroutine: @llvm.coro.end. - builder.create(loc, i1, builder.getSymbolRefAttr(kCoroEnd), + builder.create(i1, builder.getSymbolRefAttr(kCoroEnd), ValueRange({coroHdl.getResult(0), constFalse})); // Return created `async.token` from the suspend block. This will be the // return value of a coroutine ramp function. - builder.create(loc, createToken.getResult(0)); + builder.create(createToken.getResult(0)); // Branch from the entry block to the cleanup block to create a valid CFG. builder.setInsertionPointToEnd(entryBlock); - builder.create(loc, cleanupBlock); + builder.create(cleanupBlock); // `async.await` op lowering will create resume blocks for async // continuations, and will conditionally branch to cleanup or suspend blocks. @@ -471,8 +472,6 @@ MLIRContext *ctx = module.getContext(); Location loc = execute.getLoc(); - OpBuilder moduleBuilder(module.getBody()->getTerminator()); - // Collect all outlined function inputs. llvm::SetVector functionInputs(execute.dependencies().begin(), execute.dependencies().end()); @@ -484,13 +483,13 @@ SmallVector inputTypes(typesRange.begin(), typesRange.end()); auto outputTypes = execute.getResultTypes(); - auto funcType = moduleBuilder.getFunctionType(inputTypes, outputTypes); + auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); auto funcAttrs = ArrayRef(); // TODO: Derive outlined function name from the parent FuncOp (support // multiple nested async.execute operations). FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); - symbolTable.insert(func, moduleBuilder.getInsertionPoint()); + symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator())); SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); @@ -502,21 +501,21 @@ // Async execute API (execution will be resumed in a thread managed by the // async runtime). Block *entryBlock = &func.getBlocks().front(); - OpBuilder builder = OpBuilder::atBlockTerminator(entryBlock); + auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock); // A pointer to coroutine resume intrinsic wrapper. auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); - auto resumePtr = builder.create( - loc, resumeFnTy.getPointerTo(), kResume); + auto resumePtr = + builder.create(resumeFnTy.getPointerTo(), kResume); // Save the coroutine state: @llvm.coro.save auto coroSave = builder.create( - loc, LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave), + LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave), ValueRange({coro.coroHandle})); // Call async runtime API to execute a coroutine in the managed thread. SmallVector executeArgs = {coro.coroHandle, resumePtr.res()}; - builder.create(loc, TypeRange(), kExecute, executeArgs); + builder.create(TypeRange(), kExecute, executeArgs); // Split the entry block before the terminator. auto *terminatorOp = entryBlock->getTerminator(); @@ -528,7 +527,7 @@ // Await on all dependencies before starting to execute the body region. builder.setInsertionPointToStart(resume); for (size_t i = 0; i < execute.dependencies().size(); ++i) - builder.create(loc, func.getArgument(i)); + builder.create(func.getArgument(i)); // Map from function inputs defined above the execute op to the function // arguments. @@ -540,17 +539,16 @@ // to async runtime to emplace the result token. for (Operation &op : execute.body().getOps()) { if (isa(op)) { - builder.create(loc, kEmplaceToken, TypeRange(), coro.asyncToken); + builder.create(kEmplaceToken, TypeRange(), coro.asyncToken); continue; } builder.clone(op, valueMapping); } // Replace the original `async.execute` with a call to outlined function. - OpBuilder callBuilder(execute); - auto callOutlinedFunc = - callBuilder.create(loc, func.getName(), execute.getResultTypes(), - functionInputs.getArrayRef()); + ImplicitLocOpBuilder callBuilder(loc, execute); + auto callOutlinedFunc = callBuilder.create( + func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); execute.replaceAllUsesWith(callOutlinedFunc.getResults()); execute.erase(); @@ -744,24 +742,24 @@ if (isInCoroutine) { const CoroMachinery &coro = outlined->getSecond(); - OpBuilder builder(op, rewriter.getListener()); + ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); MLIRContext *ctx = op->getContext(); // A pointer to coroutine resume intrinsic wrapper. auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); - auto resumePtr = builder.create( - loc, resumeFnTy.getPointerTo(), kResume); + auto resumePtr = + builder.create(resumeFnTy.getPointerTo(), kResume); // Save the coroutine state: @llvm.coro.save auto coroSave = builder.create( - loc, LLVM::LLVMTokenType::get(ctx), - builder.getSymbolRefAttr(kCoroSave), ValueRange(coro.coroHandle)); + LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave), + ValueRange(coro.coroHandle)); // Call async runtime API to resume a coroutine in the managed thread when // the async await argument becomes ready. SmallVector awaitAndExecuteArgs = {operands[0], coro.coroHandle, resumePtr.res()}; - builder.create(loc, TypeRange(), coroAwaitFuncName, + builder.create(TypeRange(), coroAwaitFuncName, awaitAndExecuteArgs); Block *suspended = op->getBlock();