diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -48,6 +48,7 @@ struct LLVMTypeStorage; struct LLVMDialectImpl; struct BitmaskEnumStorage; +struct LoopOptionAttrStorage; } // namespace detail /// An attribute that specifies LLVM instruction fastmath flags. @@ -64,6 +65,29 @@ static Attribute parse(DialectAsmParser &parser); }; +/// An attribute that specifies LLVM loop codegen options. +class LoopOptionAttr + : public Attribute::AttrBase { +public: + using Base::Base; + + static LoopOptionAttr getParallelAccess(MLIRContext *context, + bool parallel = true); + static LoopOptionAttr getDisableUnroll(MLIRContext *context, + bool disable = true); + static LoopOptionAttr getDisableLICM(MLIRContext *context, + bool disable = true); + static LoopOptionAttr getInterleaveCount(MLIRContext *context, int32_t count); + + LoopOptionCase getCase() const; + bool getBool() const; + int32_t getInt() const; + + void print(DialectAsmPrinter &p) const; + static Attribute parse(DialectAsmParser &parser); +}; + } // namespace LLVM } // namespace mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -32,6 +32,8 @@ static StringRef getDataLayoutAttrName() { return "llvm.data_layout"; } static StringRef getAlignAttrName() { return "llvm.align"; } static StringRef getNoAliasAttrName() { return "llvm.noalias"; } + static StringRef getLoopOptionsAttrName() { return "llvm.loops"; } + static StringRef getLoopAttrName() { return "llvm.loop"; } /// Verifies if the given string is a well-formed data layout descriptor. /// Uses `reportError` to report errors. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -47,6 +47,19 @@ "::mlir::LLVM::FMFAttr::get($0, $_builder.getContext())"; } +def LOptParallelAccess : I32EnumAttrCase<"parallel_access", 1>; +def LOptDisableUnroll : I32EnumAttrCase<"disable_unroll", 2>; +def LOptDisableLICM : I32EnumAttrCase<"disable_licm", 3>; +def LOptInterleaveCount : I32EnumAttrCase<"interleave_count", 4>; + +def LoopOptionCase : I32EnumAttr< + "LoopOptionCase", + "LLVM loop option", + [LOptParallelAccess, LOptDisableUnroll, LOptDisableLICM, LOptInterleaveCount + ]> { + let cppNamespace = "::mlir::LLVM"; +} + class LLVM_Builder { string llvmBuilder = builder; } diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -69,6 +69,8 @@ return nullptr; if (failed(translator.convertGlobals())) return nullptr; + if (failed(translator.createLoopMetadata())) + return nullptr; if (failed(translator.convertFunctions())) return nullptr; @@ -136,6 +138,11 @@ return branchMapping.lookup(op); } + /// Finds the LLVM metadata corresponding to a loop.. + llvm::MDNode *lookupLoopMetadata(StringRef loop) const { + return loopMetadataMapping.lookup(loop); + } + /// Converts the type from MLIR LLVM dialect to LLVM. llvm::Type *convertType(Type type); @@ -205,6 +212,13 @@ LogicalResult convertGlobals(); LogicalResult convertOneFunction(LLVMFuncOp func); + // Creates LLVM metadata for loops annotated in this module if the module + // contains an `llvm.loops` attribute. Each element of this attribute is a + // NamedAttribute and describes the attributes for a distinct loop. The first + // element of the NamedAttribute is a string attribute identifying the loop + // and the second is an array attribute holding the loop options. + LogicalResult createLoopMetadata(); + /// Translates dialect attributes attached to the given operation. LogicalResult convertDialectAttributes(Operation *op); @@ -241,6 +255,11 @@ /// they are converted to. This allows for connecting PHI nodes to the source /// values after all operations are converted. DenseMap branchMapping; + + // Mapping from a named loop to its LLVM metadata. This map is + // populated on module entry and used to annotate loops (as + // identified via their branches) and contained memory accesses. + llvm::StringMap loopMetadataMapping; }; namespace detail { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -56,6 +56,26 @@ KeyTy value = 0; }; + +struct LoopOptionAttrStorage : public AttributeStorage { + using KeyTy = std::pair; + + explicit LoopOptionAttrStorage(uint64_t option, int32_t value) + : option(option), value(value) {} + + bool operator==(const KeyTy &key) const { + return key == KeyTy(option, value); + } + + static LoopOptionAttrStorage * + construct(mlir::AttributeStorageAllocator &allocator, const KeyTy &key) { + return new (allocator.allocate()) + LoopOptionAttrStorage(key.first, key.second); + } + + uint64_t option; + int32_t value; +}; } // namespace detail } // namespace LLVM } // namespace mlir @@ -2079,7 +2099,7 @@ //===----------------------------------------------------------------------===// void LLVMDialect::initialize() { - addAttributes(); + addAttributes(); // clang-format off addTypes(); + if (!dictAttr) + return op->emitOpError() << "expected '" << getLoopOptionsAttrName() + << "' to be a dictionary attribute"; + for (auto keyValue : dictAttr) { + auto loopOptions = keyValue.second.dyn_cast(); + if (!loopOptions) + return op->emitOpError() + << "expected loop options for loop '" << keyValue.first + << "' in attribute '" << getLoopOptionsAttrName() + << "' to be an array attribute"; + for (auto loopOption : loopOptions) { + if (!loopOption.dyn_cast()) + return op->emitOpError() + << "expected loop option for loop '" << keyValue.first + << "' in attribute '" << getLoopOptionsAttrName() + << "' to be a loop-option attribute"; + } + } + return success(); + } + // If the data layout attribute is present, it must use the LLVM data layout // syntax. Try parsing it and report errors in case of failure. Users of this // attribute may assume it is well-formed and can pass it to the (asserting) @@ -2264,6 +2307,112 @@ return FMFAttr::get(flags, parser.getBuilder().getContext()); } +LoopOptionAttr LoopOptionAttr::getParallelAccess(MLIRContext *context, + bool parallel) { + auto option = LoopOptionCase::parallel_access; + return Base::get(context, static_cast(option), + static_cast(parallel)); +} + +LoopOptionAttr LoopOptionAttr::getDisableUnroll(MLIRContext *context, + bool disable) { + auto option = LoopOptionCase::disable_unroll; + return Base::get(context, static_cast(option), + static_cast(disable)); +} + +LoopOptionAttr LoopOptionAttr::getDisableLICM(MLIRContext *context, + bool disable) { + auto option = LoopOptionCase::disable_licm; + return Base::get(context, static_cast(option), + static_cast(disable)); +} + +LoopOptionAttr LoopOptionAttr::getInterleaveCount(MLIRContext *context, + int32_t count) { + auto option = LoopOptionCase::interleave_count; + return Base::get(context, static_cast(option), + static_cast(count)); +} + +LoopOptionCase LoopOptionAttr::getCase() const { + return static_cast(getImpl()->option); +} + +bool LoopOptionAttr::getBool() const { + LoopOptionCase option = getCase(); + (void)option; + assert(option == LoopOptionCase::parallel_access || + option == LoopOptionCase::disable_licm || + option == LoopOptionCase::disable_unroll); + return static_cast(getImpl()->value); +} + +int32_t LoopOptionAttr::getInt() const { + LoopOptionCase option = getCase(); + (void)option; + assert(option == LoopOptionCase::interleave_count); + return getImpl()->value; +} + +void LoopOptionAttr::print(DialectAsmPrinter &printer) const { + printer << "loopopt<" << stringifyEnum(getCase()) << " = "; + switch (getCase()) { + case LoopOptionCase::parallel_access: + case LoopOptionCase::disable_licm: + case LoopOptionCase::disable_unroll: + printer << (getBool() ? "true" : "false"); + break; + case LoopOptionCase::interleave_count: + printer << getInt(); + break; + } + printer << ">"; +} + +Attribute LoopOptionAttr::parse(DialectAsmParser &parser) { + if (failed(parser.parseLess())) + return {}; + + StringRef optionName; + if (failed(parser.parseKeyword(&optionName))) + return {}; + + auto option = symbolizeLoopOptionCase(optionName); + if (!option) { + parser.emitError(parser.getNameLoc(), "Unknown loop option: ") + << optionName; + return {}; + } + + if (failed(parser.parseEqual())) + return {}; + + int32_t value; + switch (*option) { + case LoopOptionCase::parallel_access: + case LoopOptionCase::disable_licm: + case LoopOptionCase::disable_unroll: + if (succeeded(parser.parseOptionalKeyword("true"))) + value = 1; + else if (succeeded(parser.parseOptionalKeyword("false"))) + value = 0; + else + return {}; + break; + case LoopOptionCase::interleave_count: + if (failed(parser.parseInteger(value))) + return {}; + break; + } + + if (failed(parser.parseGreater())) + return {}; + + return Base::get(parser.getBuilder().getContext(), + static_cast(*option), value); +} + Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser, Type type) const { if (type) { @@ -2277,14 +2426,18 @@ if (attrKind == "fastmath") return FMFAttr::parse(parser); - parser.emitError(parser.getNameLoc(), "Unknown attrribute type: ") - << attrKind; + if (attrKind == "loopopt") + return LoopOptionAttr::parse(parser); + + parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind; return {}; } void LLVMDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const { if (auto fmf = attr.dyn_cast()) fmf.print(os); + else if (auto lopt = attr.dyn_cast()) + lopt.print(os); else llvm_unreachable("Unknown attribute type"); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -167,6 +167,21 @@ return ret; } +static LogicalResult +setLoopMetadata(Operation &opInst, llvm::Instruction &llvmInst, + llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + if (Attribute attr = opInst.getAttr(LLVMDialect::getLoopAttrName())) { + llvm::Module *module = builder.GetInsertBlock()->getModule(); + auto loopName = attr.cast().getValue(); + llvm::MDNode *loopMD = moduleTranslation.lookupLoopMetadata(loopName); + if (!loopMD) + return failure(); + llvmInst.setMetadata(module->getMDKindID("llvm.loop"), loopMD); + } + return success(); +} + static LogicalResult convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { @@ -294,7 +309,7 @@ llvm::BranchInst *branch = builder.CreateBr(moduleTranslation.lookupBlock(brOp.getSuccessor())); moduleTranslation.mapBranch(&opInst, branch); - return success(); + return setLoopMetadata(opInst, *branch, builder, moduleTranslation); } if (auto condbrOp = dyn_cast(opInst)) { auto weights = condbrOp.branch_weights(); @@ -315,7 +330,7 @@ moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)), moduleTranslation.lookupBlock(condbrOp.getSuccessor(1)), branchWeights); moduleTranslation.mapBranch(&opInst, branch); - return success(); + return setLoopMetadata(opInst, *branch, builder, moduleTranslation); } if (auto switchOp = dyn_cast(opInst)) { llvm::MDNode *branchWeights = nullptr; diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -624,6 +624,80 @@ return success(); } +static llvm::MDNode *getLoopOptionMetadata(llvm::LLVMContext &ctx, + LoopOptionAttr option) { + StringRef name; + llvm::Constant *value = nullptr; + switch (option.getCase()) { + case LoopOptionCase::parallel_access: + llvm_unreachable("incorrectly handled loop option: parallel_access"); + case LoopOptionCase::disable_licm: + name = "llvm.licm.disable"; + value = llvm::ConstantInt::getBool(ctx, option.getBool()); + break; + case LoopOptionCase::disable_unroll: + name = "llvm.loop.unroll.disable"; + value = llvm::ConstantInt::getBool(ctx, option.getBool()); + break; + case LoopOptionCase::interleave_count: + name = "llvm.loop.interleave.count"; + value = llvm::ConstantInt::get(llvm::IntegerType::get(ctx, /*NumBits=*/32), + option.getInt()); + break; + } + return llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name), + llvm::ConstantAsMetadata::get(value)}); +} + +LogicalResult ModuleTranslation::createLoopMetadata() { + if (auto loopOptions = mlirModule->getAttrOfType( + LLVMDialect::getLoopOptionsAttrName())) { + for (auto nameAndMetadata : loopOptions) { + StringRef loopName = nameAndMetadata.first.strref(); + auto loopMetadataAttr = nameAndMetadata.second.dyn_cast(); + if (!loopMetadataAttr) + return failure(); + + llvm::IRBuilder<> builder(llvmModule->getContext()); + llvm::LLVMContext &ctx = llvmModule->getContext(); + llvm::SmallVector loopOptions; + + // Reserve operand 0 for loop id self reference. + auto dummy = llvm::MDNode::getTemporary(ctx, llvm::None); + loopOptions.push_back(dummy.get()); + + // Handle the parallel-access option as a special case. + auto loopOptionsAttr = loopMetadataAttr.getAsRange(); + bool isParallelAccess = + llvm::any_of(loopOptionsAttr, [](LoopOptionAttr option) { + return option.getCase() == LoopOptionCase::parallel_access && + option.getBool(); + }); + if (isParallelAccess) { + llvm::MDNode *accessGroup = llvm::MDNode::getDistinct(ctx, {}); + llvm::MDNode *parallelAccess = llvm::MDNode::get( + ctx, {llvm::MDString::get(ctx, "llvm.loop.parallel_accesses"), + accessGroup}); + loopOptions.push_back(parallelAccess); + } + + for (auto loopOption : loopOptionsAttr) + if (loopOption.getCase() != LoopOptionCase::parallel_access) + loopOptions.push_back(getLoopOptionMetadata(ctx, loopOption)); + + // Create loop options and set the first operand to itself. + llvm::MDNode *metadata = llvm::MDNode::get(ctx, loopOptions); + metadata->replaceOperandWith(0, metadata); + + // Store the metadata in a map for later reference. These will be attached + // to the loop. + loopMetadataMapping.insert({loopName, metadata}); + } + } + + return success(); +} + llvm::Type *ModuleTranslation::convertType(Type type) { return typeTranslator.translateType(type); } diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -419,3 +419,7 @@ %10 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<>} : f32 return } + +// CHECK: module attributes {llvm.loops = {loop1 = [#llvm.loopopt, #llvm.loopopt, #llvm.loopopt, #llvm.loopopt]}} +module attributes {llvm.loops = {loop1 = [#llvm.loopopt, #llvm.loopopt, #llvm.loopopt, #llvm.loopopt]}} { +} diff --git a/mlir/test/Target/llvmir-invalid.mlir b/mlir/test/Target/llvmir-invalid.mlir --- a/mlir/test/Target/llvmir-invalid.mlir +++ b/mlir/test/Target/llvmir-invalid.mlir @@ -64,3 +64,21 @@ // expected-error @+1 {{expected arrays within 'passthrough' to contain two strings}} llvm.func @passthrough_wrong_type() attributes {passthrough = [[42, 42]]} + +// ----- + +// expected-error @+1 {{expected 'llvm.loops' to be a dictionary attribute}} +module attributes {llvm.loops = []} { +} + +// ----- + +// expected-error @+1 {{expected loop options for loop 'loop1' in attribute 'llvm.loops' to be an array attribute}} +module attributes {llvm.loops = {loop1 = "test"}} { +} + +// ----- + +// expected-error @+1 {{expected loop option for loop 'loop1' in attribute 'llvm.loops' to be a loop-option attribute}} +module attributes {llvm.loops = {loop1 = ["test"]}} { +} diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -1469,3 +1469,29 @@ } // CHECK: ![[SWITCH_WEIGHT_NODE]] = !{!"branch_weights", i32 13, i32 17, i32 19} + +// ----- + +module attributes {llvm.loops = {loop1 = [#llvm.loopopt, #llvm.loopopt, #llvm.loopopt, #llvm.loopopt]}} { + llvm.func @loopOptions(%arg1 : i32, %arg2 : i32) { + %0 = llvm.mlir.constant(0 : i32) : i32 + llvm.br ^bb3(%0 : i32) + ^bb3(%1: i32): + %2 = llvm.icmp "slt" %1, %arg1 : i32 + // CHECK: br i1 {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]] + llvm.cond_br %2, ^bb4, ^bb5 {llvm.loop = "loop1"} + ^bb4: + %3 = llvm.add %1, %arg2 : i32 + // CHECK: br label {{.*}} !llvm.loop ![[LOOP_NODE]] + llvm.br ^bb3(%3 : i32) {llvm.loop = "loop1"} + ^bb5: + llvm.return + } +} + +// CHECK: ![[LOOP_NODE]] = distinct !{![[LOOP_NODE]], ![[PA_NODE:[0-9]+]], ![[UNROLL_DISABLE_NODE:[0-9]+]], ![[LICM_DISABLE_NODE:[0-9]+]], ![[INTERLEAVE_NODE:[0-9]+]]} +// CHECK: ![[PA_NODE]] = !{!"llvm.loop.parallel_accesses", ![[GROUP_NODE:[0-9]+]]} +// CHECK: ![[GROUP_NODE]] = distinct !{} +// CHECK: ![[UNROLL_DISABLE_NODE]] = !{!"llvm.loop.unroll.disable", i1 true} +// CHECK: ![[LICM_DISABLE_NODE]] = !{!"llvm.licm.disable", i1 true} +// CHECK: ![[INTERLEAVE_NODE]] = !{!"llvm.loop.interleave.count", i32 1}