diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -213,25 +213,12 @@ if (loop.lowerBound().empty()) return failure(); - if (loop.getNumLoops() != 1) - return opInst.emitOpError("collapsed loops not yet supported"); - // Static is the default. omp::ClauseScheduleKind schedule = omp::ClauseScheduleKind::Static; if (loop.schedule_val().hasValue()) schedule = *omp::symbolizeClauseScheduleKind(loop.schedule_val().getValue()); - // Find the loop configuration. - llvm::Value *lowerBound = moduleTranslation.lookupValue(loop.lowerBound()[0]); - llvm::Value *upperBound = moduleTranslation.lookupValue(loop.upperBound()[0]); - llvm::Value *step = moduleTranslation.lookupValue(loop.step()[0]); - llvm::Type *ivType = step->getType(); - llvm::Value *chunk = - loop.schedule_chunk_var() - ? moduleTranslation.lookupValue(loop.schedule_chunk_var()) - : llvm::ConstantInt::get(ivType, 1); - // Set up the source location value for OpenMP runtime. llvm::DISubprogram *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram(); @@ -240,12 +227,15 @@ llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder.saveIP(), llvm::DebugLoc(diLoc)); - // Generator of the canonical loop body. Produces an SESE region of basic - // blocks. + // Generator of the canonical loop body. // TODO: support error propagation in OpenMPIRBuilder and use it instead of // relying on captured variables. + SmallVector loopInfos; LogicalResult bodyGenStatus = success(); auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) { + if (loopInfos.size() != loop.getNumLoops() - 1) + return; + llvm::IRBuilder<>::InsertPointGuard guard(builder); // Make sure further conversions know about the induction variable. @@ -264,17 +254,37 @@ // TODO: this currently assumes WsLoop is semantically similar to SCF loop, // i.e. it has a positive step, uses signed integer semantics. Reconsider // this code when WsLoop clearly supports more cases. - llvm::CanonicalLoopInfo *loopInfo = - moduleTranslation.getOpenMPBuilder()->createCanonicalLoop( - ompLoc, bodyGen, lowerBound, upperBound, step, /*IsSigned=*/true, - /*InclusiveStop=*/loop.inclusive()); + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + for (unsigned i = 0, e = loop.getNumLoops(); i < e; ++i) { + llvm::Value *lowerBound = + moduleTranslation.lookupValue(loop.lowerBound()[i]); + llvm::Value *upperBound = + moduleTranslation.lookupValue(loop.upperBound()[i]); + llvm::Value *step = moduleTranslation.lookupValue(loop.step()[i]); + llvm::OpenMPIRBuilder::LocationDescription loc = + i == 0 ? ompLoc + : llvm::OpenMPIRBuilder::LocationDescription( + loopInfos.back()->getBodyIP(), llvm::DebugLoc(diLoc)); + loopInfos.push_back(ompBuilder->createCanonicalLoop( + loc, bodyGen, lowerBound, upperBound, step, + /*IsSigned=*/true, loop.inclusive())); + } if (failed(bodyGenStatus)) return failure(); + // Collapse loops. + llvm::CanonicalLoopInfo *loopInfo = + ompBuilder->collapseLoops(diLoc, loopInfos, {}); + + // Find the loop configuration. + llvm::Type *ivType = loopInfo->getIndVar()->getType(); + llvm::Value *chunk = + loop.schedule_chunk_var() + ? moduleTranslation.lookupValue(loop.schedule_chunk_var()) + : llvm::ConstantInt::get(ivType, 1); llvm::OpenMPIRBuilder::InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation); llvm::OpenMPIRBuilder::InsertPointTy afterIP; - llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); if (schedule == omp::ClauseScheduleKind::Static) { loopInfo = ompBuilder->createStaticWorkshareLoop(ompLoc, loopInfo, allocaIP, !loop.nowait(), chunk);