diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -45,6 +45,7 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" +#include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/ModuleUtils.h" @@ -1274,23 +1275,41 @@ if (!ompBuilder) { ompBuilder = std::make_unique(*llvmModule); + llvm::StringRef targetTriple = ""; bool isDevice = false; llvm::StringRef hostIRFilePath = ""; - if (Attribute deviceAttr = mlirModule->getAttr("omp.is_device")) - if (::llvm::isa(deviceAttr)) - isDevice = ::llvm::dyn_cast(deviceAttr).getValue(); + if (auto targetAttr = mlirModule->getAttrOfType( + mlir::LLVM::LLVMDialect::getTargetTripleAttrName())) + targetTriple = targetAttr.getValue(); - if (Attribute filepath = mlirModule->getAttr("omp.host_ir_filepath")) - if (::llvm::isa(filepath)) - hostIRFilePath = - ::llvm::dyn_cast(filepath).getValue(); + if (auto deviceAttr = + mlirModule->getAttrOfType("omp.is_device")) + isDevice = deviceAttr.getValue(); + + if (auto filepathAttr = + mlirModule->getAttrOfType("omp.host_ir_filepath")) + hostIRFilePath = filepathAttr.getValue(); ompBuilder->initialize(hostIRFilePath); + bool isTargetCodegen; + switch (llvm::Triple(targetTriple).getArch()) { + case llvm::Triple::nvptx: + case llvm::Triple::nvptx64: + case llvm::Triple::amdgcn: + assert(isDevice && + "OpenMP AMDGPU/NVPTX is only prepared to deal with device code."); + isTargetCodegen = true; + break; + default: + isTargetCodegen = false; + break; + } + // TODO: set the flags when available llvm::OpenMPIRBuilderConfig config( - isDevice, /* IsTargetCodegen */ false, + isDevice, isTargetCodegen, /* HasRequiresUnifiedSharedMemory */ false, /* OpenMPOffloadMandatory */ false); ompBuilder->setConfig(config);