diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -1,10 +1,5 @@ add_subdirectory(Transforms) -set(LLVM_TARGET_DEFINITIONS LLVMAttrDefs.td) -mlir_tablegen(LLVMOpsAttrDefs.h.inc -gen-attrdef-decls) -mlir_tablegen(LLVMOpsAttrDefs.cpp.inc -gen-attrdef-defs) -add_public_tablegen_target(MLIRLLVMAttrsIncGen) - set(LLVM_TARGET_DEFINITIONS LLVMOps.td) mlir_tablegen(LLVMOps.h.inc -gen-op-decls) mlir_tablegen(LLVMOps.cpp.inc -gen-op-defs) @@ -12,6 +7,10 @@ mlir_tablegen(LLVMOpsDialect.cpp.inc -gen-dialect-defs) mlir_tablegen(LLVMOpsEnums.h.inc -gen-enum-decls) mlir_tablegen(LLVMOpsEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(LLVMOpsAttrDefs.h.inc -gen-attrdef-decls + -attrdefs-dialect=llvm) +mlir_tablegen(LLVMOpsAttrDefs.cpp.inc -gen-attrdef-defs + -attrdefs-dialect=llvm) add_public_tablegen_target(MLIRLLVMOpsIncGen) set(LLVM_TARGET_DEFINITIONS LLVMIntrinsicOps.td) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -15,17 +15,6 @@ // All of the attributes will extend this class. class LLVM_Attr : AttrDef; -// The "FastMath" flags associated with floating point LLVM instructions. -def FastmathFlagsAttr : LLVM_Attr<"FMF"> { - let mnemonic = "fastmath"; - - // List of type parameters. - let parameters = (ins - "FastmathFlags":$flags - ); - let hasCustomAssemblyFormat = 1; -} - // Attribute definition for the LLVM Linkage enum. def LinkageAttr : LLVM_Attr<"Linkage"> { let mnemonic = "linkage"; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -13,7 +13,9 @@ #ifndef LLVMIR_OPS #define LLVMIR_OPS +include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/IR/EnumAttr.td" include "mlir/IR/FunctionInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" @@ -21,6 +23,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +def FMFnone : I32BitEnumAttrCaseNone<"none">; def FMFnnan : I32BitEnumAttrCaseBit<"nnan", 0>; def FMFninf : I32BitEnumAttrCaseBit<"ninf", 1>; def FMFnsz : I32BitEnumAttrCaseBit<"nsz", 2>; @@ -34,22 +37,18 @@ def FastmathFlags : I32BitEnumAttr< "FastmathFlags", "LLVM fastmath flags", - [FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn, FMFreassoc, FMFfast - ]> { + [FMFnone, FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn, + FMFreassoc, FMFfast]> { let separator = ", "; let cppNamespace = "::mlir::LLVM"; + let genSpecializedAttr = 0; let printBitEnumPrimaryGroups = 1; } -def LLVM_FMFAttr : DialectAttr< - LLVM_Dialect, - CPred<"$_self.isa<::mlir::LLVM::FMFAttr>()">, - "LLVM fastmath flags"> { - let storageType = "::mlir::LLVM::FMFAttr"; - let returnType = "::mlir::LLVM::FastmathFlags"; - let convertFromStorage = "$_self.getFlags()"; - let constBuilderCall = - "::mlir::LLVM::FMFAttr::get($_builder.getContext(), $0)"; +// The "FastMath" flags associated with floating point LLVM instructions. +def LLVM_FastmathFlagsAttr : + EnumAttr { + let assemblyFormat = "`<` $value `>`"; } def LOptDisableUnroll : I32EnumAttrCase<"disable_unroll", 1>; @@ -229,7 +228,8 @@ list traits = []> : LLVM_ArithmeticOpBase], traits)> { - dag fmfArg = (ins DefaultValuedAttr:$fastmathFlags); + dag fmfArg = ( + ins DefaultValuedAttr:$fastmathFlags); let arguments = !con(commonArgs, fmfArg); } @@ -239,7 +239,9 @@ LLVM_Op], traits)>, LLVM_Builder<"$res = builder.Create" # instName # "($operand);"> { - let arguments = (ins type:$operand, DefaultValuedAttr:$fastmathFlags); + let arguments = ( + ins type:$operand, + DefaultValuedAttr:$fastmathFlags); let results = (outs type:$res); let builders = [LLVM_OneResultOpBuilder]; let assemblyFormat = "$operand custom(attr-dict) `:` type($res)"; @@ -354,7 +356,8 @@ let arguments = (ins FCmpPredicate:$predicate, LLVM_ScalarOrVectorOf:$lhs, LLVM_ScalarOrVectorOf:$rhs, - DefaultValuedAttr:$fastmathFlags); + DefaultValuedAttr:$fastmathFlags); let builders = [ OpBuilder<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs)> ]; @@ -747,7 +750,8 @@ let arguments = (ins OptionalAttr:$callee, Variadic, - DefaultValuedAttr:$fastmathFlags); + DefaultValuedAttr:$fastmathFlags); let results = (outs Optional:$result); let builders = [ diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -72,7 +72,7 @@ Value real = complexStruct.real(rewriter, op.getLoc()); Value imag = complexStruct.imaginary(rewriter, op.getLoc()); - auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); + auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); Value sqNorm = rewriter.create( loc, rewriter.create(loc, real, real, fmf), rewriter.create(loc, imag, imag, fmf), fmf); @@ -180,7 +180,7 @@ auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. - auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); + auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); Value real = rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); Value imag = @@ -208,7 +208,7 @@ auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. - auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); + auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); Value rhsRe = arg.rhs.real(); Value rhsIm = arg.rhs.imag(); Value lhsRe = arg.lhs.real(); @@ -253,7 +253,7 @@ auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. - auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); + auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); Value rhsRe = arg.rhs.real(); Value rhsIm = arg.rhs.imag(); Value lhsRe = arg.lhs.real(); @@ -290,7 +290,7 @@ auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to substract complex numbers. - auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); + auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); Value real = rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); Value imag = diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -54,7 +54,8 @@ SmallVector filteredAttrs( llvm::make_filter_range(attrs, [&](NamedAttribute attr) { if (attr.getName() == "fastmathFlags") { - auto defAttr = FMFAttr::get(attr.getValue().getContext(), {}); + auto defAttr = + FastmathFlagsAttr::get(attr.getValue().getContext(), {}); return defAttr != attr.getValue(); } return true; @@ -2563,7 +2564,7 @@ //===----------------------------------------------------------------------===// void LLVMDialect::initialize() { - addAttributes(); + addAttributes(); // clang-format off addTypeshasTrait(); } -void FMFAttr::print(AsmPrinter &printer) const { - printer << "<"; - printer << stringifyFastmathFlags(this->getFlags()); - printer << ">"; -} - -Attribute FMFAttr::parse(AsmParser &parser, Type type) { - if (failed(parser.parseLess())) - return {}; - - FastmathFlags flags = {}; - if (failed(parser.parseOptionalGreater())) { - auto parseFlags = [&]() -> ParseResult { - StringRef elemName; - if (failed(parser.parseKeyword(&elemName))) - return failure(); - - auto elem = symbolizeFastmathFlags(elemName); - if (!elem) - return parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ") - << elemName; - - flags = flags | *elem; - return success(); - }; - if (failed(parser.parseCommaSeparatedList(parseFlags)) || - parser.parseGreater()) - return {}; - } - - return FMFAttr::get(parser.getContext(), flags); -} - void LinkageAttr::print(AsmPrinter &printer) const { printer << "<"; if (static_cast(getLinkage()) <= getMaxEnumValForLinkage()) diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -477,12 +477,12 @@ %7 = llvm.call @foo(%arg2) {fastmathFlags = #llvm.fastmath} : (i32) -> !llvm.struct<(i32, f64, i32)> // CHECK: {{.*}} = llvm.fadd %arg0, %arg1 : f32 - %8 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<>} : f32 + %8 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 // CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 %9 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 // CHECK: {{.*}} = llvm.fneg %arg0 : f32 - %10 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<>} : f32 + %10 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath} : f32 return } diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1666,7 +1666,7 @@ // CHECK: {{.*}} = call afn float @fastmathFlagsFunc({{.*}}) // CHECK: {{.*}} = call reassoc float @fastmathFlagsFunc({{.*}}) // CHECK: {{.*}} = call fast float @fastmathFlagsFunc({{.*}}) - %8 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<>} : (f32) -> (f32) + %8 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (f32) -> (f32) %9 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (f32) -> (f32) %10 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (f32) -> (f32) %11 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (f32) -> (f32)