diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -1,5 +1,10 @@ add_subdirectory(Transforms) +set(LLVM_TARGET_DEFINITIONS LLVMAttrDefs.td) +mlir_tablegen(LLVMOpsAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(LLVMOpsAttrDefs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(MLIRLLVMAttrsIncGen) + set(LLVM_TARGET_DEFINITIONS LLVMOps.td) mlir_tablegen(LLVMOps.h.inc -gen-op-decls) mlir_tablegen(LLVMOps.cpp.inc -gen-op-defs) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -0,0 +1,29 @@ +//===-- LLVMAttrDefs.td - LLVM Attributes definition file --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVMIR_ATTRDEFS +#define LLVMIR_ATTRDEFS + +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + + +// All of the attributes will extend this class. +class LLVM_Attr : AttrDef; + +// The "FastMath" flags associated with floating point LLVM instructions. +def FastmathFlagsAttr : LLVM_Attr<"FMF"> { + let mnemonic = "fastmath"; + + // List of type parameters. + let parameters = ( + ins + "FastmathFlags":$flags + ); +} + +#endif // LLVMIR_ATTRDEFS diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -30,6 +30,8 @@ #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.h.inc" #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.h.inc" +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.h.inc" namespace llvm { class Type; @@ -47,24 +49,9 @@ namespace detail { struct LLVMTypeStorage; struct LLVMDialectImpl; -struct BitmaskEnumStorage; struct LoopOptionAttrStorage; } // namespace detail -/// An attribute that specifies LLVM instruction fastmath flags. -class FMFAttr : public Attribute::AttrBase { -public: - using Base::Base; - - static FMFAttr get(FastmathFlags flags, MLIRContext *context); - - FastmathFlags getFlags() const; - - void print(DialectAsmPrinter &p) const; - static Attribute parse(DialectAsmParser &parser); -}; - /// An attribute that specifies LLVM loop codegen options. class LoopOptionAttr : public Attribute::AttrBase; @@ -249,7 +249,7 @@ [{ build($_builder, $_state, IntegerType::get(lhs.getType().getContext(), 1), $_builder.getI64IntegerAttr(static_cast(predicate)), lhs, rhs, - ::mlir::LLVM::FMFAttr::get(fmf, $_builder.getContext())); + ::mlir::LLVM::FMFAttr::get($_builder.getContext(), fmf)); }]>]; let parser = [{ return parseCmpOp(parser, result); }]; let printer = [{ printFCmpOp(p, *this); }]; diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -30,7 +30,7 @@ Value real = complexStruct.real(rewriter, op.getLoc()); Value imag = complexStruct.imaginary(rewriter, op.getLoc()); - auto fmf = LLVM::FMFAttr::get({}, op.getContext()); + auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); Value sqNorm = rewriter.create( loc, rewriter.create(loc, real, real, fmf), rewriter.create(loc, imag, imag, fmf), fmf); @@ -133,7 +133,7 @@ auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. - auto fmf = LLVM::FMFAttr::get({}, op.getContext()); + auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); Value real = rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); Value imag = @@ -161,7 +161,7 @@ auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. - auto fmf = LLVM::FMFAttr::get({}, op.getContext()); + auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); Value rhsRe = arg.rhs.real(); Value rhsIm = arg.rhs.imag(); Value lhsRe = arg.lhs.real(); @@ -206,7 +206,7 @@ auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. - auto fmf = LLVM::FMFAttr::get({}, op.getContext()); + auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); Value rhsRe = arg.rhs.real(); Value rhsIm = arg.rhs.imag(); Value lhsRe = arg.lhs.real(); @@ -243,7 +243,7 @@ auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to substract complex numbers. - auto fmf = LLVM::FMFAttr::get({}, op.getContext()); + auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); Value real = rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); Value imag = diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -829,7 +829,7 @@ operation, dstType, rewriter.getI64IntegerAttr(static_cast(predicate)), operation.operand1(), operation.operand2(), - LLVM::FMFAttr::get({}, operation.getContext())); + LLVM::FMFAttr::get(operation.getContext(), {})); return success(); } }; diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2981,7 +2981,7 @@ ConversionPatternRewriter &rewriter) const override { CmpFOpAdaptor transformed(operands); - auto fmf = LLVM::FMFAttr::get({}, cmpfOp.getContext()); + auto fmf = LLVM::FMFAttr::get(cmpfOp.getContext(), {}); rewriter.replaceOpWithNewOp( cmpfOp, typeConverter->convertType(cmpfOp.getResult().getType()), rewriter.getI64IntegerAttr(static_cast( diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/MLIRContext.h" #include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/AsmParser/Parser.h" #include "llvm/Bitcode/BitcodeReader.h" #include "llvm/Bitcode/BitcodeWriter.h" @@ -37,25 +38,12 @@ #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc" +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc" namespace mlir { namespace LLVM { namespace detail { -struct BitmaskEnumStorage : public AttributeStorage { - using KeyTy = uint64_t; - - BitmaskEnumStorage(KeyTy val) : value(val) {} - - bool operator==(const KeyTy &key) const { return value == key; } - - static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator, - const KeyTy &key) { - return new (allocator.allocate()) - BitmaskEnumStorage(key); - } - - KeyTy value = 0; -}; struct LoopOptionAttrStorage : public AttributeStorage { using KeyTy = std::pair; @@ -84,7 +72,7 @@ SmallVector filteredAttrs( llvm::make_filter_range(attrs, [&](NamedAttribute attr) { if (attr.first == "fastmathFlags") { - auto defAttr = FMFAttr::get({}, attr.second.getContext()); + auto defAttr = FMFAttr::get(attr.second.getContext(), {}); return defAttr != attr.second; } return true; @@ -2387,14 +2375,6 @@ op->hasTrait(); } -FMFAttr FMFAttr::get(FastmathFlags flags, MLIRContext *context) { - return Base::get(context, static_cast(flags)); -} - -FastmathFlags FMFAttr::getFlags() const { - return static_cast(getImpl()->value); -} - static constexpr const FastmathFlags FastmathFlagsList[] = { // clang-format off FastmathFlags::nnan, @@ -2418,7 +2398,8 @@ printer << ">"; } -Attribute FMFAttr::parse(DialectAsmParser &parser) { +Attribute FMFAttr::parse(MLIRContext *context, DialectAsmParser &parser, + Type type) { if (failed(parser.parseLess())) return {}; @@ -2443,7 +2424,7 @@ return {}; } - return FMFAttr::get(flags, parser.getBuilder().getContext()); + return FMFAttr::get(parser.getBuilder().getContext(), flags); } LoopOptionAttr LoopOptionAttr::getDisableUnroll(MLIRContext *context, @@ -2558,9 +2539,9 @@ StringRef attrKind; if (parser.parseKeyword(&attrKind)) return {}; - - if (attrKind == "fastmath") - return FMFAttr::parse(parser); + if (auto attr = + generatedAttributeParser(getContext(), parser, attrKind, type)) + return attr; if (attrKind == "loopopt") return LoopOptionAttr::parse(parser); @@ -2570,9 +2551,9 @@ } void LLVMDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const { - if (auto fmf = attr.dyn_cast()) - fmf.print(os); - else if (auto lopt = attr.dyn_cast()) + if (succeeded(generatedAttributePrinter(attr, os))) + return; + if (auto lopt = attr.dyn_cast()) lopt.print(os); else llvm_unreachable("Unknown attribute type");