diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -26,4 +26,26 @@ ); } +// A more complex parameterized attribute. +def LoopOptionsAttr : LLVM_Attr<"LoopOptions"> { + let mnemonic = "loopopts"; + + // List of type parameters. + let parameters = ( + ins + ArrayRefParameter<"std::pair", "">:$options + ); + + let extraClassDeclaration = [{ + Optional disableUnroll(); + Optional disableLICM(); + Optional interleaveCount(); + }]; + + let builders = [ + AttrBuilder<(ins "ArrayRef>":$options)> + ]; + let skipDefaultBuilders = 1; +} + #endif // LLVMIR_ATTRDEFS diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -49,41 +49,7 @@ namespace detail { struct LLVMTypeStorage; struct LLVMDialectImpl; -struct LoopOptionAttrStorage; } // namespace detail - -/// An attribute that specifies LLVM loop codegen options. -class LoopOptionAttr - : public Attribute::AttrBase { -public: - using Base::Base; - - /// Specifies the llvm.loop.unroll.disable metadata. - static LoopOptionAttr getDisableUnroll(MLIRContext *context, - bool disable = true); - - /// Specifies the llvm.licm.disable metadata. - static LoopOptionAttr getDisableLICM(MLIRContext *context, - bool disable = true); - - /// Specifies the llvm.loop.interleave.count metadata. - static LoopOptionAttr getInterleaveCount(MLIRContext *context, int32_t count); - - /// Returns the loop option, e.g. parallel_access. - LoopOptionCase getCase() const; - - /// Returns if the loop option is activated. Only valid for boolean options. - bool getBool() const; - - /// Returns the integer value associated with a loop option. Only valid for - /// integer options. - int32_t getInt() const; - - void print(DialectAsmPrinter &p) const; - static Attribute parse(DialectAsmParser &parser); -}; - } // namespace LLVM } // namespace mlir diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -41,33 +41,6 @@ #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc" -namespace mlir { -namespace LLVM { -namespace detail { - -struct LoopOptionAttrStorage : public AttributeStorage { - using KeyTy = std::pair; - - explicit LoopOptionAttrStorage(uint64_t option, int32_t value) - : option(option), value(value) {} - - bool operator==(const KeyTy &key) const { - return key == KeyTy(option, value); - } - - static LoopOptionAttrStorage * - construct(mlir::AttributeStorageAllocator &allocator, const KeyTy &key) { - return new (allocator.allocate()) - LoopOptionAttrStorage(key.first, key.second); - } - - uint64_t option; - int32_t value; -}; -} // namespace detail -} // namespace LLVM -} // namespace mlir - static auto processFMFAttr(ArrayRef attrs) { SmallVector filteredAttrs( llvm::make_filter_range(attrs, [&](NamedAttribute attr) { @@ -2198,7 +2171,7 @@ //===----------------------------------------------------------------------===// void LLVMDialect::initialize() { - addAttributes(); + addAttributes(); // clang-format off addTypes loopOptions = loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName()); - if (loopOptions.hasValue()) { - auto options = loopOptions->second.dyn_cast(); - if (!options) - return op->emitOpError() - << "expected '" << LLVMDialect::getLoopOptionsAttrName() - << "' to be an array attribute"; - if (!llvm::all_of(options, [](Attribute option) { - return option.isa(); - })) - return op->emitOpError() << "invalid loop options list " << options; - } + if (loopOptions.hasValue() && !loopOptions->second.isa()) + return op->emitOpError() + << "expected '" << LLVMDialect::getLoopOptionsAttrName() + << "' to be a `loopopts` attribute"; } // If the data layout attribute is present, it must use the LLVM data layout @@ -2427,107 +2393,107 @@ return FMFAttr::get(parser.getBuilder().getContext(), flags); } -LoopOptionAttr LoopOptionAttr::getDisableUnroll(MLIRContext *context, - bool disable) { - auto option = LoopOptionCase::disable_unroll; - return Base::get(context, static_cast(option), - static_cast(disable)); -} - -LoopOptionAttr LoopOptionAttr::getDisableLICM(MLIRContext *context, - bool disable) { - auto option = LoopOptionCase::disable_licm; - return Base::get(context, static_cast(option), - static_cast(disable)); +template +static Optional +getOption(ArrayRef> options, + LoopOptionCase option) { + auto it = + lower_bound(options, option, [](auto optionPair, LoopOptionCase option) { + return optionPair.first == option; + }); + if (it == options.end()) + return {}; + return static_cast(it->second); } -LoopOptionAttr LoopOptionAttr::getInterleaveCount(MLIRContext *context, - int32_t count) { - auto option = LoopOptionCase::interleave_count; - return Base::get(context, static_cast(option), - static_cast(count)); +Optional LoopOptionsAttr::disableUnroll() { + return getOption(getOptions(), LoopOptionCase::disable_unroll); } -LoopOptionCase LoopOptionAttr::getCase() const { - return static_cast(getImpl()->option); +Optional LoopOptionsAttr::disableLICM() { + return getOption(getOptions(), LoopOptionCase::disable_licm); } -bool LoopOptionAttr::getBool() const { - LoopOptionCase option = getCase(); - (void)option; - assert(option == LoopOptionCase::disable_licm || - option == LoopOptionCase::disable_unroll && - "expected a boolean loop option"); - return static_cast(getImpl()->value); +Optional LoopOptionsAttr::interleaveCount() { + return getOption(getOptions(), LoopOptionCase::interleave_count); } -int32_t LoopOptionAttr::getInt() const { - LoopOptionCase option = getCase(); - (void)option; - assert(option == LoopOptionCase::interleave_count && - "expected an integer loop option"); - return getImpl()->value; +LoopOptionsAttr +LoopOptionsAttr::get(MLIRContext *context, + ArrayRef> options) { + assert(llvm::is_sorted(options, llvm::less_first())); + return Base::get(context, options); } -void LoopOptionAttr::print(DialectAsmPrinter &printer) const { - printer << "loopopt<" << stringifyEnum(getCase()) << " = "; - switch (getCase()) { - case LoopOptionCase::disable_licm: - case LoopOptionCase::disable_unroll: - printer << (getBool() ? "true" : "false"); - break; - case LoopOptionCase::interleave_count: - printer << getInt(); - break; - } +void LoopOptionsAttr::print(DialectAsmPrinter &printer) const { + printer << getMnemonic() << "<"; + llvm::interleaveComma(getOptions(), printer, [&](auto option) { + switch (option.first) { + case LoopOptionCase::disable_licm: + case LoopOptionCase::disable_unroll: + printer << (option.second ? "true" : "false"); + break; + case LoopOptionCase::interleave_count: + printer << option.second; + break; + } + }); printer << ">"; } -Attribute LoopOptionAttr::parse(DialectAsmParser &parser) { +Attribute LoopOptionsAttr::parse(MLIRContext *context, DialectAsmParser &parser, + Type type) { if (failed(parser.parseLess())) return {}; - StringRef optionName; - if (failed(parser.parseKeyword(&optionName))) - return {}; - - auto option = symbolizeLoopOptionCase(optionName); - if (!option) { - parser.emitError(parser.getNameLoc(), "unknown loop option: ") - << optionName; - return {}; - } - - if (failed(parser.parseEqual())) - return {}; + SmallVector> options; + llvm::SmallDenseSet seenOptions; + do { + StringRef optionName; + if (failed(parser.parseKeyword(&optionName))) + return {}; - int32_t value; - switch (*option) { - case LoopOptionCase::disable_licm: - case LoopOptionCase::disable_unroll: - if (succeeded(parser.parseOptionalKeyword("true"))) - value = 1; - else if (succeeded(parser.parseOptionalKeyword("false"))) - value = 0; - else { - parser.emitError(parser.getNameLoc(), - "expected boolean value 'true' or 'false'"); + auto option = symbolizeLoopOptionCase(optionName); + if (!option) { + parser.emitError(parser.getNameLoc(), "unknown loop option: ") + << optionName; return {}; } - break; - case LoopOptionCase::interleave_count: - if (failed(parser.parseInteger(value))) { - parser.emitError(parser.getNameLoc(), "expected integer value"); + if (!seenOptions.insert(*option).second) { + parser.emitError(parser.getNameLoc(), "loop option present twice"); return {}; } - break; - } + if (failed(parser.parseEqual())) + return {}; + int64_t value; + switch (*option) { + case LoopOptionCase::disable_licm: + case LoopOptionCase::disable_unroll: + if (succeeded(parser.parseOptionalKeyword("true"))) + value = 1; + else if (succeeded(parser.parseOptionalKeyword("false"))) + value = 0; + else { + parser.emitError(parser.getNameLoc(), + "expected boolean value 'true' or 'false'"); + return {}; + } + break; + case LoopOptionCase::interleave_count: + if (failed(parser.parseInteger(value))) { + parser.emitError(parser.getNameLoc(), "expected integer value"); + return {}; + } + break; + } + options.push_back(std::make_pair(*option, value)); + } while (!parser.parseOptionalComma()); if (failed(parser.parseGreater())) return {}; - return Base::get(parser.getBuilder().getContext(), - static_cast(*option), value); + llvm::sort(options, llvm::less_first()); + return get(parser.getBuilder().getContext(), options); } Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser, @@ -2543,9 +2509,6 @@ generatedAttributeParser(getContext(), parser, attrKind, type)) return attr; - if (attrKind == "loopopt") - return LoopOptionAttr::parse(parser); - parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind; return {}; } @@ -2553,7 +2516,7 @@ void LLVMDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const { if (succeeded(generatedAttributePrinter(attr, os))) return; - if (auto lopt = attr.dyn_cast()) + if (auto lopt = attr.dyn_cast()) lopt.print(os); else llvm_unreachable("Unknown attribute type"); diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -171,26 +171,27 @@ /// Returns an LLVM metadata node corresponding to a loop option. This metadata /// is attached to an llvm.loop node. static llvm::MDNode *getLoopOptionMetadata(llvm::LLVMContext &ctx, - LoopOptionAttr option) { + LoopOptionCase option, + int64_t value) { StringRef name; - llvm::Constant *value = nullptr; - switch (option.getCase()) { + llvm::Constant *cstValue = nullptr; + switch (option) { case LoopOptionCase::disable_licm: name = "llvm.licm.disable"; - value = llvm::ConstantInt::getBool(ctx, option.getBool()); + cstValue = llvm::ConstantInt::getBool(ctx, value); break; case LoopOptionCase::disable_unroll: name = "llvm.loop.unroll.disable"; - value = llvm::ConstantInt::getBool(ctx, option.getBool()); + cstValue = llvm::ConstantInt::getBool(ctx, value); break; case LoopOptionCase::interleave_count: name = "llvm.loop.interleave.count"; - value = llvm::ConstantInt::get(llvm::IntegerType::get(ctx, /*NumBits=*/32), - option.getInt()); + cstValue = llvm::ConstantInt::get( + llvm::IntegerType::get(ctx, /*NumBits=*/32), value); break; } return llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name), - llvm::ConstantAsMetadata::get(value)}); + llvm::ConstantAsMetadata::get(cstValue)}); } static void setLoopMetadata(Operation &opInst, llvm::Instruction &llvmInst, @@ -222,13 +223,12 @@ loopOptions.push_back(llvm::MDNode::get(ctx, parallelAccess)); } - auto loopOptionsAttr = - loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName()); - if (loopOptionsAttr.hasValue()) { - for (LoopOptionAttr loopOption : - loopOptionsAttr->second.cast() - .getAsRange()) - loopOptions.push_back(getLoopOptionMetadata(ctx, loopOption)); + if (LoopOptionsAttr loopOptionsAttr = + loopAttr.get(LLVMDialect::getLoopOptionsAttrName()) + .dyn_cast_or_null()) { + for (auto option : loopOptionsAttr.getOptions()) + loopOptions.push_back( + getLoopOptionMetadata(ctx, option.first, option.second)); } // Create loop options and set the first operand to itself. diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -779,7 +779,7 @@ module { llvm.func @loopOptions() { - // expected-error@below {{expected 'options' to be an array attribute}} + // expected-error@below {{expected 'options' to be a `loopopts` attribute}} llvm.br ^bb4 {llvm.loop = {options = "name"}} ^bb4: llvm.return @@ -790,8 +790,21 @@ module { llvm.func @loopOptions() { - // expected-error@below {{invalid loop options list}} - llvm.br ^bb4 {llvm.loop = {options = ["name"]}} + // expected-error@+2 {{unknown loop option: name}} + // expected-error@below {{Unknown attribute type: loopopts}} + llvm.br ^bb4 {llvm.loop = {options = #llvm.loopopts}} + ^bb4: + llvm.return + } +} + +// ----- + +module { + llvm.func @loopOptions() { + // expected-error@+2 {{loop option present twice}} + // expected-error@below {{Unknown attribute type: loopopts}} + llvm.br ^bb4 {llvm.loop = {options = #llvm.loopopts}} ^bb4: llvm.return } diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -1480,13 +1480,13 @@ ^bb3(%1: i32): %2 = llvm.icmp "slt" %1, %arg1 : i32 // CHECK: br i1 {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]] - llvm.cond_br %2, ^bb4, ^bb5 {llvm.loop = {parallel_access = [@metadata::@group1, @metadata::@group2], options = [#llvm.loopopt, #llvm.loopopt, #llvm.loopopt]}} + llvm.cond_br %2, ^bb4, ^bb5 {llvm.loop = {parallel_access = [@metadata::@group1, @metadata::@group2], options = #llvm.loopopts}} ^bb4: %3 = llvm.add %1, %arg2 : i32 // CHECK: = load i32, i32* %{{.*}} !llvm.access.group ![[ACCESS_GROUPS_NODE:[0-9]+]] %5 = llvm.load %4 { access_groups = [@metadata::@group1, @metadata::@group2] } : !llvm.ptr // CHECK: br label {{.*}} !llvm.loop ![[LOOP_NODE]] - llvm.br ^bb3(%3 : i32) {llvm.loop = {parallel_access = [@metadata::@group1, @metadata::@group2], options = [#llvm.loopopt, #llvm.loopopt, #llvm.loopopt]}} + llvm.br ^bb3(%3 : i32) {llvm.loop = {parallel_access = [@metadata::@group1, @metadata::@group2], options = #llvm.loopopts}} ^bb5: llvm.return }