diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -252,25 +252,12 @@ if (loop.lowerBound().empty()) return failure(); - if (loop.getNumLoops() != 1) - return opInst.emitOpError("collapsed loops not yet supported"); - // Static is the default. omp::ClauseScheduleKind schedule = omp::ClauseScheduleKind::Static; if (loop.schedule_val().hasValue()) schedule = *omp::symbolizeClauseScheduleKind(loop.schedule_val().getValue()); - // Find the loop configuration. - llvm::Value *lowerBound = moduleTranslation.lookupValue(loop.lowerBound()[0]); - llvm::Value *upperBound = moduleTranslation.lookupValue(loop.upperBound()[0]); - llvm::Value *step = moduleTranslation.lookupValue(loop.step()[0]); - llvm::Type *ivType = step->getType(); - llvm::Value *chunk = - loop.schedule_chunk_var() - ? moduleTranslation.lookupValue(loop.schedule_chunk_var()) - : llvm::ConstantInt::get(ivType, 1); - // Set up the source location value for OpenMP runtime. llvm::DISubprogram *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram(); @@ -279,22 +266,29 @@ llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder.saveIP(), llvm::DebugLoc(diLoc)); - // Generator of the canonical loop body. Produces an SESE region of basic - // blocks. + // Generator of the canonical loop body. // TODO: support error propagation in OpenMPIRBuilder and use it instead of // relying on captured variables. + SmallVector loopInfos; + SmallVector bodyInsertPoints; LogicalResult bodyGenStatus = success(); auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) { - llvm::IRBuilder<>::InsertPointGuard guard(builder); - // Make sure further conversions know about the induction variable. - moduleTranslation.mapValue(loop.getRegion().front().getArgument(0), iv); + moduleTranslation.mapValue( + loop.getRegion().front().getArgument(loopInfos.size()), iv); + + // Capture the body insertion point for use in nested loops. BodyIP of the + // CanonicalLoopInfo always points to the beginning of the entry block of + // the body. + bodyInsertPoints.push_back(ip); + + if (loopInfos.size() != loop.getNumLoops() - 1) + return; + // Convert the body of the loop. llvm::BasicBlock *entryBlock = ip.getBlock(); llvm::BasicBlock *exitBlock = entryBlock->splitBasicBlock(ip.getPoint(), "omp.wsloop.exit"); - - // Convert the body of the loop. convertOmpOpRegions(loop.region(), "omp.wsloop.region", *entryBlock, *exitBlock, builder, moduleTranslation, bodyGenStatus); }; @@ -303,21 +297,49 @@ // TODO: this currently assumes WsLoop is semantically similar to SCF loop, // i.e. it has a positive step, uses signed integer semantics. Reconsider // this code when WsLoop clearly supports more cases. + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + for (unsigned i = 0, e = loop.getNumLoops(); i < e; ++i) { + llvm::Value *lowerBound = + moduleTranslation.lookupValue(loop.lowerBound()[i]); + llvm::Value *upperBound = + moduleTranslation.lookupValue(loop.upperBound()[i]); + llvm::Value *step = moduleTranslation.lookupValue(loop.step()[i]); + + // Make sure loop trip count are emitted in the preheader of the outermost + // loop at the latest so that they are all available for the new collapsed + // loop will be created below. + llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc; + llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP; + if (i != 0) { + loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(), + llvm::DebugLoc(diLoc)); + computeIP = loopInfos.front()->getPreheaderIP(); + } + loopInfos.push_back(ompBuilder->createCanonicalLoop( + loc, bodyGen, lowerBound, upperBound, step, + /*IsSigned=*/true, loop.inclusive(), computeIP)); + + if (failed(bodyGenStatus)) + return failure(); + } + + // Collapse loops. Store the insertion point because LoopInfos may get + // invalidated. + llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP(); llvm::CanonicalLoopInfo *loopInfo = - moduleTranslation.getOpenMPBuilder()->createCanonicalLoop( - ompLoc, bodyGen, lowerBound, upperBound, step, /*IsSigned=*/true, - /*InclusiveStop=*/loop.inclusive()); - if (failed(bodyGenStatus)) - return failure(); + ompBuilder->collapseLoops(diLoc, loopInfos, {}); + // Find the loop configuration. + llvm::Type *ivType = loopInfo->getIndVar()->getType(); + llvm::Value *chunk = + loop.schedule_chunk_var() + ? moduleTranslation.lookupValue(loop.schedule_chunk_var()) + : llvm::ConstantInt::get(ivType, 1); llvm::OpenMPIRBuilder::InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation); - llvm::OpenMPIRBuilder::InsertPointTy afterIP; - llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); if (schedule == omp::ClauseScheduleKind::Static) { - loopInfo = ompBuilder->createStaticWorkshareLoop(ompLoc, loopInfo, allocaIP, - !loop.nowait(), chunk); - afterIP = loopInfo->getAfterIP(); + ompBuilder->createStaticWorkshareLoop(ompLoc, loopInfo, allocaIP, + !loop.nowait(), chunk); } else { llvm::omp::OMPScheduleType schedType; switch (schedule) { @@ -338,11 +360,14 @@ break; } - afterIP = ompBuilder->createDynamicWorkshareLoop( - ompLoc, loopInfo, allocaIP, schedType, !loop.nowait(), chunk); + ompBuilder->createDynamicWorkshareLoop(ompLoc, loopInfo, allocaIP, + schedType, !loop.nowait(), chunk); } - // Continue building IR after the loop. + // Continue building IR after the loop. Note that the LoopInfo returned by + // `collapseLoops` points inside the outermost loop and is intended for + // potential further loop transformations. Use the insertion point stored + // before collapsing loops instead. builder.restoreIP(afterIP); return success(); } diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -467,6 +467,8 @@ llvm.return } +// ----- + // CHECK-LABEL: @omp_critical llvm.func @omp_critical(%x : !llvm.ptr, %xval : i32) -> () { // CHECK: call void @__kmpc_critical_with_hint({{.*}}critical_user_.var{{.*}}, i32 0) @@ -488,6 +490,65 @@ omp.terminator } // CHECK: call void @__kmpc_end_critical({{.*}}critical_user_mutex.var{{.*}}) + llvm.return +} + +// ----- +// Check that the loop bounds are emitted in the correct location in case of +// collapse. This only checks the overall shape of the IR, detailed checking +// is done by the OpenMPIRBuilder. + +// CHECK-LABEL: @collapse_wsloop +// CHECK: i32* noalias %[[TIDADDR:[0-9A-Za-z.]*]] +// CHECK: load i32, i32* %[[TIDADDR]] +// CHECK: store +// CHECK: load +// CHECK: %[[LB0:.*]] = load i32 +// CHECK: %[[UB0:.*]] = load i32 +// CHECK: %[[STEP0:.*]] = load i32 +// CHECK: %[[LB1:.*]] = load i32 +// CHECK: %[[UB1:.*]] = load i32 +// CHECK: %[[STEP1:.*]] = load i32 +// CHECK: %[[LB2:.*]] = load i32 +// CHECK: %[[UB2:.*]] = load i32 +// CHECK: %[[STEP2:.*]] = load i32 +llvm.func @collapse_wsloop( + %0: i32, %1: i32, %2: i32, + %3: i32, %4: i32, %5: i32, + %6: i32, %7: i32, %8: i32, + %20: !llvm.ptr) { + omp.parallel { + // CHECK: icmp slt i32 %[[LB0]], 0 + // CHECK-COUNT-4: select + // CHECK: %[[TRIPCOUNT0:.*]] = select + // CHECK: br label %[[PREHEADER:.*]] + // + // CHECK: [[PREHEADER]]: + // CHECK: icmp slt i32 %[[LB1]], 0 + // CHECK-COUNT-4: select + // CHECK: %[[TRIPCOUNT1:.*]] = select + // CHECK: icmp slt i32 %[[LB2]], 0 + // CHECK-COUNT-4: select + // CHECK: %[[TRIPCOUNT2:.*]] = select + // CHECK: %[[PROD:.*]] = mul nuw i32 %[[TRIPCOUNT0]], %[[TRIPCOUNT1]] + // CHECK: %[[TOTAL:.*]] = mul nuw i32 %[[PROD]], %[[TRIPCOUNT2]] + // CHECK: br label %[[COLLAPSED_PREHEADER:.*]] + // + // CHECK: [[COLLAPSED_PREHEADER]]: + // CHECK: store i32 0, i32* + // CHECK: %[[TOTAL_SUB_1:.*]] = sub i32 %[[TOTAL]], 1 + // CHECK: store i32 %[[TOTAL_SUB_1]], i32* + // CHECK: call void @__kmpc_for_static_init_4u + omp.wsloop (%arg0, %arg1, %arg2) : i32 = (%0, %1, %2) to (%3, %4, %5) step (%6, %7, %8) collapse(3) { + %31 = llvm.load %20 : !llvm.ptr + %32 = llvm.add %31, %arg0 : i32 + %33 = llvm.add %32, %arg1 : i32 + %34 = llvm.add %33, %arg2 : i32 + llvm.store %34, %20 : !llvm.ptr + omp.yield + } + omp.terminator + } llvm.return }