diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -26,4 +26,38 @@ ); } +def LoopOptionsAttr : LLVM_Attr<"LoopOptions"> { + let mnemonic = "loopopts"; + + let description = [{ + This attributes encapsulates "loop options". It is means to decorate + branches that are "latches" (loop backedges) and maps to the `!llvm.loop` + metadatas: https://llvm.org/docs/LangRef.html#llvm-loop + It store the options as a pair in a sorted array and expose + APIs to retrieve the value for each option with a stronger type (bool for + example). + }]; + + // List of type parameters. + let parameters = ( + ins + ArrayRefParameter<"std::pair", "">:$options + ); + + let extraClassDeclaration = [{ + using OptionValuePair = std::pair; + using OptionsArray = ArrayRef>; + Optional disableUnroll(); + Optional disableLICM(); + Optional interleaveCount(); + }]; + + let builders = [ + /// Build the LoopOptions Attribute from a sorted array of individual options. + AttrBuilder<(ins "ArrayRef>":$sortedOptions)>, + AttrBuilder<(ins "LoopOptionsAttrBuilder &":$optionBuilders)> + ]; + let skipDefaultBuilders = 1; +} + #endif // LLVMIR_ATTRDEFS diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -30,8 +30,6 @@ #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.h.inc" #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.h.inc" -#define GET_ATTRDEF_CLASSES -#include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.h.inc" namespace llvm { class Type; @@ -45,48 +43,18 @@ namespace mlir { namespace LLVM { class LLVMDialect; +class LoopOptionsAttrBuilder; namespace detail { struct LLVMTypeStorage; struct LLVMDialectImpl; -struct LoopOptionAttrStorage; } // namespace detail - -/// An attribute that specifies LLVM loop codegen options. -class LoopOptionAttr - : public Attribute::AttrBase { -public: - using Base::Base; - - /// Specifies the llvm.loop.unroll.disable metadata. - static LoopOptionAttr getDisableUnroll(MLIRContext *context, - bool disable = true); - - /// Specifies the llvm.licm.disable metadata. - static LoopOptionAttr getDisableLICM(MLIRContext *context, - bool disable = true); - - /// Specifies the llvm.loop.interleave.count metadata. - static LoopOptionAttr getInterleaveCount(MLIRContext *context, int32_t count); - - /// Returns the loop option, e.g. parallel_access. - LoopOptionCase getCase() const; - - /// Returns if the loop option is activated. Only valid for boolean options. - bool getBool() const; - - /// Returns the integer value associated with a loop option. Only valid for - /// integer options. - int32_t getInt() const; - - void print(DialectAsmPrinter &p) const; - static Attribute parse(DialectAsmParser &parser); -}; - } // namespace LLVM } // namespace mlir +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.h.inc" + ///// Ops ///// #define GET_OP_CLASSES #include "mlir/Dialect/LLVMIR/LLVMOps.h.inc" @@ -106,6 +74,41 @@ /// function confirms that the Operation has the desired properties. bool satisfiesLLVMModule(Operation *op); +/// Builder class for LoopOptionsAttr. This helper class allows to progressively +/// build a LoopOptionsAttr one option at a time, and pay the price of attribute +/// creation once all the options are in place. +class LoopOptionsAttrBuilder { +public: + /// Construct a builder with an initial list of options from an existing + /// LoopOptionsAttr. + LoopOptionsAttrBuilder(LoopOptionsAttr attr); + + /// Set the `disable_licm` option to the provided value. If no value + /// is provided the option is deleted. + LoopOptionsAttrBuilder &setDisableLICM(Optional value) { + return setOption(LoopOptionCase::disable_licm, value); + } + + /// Set the `interleave_count` option to the provided value. If no value + /// is provided the option is deleted. + LoopOptionsAttrBuilder &setInterleaveCount(Optional count) { + return setOption(LoopOptionCase::interleave_count, count); + } + + /// Set the `disable_unroll` option to the provided value. If no value + /// is provided the option is deleted. + LoopOptionsAttrBuilder &setDisableUnroll(Optional value) { + return setOption(LoopOptionCase::disable_unroll, value); + } + +private: + template + LoopOptionsAttrBuilder &setOption(LoopOptionCase tag, Optional value); + + friend class LoopOptionsAttr; + SmallVector options; +}; + } // end namespace LLVM } // end namespace mlir diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -197,6 +197,13 @@ static Optional findDuplicate(SmallVectorImpl &array, bool isSorted); + /// Return the specified attribute if present and is an instance of + /// `AttrClass`, null otherwise. + template + AttrClass getAs(StringRef name) { + return get(name).dyn_cast_or_null(); + } + private: /// Return empty dictionary. static DictionaryAttr getEmpty(MLIRContext *context); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -41,33 +41,6 @@ #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc" -namespace mlir { -namespace LLVM { -namespace detail { - -struct LoopOptionAttrStorage : public AttributeStorage { - using KeyTy = std::pair; - - explicit LoopOptionAttrStorage(uint64_t option, int32_t value) - : option(option), value(value) {} - - bool operator==(const KeyTy &key) const { - return key == KeyTy(option, value); - } - - static LoopOptionAttrStorage * - construct(mlir::AttributeStorageAllocator &allocator, const KeyTy &key) { - return new (allocator.allocate()) - LoopOptionAttrStorage(key.first, key.second); - } - - uint64_t option; - int32_t value; -}; -} // namespace detail -} // namespace LLVM -} // namespace mlir - static auto processFMFAttr(ArrayRef attrs) { SmallVector filteredAttrs( llvm::make_filter_range(attrs, [&](NamedAttribute attr) { @@ -2198,7 +2171,7 @@ //===----------------------------------------------------------------------===// void LLVMDialect::initialize() { - addAttributes(); + addAttributes(); // clang-format off addTypes loopOptions = loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName()); - if (loopOptions.hasValue()) { - auto options = loopOptions->second.dyn_cast(); - if (!options) - return op->emitOpError() - << "expected '" << LLVMDialect::getLoopOptionsAttrName() - << "' to be an array attribute"; - if (!llvm::all_of(options, [](Attribute option) { - return option.isa(); - })) - return op->emitOpError() << "invalid loop options list " << options; - } + if (loopOptions.hasValue() && !loopOptions->second.isa()) + return op->emitOpError() + << "expected '" << LLVMDialect::getLoopOptionsAttrName() + << "' to be a `loopopts` attribute"; } // If the data layout attribute is present, it must use the LLVM data layout @@ -2427,107 +2393,135 @@ return FMFAttr::get(parser.getBuilder().getContext(), flags); } -LoopOptionAttr LoopOptionAttr::getDisableUnroll(MLIRContext *context, - bool disable) { - auto option = LoopOptionCase::disable_unroll; - return Base::get(context, static_cast(option), - static_cast(disable)); +LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr) + : options(attr.getOptions().begin(), attr.getOptions().end()) {} + +template +LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setOption(LoopOptionCase tag, + Optional value) { + auto option = llvm::find_if( + options, [tag](auto option) { return option.first == tag; }); + if (option != options.end()) { + if (value.hasValue()) + option->second = *value; + else + options.erase(option); + } else { + options.push_back(LoopOptionsAttr::OptionValuePair(tag, *value)); + } + return *this; } -LoopOptionAttr LoopOptionAttr::getDisableLICM(MLIRContext *context, - bool disable) { - auto option = LoopOptionCase::disable_licm; - return Base::get(context, static_cast(option), - static_cast(disable)); +template +static Optional +getOption(ArrayRef> options, + LoopOptionCase option) { + auto it = + lower_bound(options, option, [](auto optionPair, LoopOptionCase option) { + return optionPair.first < option; + }); + if (it == options.end()) + return {}; + return static_cast(it->second); } -LoopOptionAttr LoopOptionAttr::getInterleaveCount(MLIRContext *context, - int32_t count) { - auto option = LoopOptionCase::interleave_count; - return Base::get(context, static_cast(option), - static_cast(count)); +Optional LoopOptionsAttr::disableUnroll() { + return getOption(getOptions(), LoopOptionCase::disable_unroll); } -LoopOptionCase LoopOptionAttr::getCase() const { - return static_cast(getImpl()->option); +Optional LoopOptionsAttr::disableLICM() { + return getOption(getOptions(), LoopOptionCase::disable_licm); } -bool LoopOptionAttr::getBool() const { - LoopOptionCase option = getCase(); - (void)option; - assert(option == LoopOptionCase::disable_licm || - option == LoopOptionCase::disable_unroll && - "expected a boolean loop option"); - return static_cast(getImpl()->value); +Optional LoopOptionsAttr::interleaveCount() { + return getOption(getOptions(), LoopOptionCase::interleave_count); } -int32_t LoopOptionAttr::getInt() const { - LoopOptionCase option = getCase(); - (void)option; - assert(option == LoopOptionCase::interleave_count && - "expected an integer loop option"); - return getImpl()->value; +/// Build the LoopOptions Attribute from a sorted array of individual options. +LoopOptionsAttr LoopOptionsAttr::get( + MLIRContext *context, + ArrayRef> sortedOptions) { + assert(llvm::is_sorted(sortedOptions, llvm::less_first()) && + "LoopOptionsAttr ctor expects a sorted options array"); + return Base::get(context, sortedOptions); } -void LoopOptionAttr::print(DialectAsmPrinter &printer) const { - printer << "loopopt<" << stringifyEnum(getCase()) << " = "; - switch (getCase()) { - case LoopOptionCase::disable_licm: - case LoopOptionCase::disable_unroll: - printer << (getBool() ? "true" : "false"); - break; - case LoopOptionCase::interleave_count: - printer << getInt(); - break; - } +/// Build the LoopOptions Attribute from a sorted array of individual options. +LoopOptionsAttr LoopOptionsAttr::get(MLIRContext *context, + LoopOptionsAttrBuilder &optionBuilders) { + llvm::sort(optionBuilders.options, llvm::less_first()); + return Base::get(context, optionBuilders.options); +} + +void LoopOptionsAttr::print(DialectAsmPrinter &printer) const { + printer << getMnemonic() << "<"; + llvm::interleaveComma(getOptions(), printer, [&](auto option) { + switch (option.first) { + case LoopOptionCase::disable_licm: + case LoopOptionCase::disable_unroll: + printer << (option.second ? "true" : "false"); + break; + case LoopOptionCase::interleave_count: + printer << option.second; + break; + } + }); printer << ">"; } -Attribute LoopOptionAttr::parse(DialectAsmParser &parser) { +Attribute LoopOptionsAttr::parse(MLIRContext *context, DialectAsmParser &parser, + Type type) { if (failed(parser.parseLess())) return {}; - StringRef optionName; - if (failed(parser.parseKeyword(&optionName))) - return {}; - - auto option = symbolizeLoopOptionCase(optionName); - if (!option) { - parser.emitError(parser.getNameLoc(), "unknown loop option: ") - << optionName; - return {}; - } - - if (failed(parser.parseEqual())) - return {}; + SmallVector> options; + llvm::SmallDenseSet seenOptions; + do { + StringRef optionName; + if (parser.parseKeyword(&optionName)) + return {}; - int32_t value; - switch (*option) { - case LoopOptionCase::disable_licm: - case LoopOptionCase::disable_unroll: - if (succeeded(parser.parseOptionalKeyword("true"))) - value = 1; - else if (succeeded(parser.parseOptionalKeyword("false"))) - value = 0; - else { - parser.emitError(parser.getNameLoc(), - "expected boolean value 'true' or 'false'"); + auto option = symbolizeLoopOptionCase(optionName); + if (!option) { + parser.emitError(parser.getNameLoc(), "unknown loop option: ") + << optionName; return {}; } - break; - case LoopOptionCase::interleave_count: - if (failed(parser.parseInteger(value))) { - parser.emitError(parser.getNameLoc(), "expected integer value"); + if (!seenOptions.insert(*option).second) { + parser.emitError(parser.getNameLoc(), "loop option present twice"); return {}; } - break; - } + if (failed(parser.parseEqual())) + return {}; + int64_t value; + switch (*option) { + case LoopOptionCase::disable_licm: + case LoopOptionCase::disable_unroll: + if (succeeded(parser.parseOptionalKeyword("true"))) + value = 1; + else if (succeeded(parser.parseOptionalKeyword("false"))) + value = 0; + else { + parser.emitError(parser.getNameLoc(), + "expected boolean value 'true' or 'false'"); + return {}; + } + break; + case LoopOptionCase::interleave_count: + if (failed(parser.parseInteger(value))) { + parser.emitError(parser.getNameLoc(), "expected integer value"); + return {}; + } + break; + } + options.push_back(std::make_pair(*option, value)); + } while (succeeded(parser.parseOptionalComma())); if (failed(parser.parseGreater())) return {}; - return Base::get(parser.getBuilder().getContext(), - static_cast(*option), value); + llvm::sort(options, llvm::less_first()); + return get(parser.getBuilder().getContext(), options); } Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser, @@ -2543,9 +2537,6 @@ generatedAttributeParser(getContext(), parser, attrKind, type)) return attr; - if (attrKind == "loopopt") - return LoopOptionAttr::parse(parser); - parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind; return {}; } @@ -2553,8 +2544,5 @@ void LLVMDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const { if (succeeded(generatedAttributePrinter(attr, os))) return; - if (auto lopt = attr.dyn_cast()) - lopt.print(os); - else - llvm_unreachable("Unknown attribute type"); + llvm_unreachable("Unknown attribute type"); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -171,26 +171,27 @@ /// Returns an LLVM metadata node corresponding to a loop option. This metadata /// is attached to an llvm.loop node. static llvm::MDNode *getLoopOptionMetadata(llvm::LLVMContext &ctx, - LoopOptionAttr option) { + LoopOptionCase option, + int64_t value) { StringRef name; - llvm::Constant *value = nullptr; - switch (option.getCase()) { + llvm::Constant *cstValue = nullptr; + switch (option) { case LoopOptionCase::disable_licm: name = "llvm.licm.disable"; - value = llvm::ConstantInt::getBool(ctx, option.getBool()); + cstValue = llvm::ConstantInt::getBool(ctx, value); break; case LoopOptionCase::disable_unroll: name = "llvm.loop.unroll.disable"; - value = llvm::ConstantInt::getBool(ctx, option.getBool()); + cstValue = llvm::ConstantInt::getBool(ctx, value); break; case LoopOptionCase::interleave_count: name = "llvm.loop.interleave.count"; - value = llvm::ConstantInt::get(llvm::IntegerType::get(ctx, /*NumBits=*/32), - option.getInt()); + cstValue = llvm::ConstantInt::get( + llvm::IntegerType::get(ctx, /*NumBits=*/32), value); break; } return llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name), - llvm::ConstantAsMetadata::get(value)}); + llvm::ConstantAsMetadata::get(cstValue)}); } static void setLoopMetadata(Operation &opInst, llvm::Instruction &llvmInst, @@ -222,13 +223,11 @@ loopOptions.push_back(llvm::MDNode::get(ctx, parallelAccess)); } - auto loopOptionsAttr = - loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName()); - if (loopOptionsAttr.hasValue()) { - for (LoopOptionAttr loopOption : - loopOptionsAttr->second.cast() - .getAsRange()) - loopOptions.push_back(getLoopOptionMetadata(ctx, loopOption)); + if (auto loopOptionsAttr = loopAttr.getAs( + LLVMDialect::getLoopOptionsAttrName())) { + for (auto option : loopOptionsAttr.getOptions()) + loopOptions.push_back( + getLoopOptionMetadata(ctx, option.first, option.second)); } // Create loop options and set the first operand to itself. diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -779,7 +779,7 @@ module { llvm.func @loopOptions() { - // expected-error@below {{expected 'options' to be an array attribute}} + // expected-error@below {{expected 'options' to be a `loopopts` attribute}} llvm.br ^bb4 {llvm.loop = {options = "name"}} ^bb4: llvm.return @@ -790,8 +790,21 @@ module { llvm.func @loopOptions() { - // expected-error@below {{invalid loop options list}} - llvm.br ^bb4 {llvm.loop = {options = ["name"]}} + // expected-error@+2 {{unknown loop option: name}} + // expected-error@below {{Unknown attribute type: loopopts}} + llvm.br ^bb4 {llvm.loop = {options = #llvm.loopopts}} + ^bb4: + llvm.return + } +} + +// ----- + +module { + llvm.func @loopOptions() { + // expected-error@+2 {{loop option present twice}} + // expected-error@below {{Unknown attribute type: loopopts}} + llvm.br ^bb4 {llvm.loop = {options = #llvm.loopopts}} ^bb4: llvm.return } diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -1480,13 +1480,13 @@ ^bb3(%1: i32): %2 = llvm.icmp "slt" %1, %arg1 : i32 // CHECK: br i1 {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]] - llvm.cond_br %2, ^bb4, ^bb5 {llvm.loop = {parallel_access = [@metadata::@group1, @metadata::@group2], options = [#llvm.loopopt, #llvm.loopopt, #llvm.loopopt]}} + llvm.cond_br %2, ^bb4, ^bb5 {llvm.loop = {parallel_access = [@metadata::@group1, @metadata::@group2], options = #llvm.loopopts}} ^bb4: %3 = llvm.add %1, %arg2 : i32 // CHECK: = load i32, i32* %{{.*}} !llvm.access.group ![[ACCESS_GROUPS_NODE:[0-9]+]] %5 = llvm.load %4 { access_groups = [@metadata::@group1, @metadata::@group2] } : !llvm.ptr // CHECK: br label {{.*}} !llvm.loop ![[LOOP_NODE]] - llvm.br ^bb3(%3 : i32) {llvm.loop = {parallel_access = [@metadata::@group1, @metadata::@group2], options = [#llvm.loopopt, #llvm.loopopt, #llvm.loopopt]}} + llvm.br ^bb3(%3 : i32) {llvm.loop = {parallel_access = [@metadata::@group1, @metadata::@group2], options = #llvm.loopopts}} ^bb5: llvm.return }