diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -185,6 +185,11 @@ ]; let regions = (region AnyRegion:$region); + + let extraClassDeclaration = [{ + /// Returns the number of loops in the workshape loop nest. + unsigned getNumLoops() { return lowerBound().size(); } + }]; } def YieldOp : OpenMP_Op<"yield", [NoSideEffect, ReturnLike, Terminator, diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -91,6 +91,8 @@ llvm::IRBuilder<> &builder); virtual LogicalResult convertOmpParallel(Operation &op, llvm::IRBuilder<> &builder); + virtual LogicalResult convertOmpWsLoop(Operation &opInst, + llvm::IRBuilder<> &builder); /// Converts the type from MLIR LLVM dialect to LLVM. llvm::Type *convertType(LLVMType type); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -465,6 +465,176 @@ return success(); } +/// Returns an LLVM function to call for initializing loop bounds using OpenMP +/// static scheduling depending on `type`. +static llvm::FunctionCallee +getKmpcForStaticInitForType(Type type, llvm::Module &llvmModule, + llvm::OpenMPIRBuilder &ompBuilder) { + // So far, all code is being lowered from `index` types for loop induction + // variables, which are considered signed. + // TODO: model signed/unsigned difference if necessary. + unsigned bitwidth = type.cast().getBitWidth(); + if (bitwidth == 32) + return ompBuilder.getOrCreateRuntimeFunction( + llvmModule, + llvm::omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_4); + if (bitwidth == 64) + return ompBuilder.getOrCreateRuntimeFunction( + llvmModule, + llvm::omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_8); + + // TODO: we should have a verifier on WsLoopOp for this. + llvm_unreachable("unknown OpenMP loop iterator bitwidth"); +} + +/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder. +LogicalResult ModuleTranslation::convertOmpWsLoop(Operation &opInst, + llvm::IRBuilder<> &builder) { + auto loop = cast(opInst); + // TODO: this should be in the op verifier instead. + if (loop.lowerBound().empty()) + return failure(); + + if (loop.getNumLoops() != 1) + return opInst.emitOpError("collapsed loops not yet supported"); + + if (loop.schedule_val().hasValue() && + omp::symbolizeClauseScheduleKind(loop.schedule_val().getValue()) != + omp::ClauseScheduleKind::Static) + return opInst.emitOpError( + "only static (default) loop schedule is currently supported"); + + llvm::Function *func = builder.GetInsertBlock()->getParent(); + llvm::LLVMContext &llvmContext = llvmModule->getContext(); + + // Declare the necessary OpenMP runtime functions. + llvm::Module &llvmModule = + *builder.GetInsertBlock()->getParent()->getParent(); + llvm::FunctionCallee staticInit = getKmpcForStaticInitForType( + loop.step()[0].getType(), llvmModule, *ompBuilder); + llvm::FunctionCallee staticFini = ompBuilder->getOrCreateRuntimeFunction( + llvmModule, llvm::omp::OMPRTL___kmpc_for_static_fini); + llvm::FunctionCallee globalThreadNum = ompBuilder->getOrCreateRuntimeFunction( + llvmModule, llvm::omp::OMPRTL___kmpc_global_thread_num); + + // Store loop bounds in a `alloca`ed memory as expected by the upcoming call + // to the scheduler that can modify them. Note that the upper bound is + // expected to be inclusive by the runtime. Also prepare other arguments to + // the runtime call. + llvm::Type *i32Type = llvm::Type::getInt32Ty(builder.getContext()); + llvm::Value *lastIter = builder.CreateAlloca(i32Type, nullptr, "p.lastiter"); + llvm::Value *step = valueMapping.lookup(loop.step()[0]); + llvm::Type *ivType = step->getType(); + llvm::Value *lowerBound = builder.CreateAlloca(ivType); + llvm::Value *upperBound = builder.CreateAlloca(ivType); + llvm::Value *stride = builder.CreateAlloca(ivType); + llvm::Constant *one = llvm::ConstantInt::get(ivType, 1); + llvm::Value *inclusiveUpperBound = + builder.CreateSub(valueMapping[loop.upperBound()[0]], one); + builder.CreateStore(valueMapping[loop.lowerBound()[0]], lowerBound); + builder.CreateStore(inclusiveUpperBound, upperBound); + builder.CreateStore(valueMapping[loop.step()[0]], stride); + llvm::Value *chunk = loop.schedule_chunk_var() + ? valueMapping[loop.schedule_chunk_var()] + : llvm::ConstantInt::get(ivType, 1); + + // Set up the source location value for OpenMP runtime. + llvm::DISubprogram *subprogram = + builder.GetInsertBlock()->getParent()->getSubprogram(); + const llvm::DILocation *diLoc = + debugTranslation->translateLoc(opInst.getLoc(), subprogram); + llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder.saveIP(), + llvm::DebugLoc(diLoc)); + llvm::Constant *srcLocStr = ompBuilder->getOrCreateSrcLocStr(ompLoc); + llvm::Value *srcLoc = ompBuilder->getOrCreateIdent(srcLocStr); + + // Get the global thread id of the current thread. + llvm::Value *threadNum = builder.CreateCall(globalThreadNum, {srcLoc}); + + // TODO: extract scheduling type and map it to OMP constant. This is curently + // happening in kmp.h and its ilk and needs to be moved to OpenMP.td first. + constexpr int kStaticSchedType = 34; + llvm::Constant *schedulingType = + llvm::ConstantInt::get(i32Type, kStaticSchedType); + + // Call the runtime function to compute new loop bounds according to the + // scheduler policy. + builder.CreateCall(staticInit, {srcLoc, threadNum, schedulingType, lastIter, + lowerBound, upperBound, stride, step, chunk}); + llvm::Value *lowerBoundVal = builder.CreateLoad(lowerBound); + llvm::Value *upperBoundVal = builder.CreateLoad(upperBound); + + // Generator of the canonical loop body. Produces an SESE region of basic + // blocks. + // TODO: support error propagation in OpenMPIRBuilder and use it instead of + // relying on captured variables. + LogicalResult bodyGenStatus = success(); + auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) { + llvm::IRBuilder<>::InsertPointGuard guard(builder); + + // Make sure further conversions know about the induction variable. + valueMapping[loop.getRegion().front().getArgument(0)] = iv; + + llvm::BasicBlock *entryBlock = ip.getBlock(); + llvm::BasicBlock *exitBlock = + entryBlock->splitBasicBlock(ip.getPoint(), "omp.wsloop.exit"); + + // Convert the body of the loop. + Region ®ion = loop.region(); + for (Block &bb : region) { + llvm::BasicBlock *llvmBB = + llvm::BasicBlock::Create(llvmContext, "omp.wsloop.region", func); + blockMapping[&bb] = llvmBB; + + // Retarget the branch of the entry block to the entry block of the + // converted region (regions are single-entry). + if (bb.isEntryBlock()) { + auto *branch = cast(entryBlock->getTerminator()); + branch->setSuccessor(0, llvmBB); + } + } + + // Block conversion creates a new IRBuilder every time so need not bother + // about maintaining the insertion point. + llvm::SetVector blocks = topologicalSort(region); + for (Block *bb : blocks) { + if (failed(convertBlock(*bb, bb->isEntryBlock()))) { + bodyGenStatus = failure(); + return; + } + + // Special handling for `omp.yield` terminators (we may have more than + // one): they return the control to the parent WsLoop operation so replace + // them with the branch to the exit block. We handle this here to avoid + // relying inter-function communication through the ModuleTranslation + // class to set up the correct insertion point. This is also consistent + // with MLIR's idiom of handling special region terminators in the same + // code that handles the region-owning operation. + if (isa(bb->getTerminator())) { + llvm::BasicBlock *llvmBB = blockMapping[bb]; + builder.SetInsertPoint(llvmBB, llvmBB->end()); + builder.CreateBr(exitBlock); + } + } + + connectPHINodes(region, valueMapping, blockMapping); + }; + + // Delegate actual loop construction to the OpenMP IRBuilder. + llvm::CanonicalLoopInfo *loopInfo = ompBuilder->createCanonicalLoop( + builder, bodyGen, lowerBoundVal, upperBoundVal, + valueMapping[loop.step()[0]], /*IsSigned=*/true, + /*InclusiveStop=*/true); + if (failed(bodyGenStatus)) + return failure(); + + // Notify the scheduler when the loop is complete. + builder.restoreIP(loopInfo->getAfterIP()); + builder.CreateCall(staticFini, {srcLoc, threadNum}); + + return success(); +} + /// Given an OpenMP MLIR operation, create the corresponding LLVM IR /// (including OpenMP runtime calls). LogicalResult @@ -505,6 +675,14 @@ }) .Case( [&](omp::ParallelOp) { return convertOmpParallel(opInst, builder); }) + .Case([&](omp::WsLoopOp) { return convertOmpWsLoop(opInst, builder); }) + .Case([&](omp::YieldOp op) { + // Yields are loop terminators that can be just omitted. The loop + // structure was created in the function that handles WsLoopOp. + assert(op.getNumOperands() == 0 && "unexpected yield with operands"); + return success(); + }) + .Default([&](Operation *inst) { return inst->emitError("unsupported OpenMP operation: ") << inst->getName(); diff --git a/mlir/test/Target/openmp-llvm.mlir b/mlir/test/Target/openmp-llvm.mlir --- a/mlir/test/Target/openmp-llvm.mlir +++ b/mlir/test/Target/openmp-llvm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s +// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s // CHECK-LABEL: define void @test_stand_alone_directives() llvm.func @test_stand_alone_directives() { @@ -264,3 +264,42 @@ } llvm.return } + +// ----- + +// CHECK: %struct.ident_t = type +// CHECK: @[[$parallel_loc:.*]] = private unnamed_addr constant {{.*}} c";LLVMDialectModule;wsloop_simple;{{[0-9]+}};{{[0-9]+}};;\00" +// CHECK: @[[$parallel_loc_struct:.*]] = private unnamed_addr constant %struct.ident_t {{.*}} @[[$parallel_loc]], {{.*}} + +// CHECK: @[[$wsloop_loc:.*]] = private unnamed_addr constant {{.*}} c";LLVMDialectModule;wsloop_simple;{{[0-9]+}};{{[0-9]+}};;\00" +// CHECK: @[[$wsloop_loc_struct:.*]] = private unnamed_addr constant %struct.ident_t {{.*}} @[[$wsloop_loc]], {{.*}} + +// CHECK-LABEL: @wsloop_simple +llvm.func @wsloop_simple(%arg0: !llvm.ptr) { + %0 = llvm.mlir.constant(42 : index) : !llvm.i64 + %1 = llvm.mlir.constant(10 : index) : !llvm.i64 + %2 = llvm.mlir.constant(1 : index) : !llvm.i64 + omp.parallel { + "omp.wsloop"(%1, %0, %2) ( { + ^bb0(%arg1: !llvm.i64): + // CHECK: %p.lastiter = alloca i32 + // CHECK: %[[PLOWER:.*]] = alloca i64 + // CHECK: %[[PUPPER:.*]] = alloca i64 + // CHECK: %[[PSTRIDE:.*]] = alloca i64 + // CHECK: store i64 {{.*}}, i64* %[[PLOWER]] + // CHECK: store i64 {{.*}}, i64* %[[PUPPER]] + // CHECK: store i64 {{.*}}, i64* %[[PSTRIDE]] + // CHECK: %[[TID:.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @[[$wsloop_loc_struct]]) + // CHECK: call void @__kmpc_for_static_init_8(%struct.ident_t* @[[$wsloop_loc_struct]], i32 %[[TID]], i32 34, i32* %p.lastiter, i64* %[[PLOWER]], i64* %[[PUPPER]], i64* %[[PSTRIDE]], i64 1, i64 1) + // CHECK: load i64, i64* %[[PLOWER]] + // CHECK: load i64, i64* %[[PUPPER]] + %3 = llvm.mlir.constant(2.000000e+00 : f32) : !llvm.float + %4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr + llvm.store %3, %4 : !llvm.ptr + omp.yield + // CHECK: call void @__kmpc_for_static_fini(%struct.ident_t* @[[$wsloop_loc_struct]], i32 %[[TID]]) + }) {operand_segment_sizes = dense<[1, 1, 1, 0, 0, 0, 0, 0, 0]> : vector<9xi32>} : (!llvm.i64, !llvm.i64, !llvm.i64) -> () + omp.terminator + } + llvm.return +}