diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -222,16 +222,6 @@ schedule = *omp::symbolizeClauseScheduleKind(loop.schedule_val().getValue()); - // Find the loop configuration. - llvm::Value *lowerBound = moduleTranslation.lookupValue(loop.lowerBound()[0]); - llvm::Value *upperBound = moduleTranslation.lookupValue(loop.upperBound()[0]); - llvm::Value *step = moduleTranslation.lookupValue(loop.step()[0]); - llvm::Type *ivType = step->getType(); - llvm::Value *chunk = - loop.schedule_chunk_var() - ? moduleTranslation.lookupValue(loop.schedule_chunk_var()) - : llvm::ConstantInt::get(ivType, 1); - // Set up the source location value for OpenMP runtime. llvm::DISubprogram *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram(); @@ -264,17 +254,40 @@ // TODO: this currently assumes WsLoop is semantically similar to SCF loop, // i.e. it has a positive step, uses signed integer semantics. Reconsider // this code when WsLoop clearly supports more cases. + auto doNothingBodyGen = [](llvm::OpenMPIRBuilder::InsertPointTy, + llvm::Value *) {}; + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + SmallVector loopInfos; + for (unsigned i = 0, e = loop.getNumLoops(); i < e; ++i) { + llvm::Value *lowerBound = + moduleTranslation.lookupValue(loop.lowerBound()[i]); + llvm::Value *upperBound = + moduleTranslation.lookupValue(loop.upperBound()[i]); + llvm::Value *step = moduleTranslation.lookupValue(loop.step()[i]); + loopInfos.push_back(ompBuilder->createCanonicalLoop( + i == 0 ? ompLoc + : llvm::OpenMPIRBuilder::LocationDescription( + loopInfos.back()->getBodyIP(), llvm::DebugLoc(diLoc)), + doNothingBodyGen, lowerBound, upperBound, step, + /*IsSigned=*/true, loop.inclusive())); + } + + // Collapse loops. llvm::CanonicalLoopInfo *loopInfo = - moduleTranslation.getOpenMPBuilder()->createCanonicalLoop( - ompLoc, bodyGen, lowerBound, upperBound, step, /*IsSigned=*/true, - /*InclusiveStop=*/loop.inclusive()); + ompBuilder->collapseLoops(diLoc, loopInfos, {}); + bodyGen(loopInfo->getBodyIP(), loopInfo->getIndVar()); if (failed(bodyGenStatus)) return failure(); + // Find the loop configuration. + llvm::Type *ivType = loopInfo->getIndVar()->getType(); + llvm::Value *chunk = + loop.schedule_chunk_var() + ? moduleTranslation.lookupValue(loop.schedule_chunk_var()) + : llvm::ConstantInt::get(ivType, 1); llvm::OpenMPIRBuilder::InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation); llvm::OpenMPIRBuilder::InsertPointTy afterIP; - llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); if (schedule == omp::ClauseScheduleKind::Static) { loopInfo = ompBuilder->createStaticWorkshareLoop(ompLoc, loopInfo, allocaIP, !loop.nowait(), chunk);