diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -176,6 +176,82 @@ /// it if it does not exist. llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name); + /// Common CRTP base class for ModuleTranslation stack frames. + class StackFrame { + public: + virtual ~StackFrame() {} + TypeID getTypeID() const { return typeID; } + + protected: + explicit StackFrame(TypeID typeID) : typeID(typeID) {} + + private: + const TypeID typeID; + virtual void anchor(); + }; + + /// Concrete CRTP base class for ModuleTranslation stack frames. When + /// translating operations with regions, users of ModuleTranslation can store + /// state on ModuleTranslation stack before entering the region and inspect + /// it when converting operations nested within that region. Users are + /// expected to derive this class and put any relevant information into fields + /// of the derived class. The usual isa/dyn_cast functionality is available + /// for instances of derived classes. + template + class StackFrameBase : public StackFrame { + public: + explicit StackFrameBase() : StackFrame(TypeID::get()) {} + }; + + /// Creates a stack frame of type `T` on ModuleTranslation stack. `T` must + /// be derived from `StackFrameBase` and constructible from the provided + /// arguments. Doing this before entering the region of the op being + /// translated makes the frame available when translating ops within that + /// region. + template + void stackPush(Args &&... args) { + static_assert( + std::is_base_of::value, + "can only push instances of StackFrame on ModuleTranslation stack"); + stack.push_back(std::make_unique(std::forward(args)...)); + } + + /// Pops the last element from the ModuleTranslation stack. + void stackPop() { stack.pop_back(); } + + /// Calls `callback` for every ModuleTranslation stack frame of type `T` + /// starting from the top of the stack. + template + WalkResult + stackWalk(llvm::function_ref callback) const { + static_assert(std::is_base_of::value, + "expected T derived from StackFrame"); + if (!callback) + return WalkResult::skip(); + for (const std::unique_ptr &frame : llvm::reverse(stack)) { + if (T *ptr = dyn_cast_or_null(frame.get())) { + WalkResult result = callback(*ptr); + if (result.wasInterrupted()) + return result; + } + } + return WalkResult::advance(); + } + + /// RAII object calling stackPush/stackPop on construction/destruction. + template + struct SaveStack { + template + explicit SaveStack(ModuleTranslation &m, Args &&...args) + : moduleTranslation(m) { + moduleTranslation.stackPush(std::forward(args)...); + } + ~SaveStack() { moduleTranslation.stackPop(); } + + private: + ModuleTranslation &moduleTranslation; + }; + private: ModuleTranslation(Operation *module, std::unique_ptr llvmModule); @@ -233,6 +309,10 @@ /// metadata. The metadata is attached to Latch block branches with this /// attribute. DenseMap loopOptionsMetadataMapping; + + /// Stack of user-specified state elements, useful when translating operations + /// with regions. + SmallVector> stack; }; namespace detail { @@ -270,4 +350,14 @@ } // namespace LLVM } // namespace mlir +namespace llvm { +template +struct isa_impl { + static inline bool + doit(const ::mlir::LLVM::ModuleTranslation::StackFrame &frame) { + return frame.getTypeID() == ::mlir::TypeID::get(); + } +}; +} // namespace llvm + #endif // MLIR_TARGET_LLVMIR_MODULETRANSLATION_H diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -23,6 +23,42 @@ using namespace mlir; +namespace { +/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the +/// insertion points for allocas. +class OpenMPAllocaStackFrame + : public LLVM::ModuleTranslation::StackFrameBase { +public: + explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP) + : allocaInsertPoint(allocaIP) {} + llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint; +}; +} // namespace + +/// Find the insertion point for allocas given the current insertion point for +/// normal operations in the builder. +static llvm::OpenMPIRBuilder::InsertPointTy +findAllocaInsertPoint(llvm::IRBuilderBase &builder, + const LLVM::ModuleTranslation &moduleTranslation) { + // If there is an alloca insertion point on stack, i.e. we are in a nested + // operation and a specific point was provided by some surrounding operation, + // use it. + llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint; + WalkResult walkResult = moduleTranslation.stackWalk( + [&](const OpenMPAllocaStackFrame &frame) { + allocaInsertPoint = frame.allocaInsertPoint; + return WalkResult::interrupt(); + }); + if (walkResult.wasInterrupted()) + return allocaInsertPoint; + + // Otherwise, insert to the entry block of the surrounding function. + llvm::BasicBlock &funcEntryBlock = + builder.GetInsertBlock()->getParent()->getEntryBlock(); + return llvm::OpenMPIRBuilder::InsertPointTy( + &funcEntryBlock, funcEntryBlock.getFirstInsertionPt()); +} + /// Converts the given region that appears within an OpenMP dialect operation to /// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the /// region, and a branch from any block with an successor-less OpenMP terminator @@ -91,6 +127,11 @@ auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP, llvm::BasicBlock &continuationBlock) { + // Save the alloca insertion point on ModuleTranslation stack for use in + // nested regions. + LLVM::ModuleTranslation::SaveStack frame( + moduleTranslation, allocaIP); + // ParallelOp has only one region associated with it. auto ®ion = cast(opInst).getRegion(); convertOmpOpRegions(region, "omp.par.region", *codeGenIP.getBlock(), @@ -124,18 +165,14 @@ pbKind = llvm::omp::getProcBindKind(bind.getValue()); // TODO: Is the Parallel construct cancellable? bool isCancellable = false; - // TODO: Determine the actual alloca insertion point, e.g., the function - // entry or the alloca insertion point as provided by the body callback - // above. - llvm::OpenMPIRBuilder::InsertPointTy allocaIP(builder.saveIP()); - if (failed(bodyGenStatus)) - return failure(); + llvm::OpenMPIRBuilder::LocationDescription ompLoc( builder.saveIP(), builder.getCurrentDebugLocation()); builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createParallel( - ompLoc, allocaIP, bodyGenCB, privCB, finiCB, ifCond, numThreads, pbKind, - isCancellable)); - return success(); + ompLoc, findAllocaInsertPoint(builder, moduleTranslation), bodyGenCB, + privCB, finiCB, ifCond, numThreads, pbKind, isCancellable)); + + return bodyGenStatus; } /// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder. @@ -233,7 +270,6 @@ // TODO: this currently assumes WsLoop is semantically similar to SCF loop, // i.e. it has a positive step, uses signed integer semantics. Reconsider // this code when WsLoop clearly supports more cases. - llvm::BasicBlock *insertBlock = builder.GetInsertBlock(); llvm::CanonicalLoopInfo *loopInfo = moduleTranslation.getOpenMPBuilder()->createCanonicalLoop( ompLoc, bodyGen, lowerBound, upperBound, step, /*IsSigned=*/true, @@ -241,12 +277,8 @@ if (failed(bodyGenStatus)) return failure(); - // TODO: get the alloca insertion point from the parallel operation builder. - // If we insert the at the top of the current function, they will be passed as - // extra arguments into the function the parallel operation builder outlines. - // Put them at the start of the current block for now. - llvm::OpenMPIRBuilder::InsertPointTy allocaIP( - insertBlock, insertBlock->getFirstInsertionPt()); + llvm::OpenMPIRBuilder::InsertPointTy allocaIP = + findAllocaInsertPoint(builder, moduleTranslation); llvm::OpenMPIRBuilder::InsertPointTy afterIP; llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); if (isStatic) { diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -755,6 +755,8 @@ return llvmModule->getOrInsertNamedMetadata(name); } +void ModuleTranslation::StackFrame::anchor() {} + static std::unique_ptr prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext, StringRef name) { diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -151,6 +151,13 @@ // CHECK: define void @test_omp_parallel_if_1(i32 %[[IF_VAR_1:.*]]) llvm.func @test_omp_parallel_if_1(%arg0: i32) -> () { +// Check that the allocas are emitted by the OpenMPIRBuilder at the top of the +// function, before the condition. Allocas are only emitted by the builder when +// the `if` clause is present. We match specific SSA value names since LLVM +// actually produces those names. +// CHECK: %tid.addr{{.*}} = alloca i32 +// CHECK: %zero.addr{{.*}} = alloca i32 + // CHECK: %[[IF_COND_VAR_1:.*]] = icmp slt i32 %[[IF_VAR_1]], 0 %0 = llvm.mlir.constant(0 : index) : i32 %1 = llvm.icmp "slt" %arg0, %0 : i32 @@ -184,6 +191,60 @@ // CHECK: define internal void @[[OMP_OUTLINED_FN_IF_1]] // CHECK: call void @__kmpc_barrier +// ----- + +// CHECK-LABEL: @test_nested_alloca_ip +llvm.func @test_nested_alloca_ip(%arg0: i32) -> () { + + // Check that the allocas are emitted by the OpenMPIRBuilder at the top of + // the function, before the condition. Allocas are only emitted by the + // builder when the `if` clause is present. We match specific SSA value names + // since LLVM actually produces those names and ensure they come before the + // "icmp" that is the first operation we emit. + // CHECK: %tid.addr{{.*}} = alloca i32 + // CHECK: %zero.addr{{.*}} = alloca i32 + // CHECK: icmp slt i32 %{{.*}}, 0 + %0 = llvm.mlir.constant(0 : index) : i32 + %1 = llvm.icmp "slt" %arg0, %0 : i32 + + omp.parallel if(%1 : i1) { + // The "parallel" operation will be outlined, check the the function is + // produced. Inside that function, further allocas should be placed before + // another "icmp". + // CHECK: define + // CHECK: %tid.addr{{.*}} = alloca i32 + // CHECK: %zero.addr{{.*}} = alloca i32 + // CHECK: icmp slt i32 %{{.*}}, 1 + %2 = llvm.mlir.constant(1 : index) : i32 + %3 = llvm.icmp "slt" %arg0, %2 : i32 + + omp.parallel if(%3 : i1) { + // One more nesting level. + // CHECK: define + // CHECK: %tid.addr{{.*}} = alloca i32 + // CHECK: %zero.addr{{.*}} = alloca i32 + // CHECK: icmp slt i32 %{{.*}}, 2 + + %4 = llvm.mlir.constant(2 : index) : i32 + %5 = llvm.icmp "slt" %arg0, %4 : i32 + + omp.parallel if(%5 : i1) { + omp.barrier + omp.terminator + } + + omp.barrier + omp.terminator + } + omp.barrier + omp.terminator + } + + llvm.return +} + +// ----- + // CHECK-LABEL: define void @test_omp_parallel_3() llvm.func @test_omp_parallel_3() -> () { // CHECK: [[OMP_THREAD_3_1:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}})