diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -3710,7 +3710,10 @@ target.addLegalDialect(); // required NOPs for applying a full conversion - target.addLegalOp(); + target.addDynamicallyLegalOp([](mlir::ModuleOp op) { + // TODO Check for other attributes that need lowering + return !op->hasAttr("omp.requires"); + }); // If we're on Windows, we might need to rename some libm calls. bool isMSVC = fir::getTargetTriple(mod).isOSMSVCRT(); diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -175,6 +175,87 @@ return success(); } }; + +/// Rewrite pattern that processes `builtin.module` operations if they contain +/// the `omp.requires` attribute, removing it afterwards to mark the operation +/// as already-processed. +struct ModuleRequiresAttrConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(ModuleOp curOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!curOp->hasAttr("omp.requires")) + return success(); + + rewriter.updateRootInPlace(curOp.getOperation(), [&] { + auto offloadMod = cast(curOp.getOperation()); + if (!offloadMod.getIsDevice()) { + // Ad-hoc definition of 'requires' flags expected by the OpenMP runtime. + enum { + REVERSE_OFFLOAD = 0x002, + UNIFIED_ADDRESS = 0x004, + UNIFIED_SHARED_MEMORY = 0x008, + DYNAMIC_ALLOCATORS = 0x010 + }; + + // Populate flags from `omp.requires` attribute. + int32_t requiresFlags = 0; + auto requiresClauses = offloadMod.getRequires(); + if (bitEnumContainsAll(requiresClauses, + omp::ClauseRequires::reverse_offload)) + requiresFlags |= REVERSE_OFFLOAD; + if (bitEnumContainsAll(requiresClauses, + omp::ClauseRequires::unified_address)) + requiresFlags |= UNIFIED_ADDRESS; + if (bitEnumContainsAll(requiresClauses, + omp::ClauseRequires::unified_shared_memory)) + requiresFlags |= UNIFIED_SHARED_MEMORY; + if (bitEnumContainsAll(requiresClauses, + omp::ClauseRequires::dynamic_allocators)) + requiresFlags |= DYNAMIC_ALLOCATORS; + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(curOp.getBody()); + + // Declaration of `__tgt_register_requires(i32)`. + auto i32Type = IntegerType::get(getContext(), 32); + auto regFnType = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(getContext()), {i32Type}); + auto regFn = rewriter.create( + curOp->getLoc(), "__tgt_register_requires", regFnType, + LLVM::linkage::Linkage::External); + + // Definition of `.omp_offloading.requires_reg()`. + auto ctorFnType = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(getContext()), {}); + auto ctorFn = rewriter.create( + curOp->getLoc(), ".omp_offloading.requires_reg", ctorFnType, + LLVM::linkage::Linkage::Internal); + ctorFn.addEntryBlock(); + + OpBuilder ctorFnBuilder(ctorFn.getBody()); + auto flagsConstOp = ctorFnBuilder.create( + ctorFn->getLoc(), i32Type, + IntegerAttr::get(i32Type, requiresFlags)); + ctorFnBuilder.create(ctorFn->getLoc(), TypeRange(), + SymbolRefAttr::get(regFn), + flagsConstOp.getResult()); + ctorFnBuilder.create(ctorFn->getLoc(), nullptr); + + // Registration of `.omp_offloading.requires_reg()` as constructor. + auto ctors = rewriter.getArrayAttr( + {FlatSymbolRefAttr::get(getContext(), ctorFn.getSymName())}); + auto priorities = + rewriter.getArrayAttr({rewriter.getI32IntegerAttr(0)}); + rewriter.create(curOp->getLoc(), ctors, + priorities); + } + + curOp->removeAttr("omp.requires"); + }); + return success(); + } +}; } // namespace void mlir::configureOpenMPToLLVMConversionLegality( @@ -227,7 +308,8 @@ RegionLessOpWithVarOperandsConversion, RegionLessOpConversion, RegionLessOpConversion, - RegionLessOpConversion>(converter); + RegionLessOpConversion, ModuleRequiresAttrConversion>( + converter); } namespace {