Index: flang/lib/Frontend/FrontendActions.cpp =================================================================== --- flang/lib/Frontend/FrontendActions.cpp +++ flang/lib/Frontend/FrontendActions.cpp @@ -233,14 +233,18 @@ // Fetch module from lb, so we can set mlirModule = std::make_unique(lb.getModule()); + if (!setUpTargetMachine()) + return false; + if (ci.getInvocation().getFrontendOpts().features.IsEnabled( Fortran::common::LanguageFeature::OpenMP)) { mlir::omp::OpenMPDialect::setIsDevice( *mlirModule, ci.getInvocation().getLangOpts().OpenMPIsDevice); + mlir::omp::OpenMPDialect::setTargetCpu(*mlirModule, tm->getTargetCPU()); + mlir::omp::OpenMPDialect::setTargetCpuFeatures( + *mlirModule, tm->getTargetFeatureString()); } - if (!setUpTargetMachine()) - return false; const llvm::DataLayout &dl = tm->createDataLayout(); setMLIRDataLayout(*mlirModule, dl); Index: flang/test/Lower/OpenMP/target_cpu_features.f90 =================================================================== --- /dev/null +++ flang/test/Lower/OpenMP/target_cpu_features.f90 @@ -0,0 +1,16 @@ +!REQUIRES: amdgpu-registered-target +!RUN: %flang_fc1 -emit-fir -triple amdgcn-amd-amdhsa -target-cpu gfx908 -fopenmp %s -o - | FileCheck %s + +!=============================================================================== +! Target_Enter Simple +!=============================================================================== + +!CHECK: omp.target_cpu = "gfx908", +!CHECK-SAME: omp.target_cpu_features = "+dot3-insts,+dot4-insts,+s-memtime-inst, +!CHECK-SAME: +16-bit-insts,+s-memrealtime,+dot6-insts,+dl-insts,+wavefrontsize64, +!CHECK-SAME: +gfx9-insts,+gfx8-insts,+ci-insts,+dot10-insts,+dot7-insts, +!CHECK-SAME: +dot1-insts,+dot5-insts,+mai-insts,+dpp,+dot2-insts" +!CHECK-LABEL: func.func @_QPomp_target_simple() { +subroutine omp_target_simple +end subroutine omp_target_simple + Index: llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h =================================================================== --- llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -1931,6 +1931,13 @@ /// \param Name Name of the variable. GlobalVariable *getOrCreateInternalVariable(Type *Ty, const StringRef &Name, unsigned AddressSpace = 0); + + /// Add function attribute to all functions which are defined inside module + /// but they don't have given attribute + /// \param AttributeName Name of the attribute + /// \param AttributeValue Value of the attribute + void addAttributeToModuleFunctions(StringRef AttributeName, + StringRef AttributeValue); }; /// Data structure to contain the information needed to uniquely identify Index: llvm/include/llvm/Transforms/Utils/CodeExtractor.h =================================================================== --- llvm/include/llvm/Transforms/Utils/CodeExtractor.h +++ llvm/include/llvm/Transforms/Utils/CodeExtractor.h @@ -173,6 +173,13 @@ const Function &NewFunc, AssumptionCache *AC); + /// Inherit all of the target dependent attributes and white-listed target + /// independent attributes. (e.g. If the extracted region contains a call to + /// an x86.sse instruction we need to make sure that the extracted region + /// has the "target-features" attribute allowing it to be lowered. + + static void inheritTargetDependentAttributes(const Function *oldFunction, + Function *newFunction); /// Test whether this code extractor is eligible. /// /// Based on the blocks used when constructing the code extractor, Index: llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp =================================================================== --- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -1405,6 +1405,7 @@ (Twine(OutlinedFn.getName()) + ".wrapper").str(), FunctionType::get(Builder.getInt32Ty(), WrapperArgTys, false)); Function *WrapperFunc = dyn_cast(WrapperFuncVal.getCallee()); + CodeExtractor::inheritTargetDependentAttributes(&OutlinedFn, WrapperFunc); PointerType *WrapperFuncBitcastType = FunctionType::get(Builder.getInt32Ty(), {Builder.getInt32Ty(), Builder.getInt8PtrTy()}, false) @@ -5041,6 +5042,19 @@ } } +void OpenMPIRBuilder::addAttributeToModuleFunctions(StringRef AttributeName, + StringRef AttributeValue) { + if (AttributeName.empty() || AttributeValue.empty()) + return; + for (Function &f : M.functions()) { + if (f.isDeclaration()) + continue; + if (f.hasFnAttribute(AttributeName)) + continue; + f.addFnAttr(AttributeName, AttributeValue); + } +} + void TargetRegionEntryInfo::getTargetRegionEntryFnName( SmallVectorImpl &Name, StringRef ParentName, unsigned DeviceID, unsigned FileID, unsigned Line, unsigned Count) { Index: llvm/lib/Transforms/Utils/CodeExtractor.cpp =================================================================== --- llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -807,91 +807,8 @@ } } -/// constructFunction - make a function based on inputs and outputs, as follows: -/// f(in0, ..., inN, out0, ..., outN) -Function *CodeExtractor::constructFunction(const ValueSet &inputs, - const ValueSet &outputs, - BasicBlock *header, - BasicBlock *newRootNode, - BasicBlock *newHeader, - Function *oldFunction, - Module *M) { - LLVM_DEBUG(dbgs() << "inputs: " << inputs.size() << "\n"); - LLVM_DEBUG(dbgs() << "outputs: " << outputs.size() << "\n"); - - // This function returns unsigned, outputs will go back by reference. - switch (NumExitBlocks) { - case 0: - case 1: RetTy = Type::getVoidTy(header->getContext()); break; - case 2: RetTy = Type::getInt1Ty(header->getContext()); break; - default: RetTy = Type::getInt16Ty(header->getContext()); break; - } - - std::vector ParamTy; - std::vector AggParamTy; - ValueSet StructValues; - const DataLayout &DL = M->getDataLayout(); - - // Add the types of the input values to the function's argument list - for (Value *value : inputs) { - LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n"); - if (AggregateArgs && !ExcludeArgsFromAggregate.contains(value)) { - AggParamTy.push_back(value->getType()); - StructValues.insert(value); - } else - ParamTy.push_back(value->getType()); - } - - // Add the types of the output values to the function's argument list. - for (Value *output : outputs) { - LLVM_DEBUG(dbgs() << "instr used in func: " << *output << "\n"); - if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) { - AggParamTy.push_back(output->getType()); - StructValues.insert(output); - } else - ParamTy.push_back( - PointerType::get(output->getType(), DL.getAllocaAddrSpace())); - } - - assert( - (ParamTy.size() + AggParamTy.size()) == - (inputs.size() + outputs.size()) && - "Number of scalar and aggregate params does not match inputs, outputs"); - assert((StructValues.empty() || AggregateArgs) && - "Expeced StructValues only with AggregateArgs set"); - - // Concatenate scalar and aggregate params in ParamTy. - size_t NumScalarParams = ParamTy.size(); - StructType *StructTy = nullptr; - if (AggregateArgs && !AggParamTy.empty()) { - StructTy = StructType::get(M->getContext(), AggParamTy); - ParamTy.push_back(PointerType::get(StructTy, DL.getAllocaAddrSpace())); - } - - LLVM_DEBUG({ - dbgs() << "Function type: " << *RetTy << " f("; - for (Type *i : ParamTy) - dbgs() << *i << ", "; - dbgs() << ")\n"; - }); - - FunctionType *funcType = FunctionType::get( - RetTy, ParamTy, AllowVarArgs && oldFunction->isVarArg()); - - std::string SuffixToUse = - Suffix.empty() - ? (header->getName().empty() ? "extracted" : header->getName().str()) - : Suffix; - // Create the new function - Function *newFunction = Function::Create( - funcType, GlobalValue::InternalLinkage, oldFunction->getAddressSpace(), - oldFunction->getName() + "." + SuffixToUse, M); - - // Inherit all of the target dependent attributes and white-listed - // target independent attributes. - // (e.g. If the extracted region contains a call to an x86.sse - // instruction we need to make sure that the extracted region has the - // "target-features" attribute allowing it to be lowered. +void CodeExtractor::inheritTargetDependentAttributes( + const Function *oldFunction, Function *newFunction) { // FIXME: This should be changed to check to see if a specific // attribute can not be inherited. for (const auto &Attr : oldFunction->getAttributes().getFnAttrs()) { @@ -997,9 +914,94 @@ case Attribute::TombstoneKey: llvm_unreachable("Not a function attribute"); } - newFunction->addFnAttr(Attr); } +} +/// constructFunction - make a function based on inputs and outputs, as follows: +/// f(in0, ..., inN, out0, ..., outN) +Function *CodeExtractor::constructFunction(const ValueSet &inputs, + const ValueSet &outputs, + BasicBlock *header, + BasicBlock *newRootNode, + BasicBlock *newHeader, + Function *oldFunction, Module *M) { + LLVM_DEBUG(dbgs() << "inputs: " << inputs.size() << "\n"); + LLVM_DEBUG(dbgs() << "outputs: " << outputs.size() << "\n"); + // This function returns unsigned, outputs will go back by reference. + switch (NumExitBlocks) { + case 0: + case 1: + RetTy = Type::getVoidTy(header->getContext()); + break; + case 2: + RetTy = Type::getInt1Ty(header->getContext()); + break; + default: + RetTy = Type::getInt16Ty(header->getContext()); + break; + } + + std::vector ParamTy; + std::vector AggParamTy; + ValueSet StructValues; + const DataLayout &DL = M->getDataLayout(); + + // Add the types of the input values to the function's argument list + for (Value *value : inputs) { + LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n"); + if (AggregateArgs && !ExcludeArgsFromAggregate.contains(value)) { + AggParamTy.push_back(value->getType()); + StructValues.insert(value); + } else + ParamTy.push_back(value->getType()); + } + + // Add the types of the output values to the function's argument list. + for (Value *output : outputs) { + LLVM_DEBUG(dbgs() << "instr used in func: " << *output << "\n"); + if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) { + AggParamTy.push_back(output->getType()); + StructValues.insert(output); + } else + ParamTy.push_back( + PointerType::get(output->getType(), DL.getAllocaAddrSpace())); + } + + assert( + (ParamTy.size() + AggParamTy.size()) == + (inputs.size() + outputs.size()) && + "Number of scalar and aggregate params does not match inputs, outputs"); + assert((StructValues.empty() || AggregateArgs) && + "Expeced StructValues only with AggregateArgs set"); + + // Concatenate scalar and aggregate params in ParamTy. + size_t NumScalarParams = ParamTy.size(); + StructType *StructTy = nullptr; + if (AggregateArgs && !AggParamTy.empty()) { + StructTy = StructType::get(M->getContext(), AggParamTy); + ParamTy.push_back(PointerType::get(StructTy, DL.getAllocaAddrSpace())); + } + + LLVM_DEBUG({ + dbgs() << "Function type: " << *RetTy << " f("; + for (Type *i : ParamTy) + dbgs() << *i << ", "; + dbgs() << ")\n"; + }); + + FunctionType *funcType = FunctionType::get( + RetTy, ParamTy, AllowVarArgs && oldFunction->isVarArg()); + + std::string SuffixToUse = + Suffix.empty() + ? (header->getName().empty() ? "extracted" : header->getName().str()) + : Suffix; + // Create the new function + Function *newFunction = Function::Create( + funcType, GlobalValue::InternalLinkage, oldFunction->getAddressSpace(), + oldFunction->getName() + "." + SuffixToUse, M); + + inheritTargetDependentAttributes(oldFunction, newFunction); newFunction->insert(newFunction->end(), newRootNode); // Create scalar and aggregate iterators to name all of the arguments we Index: mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td =================================================================== --- mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -36,6 +36,21 @@ // Return the value of the omp.is_device attribute stored in the module if it // exists, otherwise return false by default static bool getIsDevice(Operation* module); + + // Set the omp.target_cpu attribute on the module with the specified string + static void setTargetCpu(Operation* module, StringRef cpu); + + // Return the value of the omp.target_cpu attribute stored in the module if it + // exists, otherwise return empty by default + static std::string getTargetCpu(Operation* module); + + // Set the omp.target_cpu_features attribute on the module with + // the specified string + static void setTargetCpuFeatures(Operation* module, StringRef cpuFeatures); + + // Return the value of the omp.target_cpu_features attribute stored in + // the module if it exists, otherwise return empty by default + static std::string getTargetCpuFeatures(Operation* module); }]; } Index: mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp =================================================================== --- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1437,6 +1437,40 @@ return false; } +// Set the omp.target_cpu attribute on the module with the specified string +void OpenMPDialect::setTargetCpu(Operation *module, llvm::StringRef cpu) { + module->setAttr(mlir::StringAttr::get(module->getContext(), + llvm::Twine{"omp.target_cpu"}), + mlir::StringAttr::get(module->getContext(), cpu)); +} + +// Return the value of the omp.target_cpu attribute stored in the module if it +// exists, otherwise return empty by default +std::string OpenMPDialect::getTargetCpu(Operation *module) { + if (Attribute targetCpu = module->getAttr("omp.target_cpu")) + if (targetCpu.isa()) + return targetCpu.dyn_cast().getValue().str(); + return llvm::Twine{""}.str(); +} + +// Set the omp.target_cpu_features attribute on the module with +// the specified string +void OpenMPDialect::setTargetCpuFeatures(Operation *module, + llvm::StringRef cpuFeatures) { + module->setAttr(mlir::StringAttr::get(module->getContext(), + llvm::Twine{"omp.target_cpu_features"}), + mlir::StringAttr::get(module->getContext(), cpuFeatures)); +} + +// Return the value of the omp.target_cpu_features attribute stored in the +// module if it exists, otherwise return empty by default +std::string OpenMPDialect::getTargetCpuFeatures(Operation *module) { + if (Attribute targetCpu = module->getAttr("omp.target_cpu_features")) + if (targetCpu.isa()) + return targetCpu.dyn_cast().getValue().str(); + return llvm::Twine{""}.str(); +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" Index: mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp =================================================================== --- mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -1556,10 +1556,55 @@ LogicalResult convertOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) const final; + LogicalResult + amendOperation(Operation *op, NamedAttribute attribute, + LLVM::ModuleTranslation &moduleTranslation) const final; }; } // namespace +/// Convert OpenMP MLIR target attributes to LLVM function attributes +static void +trySetLLVMFunctionTargetAttr(Operation *Op, NamedAttribute namedAttr, + StringAttr valueAttr, + LLVM::ModuleTranslation &moduleTranslation) { + std::string LLVMFuncAttrName; + if (namedAttr.getName() == "omp.target_cpu") + LLVMFuncAttrName = "target-cpu"; + else if (namedAttr.getName() == "omp.target_cpu_features") + LLVMFuncAttrName = "target-features"; + else + return; + + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + ompBuilder->addAttributeToModuleFunctions(LLVMFuncAttrName, + valueAttr.getValue().str()); +} + +/// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR, runtime +/// calls, or operation amendments +LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation( + Operation *op, NamedAttribute attribute, + LLVM::ModuleTranslation &moduleTranslation) const { + + return llvm::TypeSwitch(attribute.getValue()) + .Case([&](mlir::StringAttr attr) { + // check if given string attributes relate to omp.target_cpu or + // omp.target_cpu_features. If yes, try to add LLVM function + // attributes + trySetLLVMFunctionTargetAttr(op, attribute, attr, moduleTranslation); + return success(); + }) + .Default([&](Attribute attr) { + // fall through for omp attributes that do not require lowering and/or + // have no concrete definition and thus no type to define a case on + // e.g. omp.is_device + return success(); + }); + + return failure(); +} + /// Given an OpenMP MLIR operation, create the corresponding LLVM IR /// (including OpenMP runtime calls). LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation( Index: mlir/test/Target/LLVMIR/openmp-llvm.mlir =================================================================== --- mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -2454,3 +2454,38 @@ } llvm.return } + +// ----- + +// CHECK: @omp_target_features_test() #0 { +module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.target_cpu = "gfx908", omp.target_cpu_features = "+dot3-insts,+dot4-insts,+s-memtime-inst,+16-bit-insts,+s-memrealtime,+dot6-insts,+dl-insts,+wavefrontsize64,+gfx9-insts,+gfx8-insts,+ci-insts,+dot10-insts,+dot7-insts,+dot1-insts,+dot5-insts,+mai-insts,+dpp,+dot2-insts"} { + llvm.func @omp_target_features_test() { + llvm.return + } +} + +// CHECK: attributes #0 = { "target-cpu"="gfx908" +// CHECK-SAME: "target-features"="+dot3-insts,+dot4-insts,+s-memtime-inst, +// CHECK-SAME: +16-bit-insts,+s-memrealtime,+dot6-insts,+dl-insts, +// CHECK-SAME: +wavefrontsize64,+gfx9-insts,+gfx8-insts,+ci-insts,+dot10-insts, +// CHECK-SAME: +dot7-insts,+dot1-insts,+dot5-insts,+mai-insts,+dpp,+dot2-insts" + +// ----- + +module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.target_cpu = "gfx908", omp.target_cpu_features = "+dot3-insts,+dot4-insts,+s-memtime-inst,+16-bit-insts,+s-memrealtime,+dot6-insts,+dl-insts,+wavefrontsize64,+gfx9-insts,+gfx8-insts,+ci-insts,+dot10-insts,+dot7-insts,+dot1-insts,+dot5-insts,+mai-insts,+dpp,+dot2-insts"} { +// CHECK: @test_omp_attr() #0 { +// CHECK: @test_omp_attr..omp_par() #0 { +// CHECK: @test_omp_attr..omp_par.wrapper(i32 %{{.*}}) #0 { + + llvm.func @test_omp_attr() { + omp.task { + omp.terminator + } + llvm.return + } +} +// CHECK: attributes #0 = { "target-cpu"="gfx908" +// CHECK-SAME: "target-features"="+dot3-insts,+dot4-insts,+s-memtime-inst, +// CHECK-SAME: +16-bit-insts,+s-memrealtime,+dot6-insts,+dl-insts, +// CHECK-SAME: +wavefrontsize64,+gfx9-insts,+gfx8-insts,+ci-insts,+dot10-insts, +// CHECK-SAME: +dot7-insts,+dot1-insts,+dot5-insts,+mai-insts,+dpp,+dot2-insts"