diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -185,6 +185,13 @@ ]; let regions = (region AnyRegion:$region); + + let verifier = [{ return ::verifyWsLoop(*this); }]; + + let extraClassDeclaration = [{ + /// Returns the number of loop in the workshape loop nest. + unsigned getNumLoops() { return lowerBound().size(); } + }]; } def YieldOp : OpenMP_Op<"yield", [NoSideEffect, ReturnLike, Terminator, diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -91,6 +91,8 @@ llvm::IRBuilder<> &builder); virtual LogicalResult convertOmpParallel(Operation &op, llvm::IRBuilder<> &builder); + virtual LogicalResult convertOmpWsLoop(Operation &opInst, + llvm::IRBuilder<> &builder); /// Converts the type from MLIR LLVM dialect to LLVM. llvm::Type *convertType(LLVMType type); diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -390,5 +390,18 @@ state.addAttributes(attributes); } +static LogicalResult verifyWsLoop(WsLoopOp op) { + if (op.lowerBound().size() != op.upperBound().size()) + return op.emitOpError( + "expects the same number of lower and upper bound operands"); + if (op.lowerBound().size() != op.step().size()) + return op.emitOpError( + "expects the same number of lower bound and step operands"); + if (op.lowerBound().size() != op.getRegion().front().getNumArguments()) + return op.emitOpError("expects the same number of lower bound operands and " + "entry block arguments"); + return success(); +} + #define GET_OP_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -464,6 +464,107 @@ return success(); } +LogicalResult ModuleTranslation::convertOmpWsLoop(Operation &opInst, + llvm::IRBuilder<> &builder) { + auto loop = cast(opInst); + if (loop.lowerBound().empty()) + return failure(); + + llvm::Function *func = builder.GetInsertBlock()->getParent(); + llvm::LLVMContext &llvmContext = llvmModule->getContext(); + + auto noBody = [&](llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *iv) {}; + + // Build the structure of nested loops. + SmallVector loopInfos; + loopInfos.reserve(loop.getNumLoops()); + for (int i = 0, e = loop.getNumLoops(); i < e; ++i) { + llvm::CanonicalLoopInfo *loopInfo = ompBuilder->createCanonicalLoop( + builder, noBody, valueMapping.lookup(loop.lowerBound()[i]), + valueMapping.lookup(loop.upperBound()[i]), + valueMapping.lookup(loop.step()[i]), + /*IsSigned=*/true, + /*InclusiveStop=*/false); + valueMapping[loop.getRegion().front().getArgument(i)] = + loopInfo->getIndVar(); + + // If it's not the first loop, we will need to create a branch to the latch + // block of the previous loop in the after block of this loop. + if (i != 0) { + assert(!loopInfos.empty() && "expected loop info from the previous loop"); + builder.restoreIP(loopInfo->getAfterIP()); + builder.CreateBr(loopInfos.back()->getLatch()); + } + loopInfos.push_back(loopInfo); + + // Creating a canonical loop will insert a branch to its preheader at the + // insertion point of the given builder. Create a new block, update the body + // block to branch to this block instead of the latch block and set + // insertion point to that block so it branches further to the child loop + // when created. This makes us have an extra block, but makes the loop + // structure more apparent and avoids an obscure memory error triggered + // if we attempt to erase the terminator of the body block and create a new + // one instead. If we are emitting the last loop, a similar trick will be + // performed below, so skip an extra block here. + if (i != e - 1) { + llvm::BasicBlock *dispatchBlock = llvm::BasicBlock::Create( + llvmContext, "omp.wsloop.forward" + llvm::Twine(i + 1), func); + llvm::BranchInst *branch = + cast(loopInfo->getBody()->getTerminator()); + branch->setSuccessor(0, dispatchBlock); + + // This needs to be at the end of the loop so that the next operation + // creates the branch to the preheader block at the right place. + builder.SetInsertPoint(dispatchBlock); + } + } + + // Convert the body of the loop. + Region ®ion = loop.region(); + for (Block &bb : region) { + llvm::BasicBlock *llvmBB = + llvm::BasicBlock::Create(llvmContext, "omp.wsloop.region", func); + blockMapping[&bb] = llvmBB; + + // Retarget the branch in the body block automatically created by + // OpenMPIRBuilder to branch to the entry block of the loop region instead. + // TODO: this is a hacky workaround for createCanonicalLoop forcibly + // inserting a branch to the latch block first, and then expecting the + // caller to only populate the _single_ body block. We need more than one. + if (bb.isEntryBlock()) { + llvm::BranchInst *branch = + cast(loopInfos.back()->getBody()->getTerminator()); + branch->setSuccessor(0, llvmBB); + } + } + + // Block conversion creates a new IRBuilder every time so need not bother + // about maintaining the insertion point. + llvm::SetVector blocks = topologicalSort(region); + for (Block *bb : blocks) { + if (failed(convertBlock(*bb, bb->isEntryBlock()))) + return failure(); + + // Special handling for `omp.yield` terminators (we may have more than one): + // they return the control to the parent WsLoop operation so replace them + // with the branch to the latch block. We handle this here to avoid relying + // inter-function communication through the ModuleTranslation class to set + // up the correct insertion point. This is also consistent with MLIR's + // idiom of handling special region terminators in the same code that + // handles the region-owning operation. + if (isa(bb->getTerminator())) { + llvm::BasicBlock *llvmBB = blockMapping[bb]; + builder.SetInsertPoint(llvmBB, llvmBB->end()); + builder.CreateBr(loopInfos.back()->getLatch()); + } + } + connectPHINodes(region, valueMapping, blockMapping); + + // Set builder to insert instructions after the outermost loop. + builder.restoreIP(loopInfos.front()->getAfterIP()); + return success(); +} + /// Given an OpenMP MLIR operation, create the corresponding LLVM IR /// (including OpenMP runtime calls). LogicalResult @@ -504,6 +605,13 @@ }) .Case( [&](omp::ParallelOp) { return convertOmpParallel(opInst, builder); }) + .Case([&](omp::WsLoopOp) { return convertOmpWsLoop(opInst, builder); }) + .Case([&](omp::YieldOp op) { + // Yields are loop terminators that can be just omitted. The loop + // structure was created in the function that handles WsLoopOp. + assert(op.getNumOperands() == 0 && "unexpected yield with operands"); + return success(); + }) .Default([&](Operation *inst) { return inst->emitError("unsupported OpenMP operation: ") << inst->getName(); diff --git a/mlir/test/Target/openmp-llvm.mlir b/mlir/test/Target/openmp-llvm.mlir --- a/mlir/test/Target/openmp-llvm.mlir +++ b/mlir/test/Target/openmp-llvm.mlir @@ -264,3 +264,39 @@ } llvm.return } + +// TODO: add filecheks for this, for now just make sure we don't crash + +llvm.func @payload(!llvm.i64) attributes {sym_visibility = "private"} +llvm.func @foo(%arg0: !llvm.i64, %arg1: !llvm.i64, %arg2: !llvm.i64) { + omp.parallel { + "omp.wsloop"(%arg0, %arg2, %arg1) ( { + ^bb0(%arg3: !llvm.i64): + llvm.call @payload(%arg3) : (!llvm.i64) -> () + llvm.br ^bb1 + ^bb1: + %0 = llvm.add %arg3, %arg3 : !llvm.i64 + llvm.call @payload(%0) : (!llvm.i64) -> () + omp.yield + }) {operand_segment_sizes = dense<[1, 1, 1, 0, 0, 0, 0, 0, 0]> : vector<9xi32>} : (!llvm.i64, !llvm.i64, !llvm.i64) -> () + omp.terminator + } + llvm.return +} + +llvm.func @bar(%arg0: !llvm.i64, %arg1: !llvm.i64, %arg2: !llvm.i64) { + omp.parallel { + "omp.wsloop"(%arg0, %arg0, %arg2, %arg2, %arg1, %arg1) ( { + ^bb0(%arg3: !llvm.i64, %arg5: !llvm.i64): + llvm.call @payload(%arg3) : (!llvm.i64) -> () + llvm.br ^bb1 + ^bb1: + %0 = llvm.add %arg5, %arg5 : !llvm.i64 + llvm.call @payload(%0) : (!llvm.i64) -> () + omp.yield + }) {operand_segment_sizes = dense<[2, 2, 2, 0, 0, 0, 0, 0, 0]> : vector<9xi32>} : (!llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64) -> () + omp.terminator + } + llvm.return +} +