Index: mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -10,6 +10,8 @@ add_mlir_doc(LLVMOps -gen-op-doc LLVMOps Dialects/) +add_mlir_interface(LLVMOpsInterfaces) + set(LLVM_TARGET_DEFINITIONS LLVMOps.td) mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions) mlir_tablegen(LLVMConversionEnumsToLLVM.inc -gen-enum-to-llvmir-conversions) Index: mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -29,6 +29,7 @@ #include "llvm/IR/Type.h" #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.h.inc" +#include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.h.inc" namespace llvm { class Type; @@ -46,8 +47,23 @@ namespace detail { struct LLVMTypeStorage; struct LLVMDialectImpl; +struct BitmaskEnumStorage; } // namespace detail +/// An attribute that specifies LLVM instruction fastmath flags +class FMFAttr : public Attribute::AttrBase { +public: + using Base::Base; + + static FMFAttr get(FastmathFlags flags, MLIRContext *context); + + FastmathFlags getFlags() const; + + void print(DialectAsmPrinter &p) const; + static Attribute parse(DialectAsmParser &parser); +}; + } // namespace LLVM } // namespace mlir Index: mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -14,10 +14,39 @@ #define LLVMIR_OPS include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" +def FMFnnan : BitEnumAttrCase<"nnan", 0x1>; +def FMFninf : BitEnumAttrCase<"ninf", 0x2>; +def FMFnsz : BitEnumAttrCase<"nsz", 0x4>; +def FMFarcp : BitEnumAttrCase<"arcp", 0x8>; +def FMFcontract : BitEnumAttrCase<"contract", 0x10>; +def FMFafn : BitEnumAttrCase<"afn", 0x20>; +def FMFreassoc : BitEnumAttrCase<"reassoc", 0x40>; +def FMFfast : BitEnumAttrCase<"fast", 0x80>; + +def FastmathFlags : BitEnumAttr< + "FastmathFlags", + "LLVM fastmath flags", + [FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn, FMFreassoc, FMFfast + ]> { + let cppNamespace = "::mlir::LLVM"; +} + +def LLVM_FMFAttr : DialectAttr< + LLVM_Dialect, + CPred<"$_self.isa<::mlir::LLVM::FMFAttr>()">, + "LLVM fastmath flags"> { + let storageType = "::mlir::LLVM::FMFAttr"; + let returnType = "::mlir::LLVM::FastmathFlags"; + let convertFromStorage = "$_self.getFlags()"; + let constBuilderCall = + "::mlir::LLVM::FMFAttr::get($0, $_builder.getContext())"; +} + class LLVM_Builder { string llvmBuilder = builder; } @@ -78,29 +107,38 @@ LLVM_Op, LLVM_Builder<"$res = builder." # builderFunc # "($lhs, $rhs);"> { - let arguments = (ins LLVM_ScalarOrVectorOf:$lhs, - LLVM_ScalarOrVectorOf:$rhs); + dag commonArgs = (ins LLVM_ScalarOrVectorOf:$lhs, + LLVM_ScalarOrVectorOf:$rhs); let results = (outs LLVM_ScalarOrVectorOf:$res); let builders = [LLVM_OneResultOpBuilder]; - let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($res)"; } class LLVM_IntArithmeticOp traits = []> : - LLVM_ArithmeticOpBase; + LLVM_ArithmeticOpBase { + let arguments = commonArgs; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($res)"; +} class LLVM_FloatArithmeticOp traits = []> : - LLVM_ArithmeticOpBase; + LLVM_ArithmeticOpBase], traits)> { + dag fmfArg = (ins DefaultValuedAttr:$fastmathFlags); + let arguments = !con(commonArgs, fmfArg); + let parser = [{ return parseFloadBinOp(parser, result); }]; + let printer = [{ printFloadBinOp(p, *this); }]; +} // Class for arithmetic unary operations. -class LLVM_UnaryArithmeticOp traits = []> : LLVM_Op, + !listconcat([NoSideEffect, SameOperandsAndResultType, DeclareOpInterfaceMethods], traits)>, LLVM_Builder<"$res = builder." # builderFunc # "($operand);"> { - let arguments = (ins type:$operand); + let arguments = (ins type:$operand, DefaultValuedAttr:$fastmathFlags); let results = (outs type:$res); let builders = [LLVM_OneResultOpBuilder]; - let assemblyFormat = "$operand attr-dict `:` type($res)"; + let parser = [{ return parseFloadUnOp(parser, result); }]; + let printer = [{ printFloadUnOp(p, *this); }]; } // Integer binary operations. @@ -187,19 +225,23 @@ } // Other integer operations. -def LLVM_FCmpOp : LLVM_Op<"fcmp", [NoSideEffect]> { +def LLVM_FCmpOp : LLVM_Op<"fcmp", [ + NoSideEffect, DeclareOpInterfaceMethods]> { let arguments = (ins FCmpPredicate:$predicate, LLVM_ScalarOrVectorOf:$lhs, - LLVM_ScalarOrVectorOf:$rhs); + LLVM_ScalarOrVectorOf:$rhs, + DefaultValuedAttr:$fastmathFlags); let results = (outs LLVM_ScalarOrVectorOf:$res); let llvmBuilder = [{ $res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); }]; let builders = [ - OpBuilderDAG<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs), + OpBuilderDAG<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs, + CArg<"FastmathFlags", "{}">:$fmf), [{ build($_builder, $_state, LLVMIntegerType::get(lhs.getType().getContext(), 1), - $_builder.getI64IntegerAttr(static_cast(predicate)), lhs, rhs); + $_builder.getI64IntegerAttr(static_cast(predicate)), lhs, rhs, + ::mlir::LLVM::FMFAttr::get(fmf, $_builder.getContext())); }]>]; let parser = [{ return parseCmpOp(parser, result); }]; let printer = [{ printFCmpOp(p, *this); }]; @@ -211,8 +253,8 @@ def LLVM_FMulOp : LLVM_FloatArithmeticOp<"fmul", "CreateFMul">; def LLVM_FDivOp : LLVM_FloatArithmeticOp<"fdiv", "CreateFDiv">; def LLVM_FRemOp : LLVM_FloatArithmeticOp<"frem", "CreateFRem">; -def LLVM_FNegOp : LLVM_UnaryArithmeticOp, - "fneg", "CreateFNeg">; +def LLVM_FNegOp : LLVM_UnaryFloatArithmeticOp< + LLVM_ScalarOrVectorOf, "fneg", "CreateFNeg">; // Common code definition that is used to verify and set the alignment attribute // of LLVM ops that accept such an attribute. @@ -406,7 +448,8 @@ let printer = [{ printLandingpadOp(p, *this); }]; } -def LLVM_CallOp : LLVM_Op<"call"> { +def LLVM_CallOp : LLVM_Op<"call", + [DeclareOpInterfaceMethods]> { let summary = "Call to an LLVM function."; let description = [{ @@ -437,7 +480,8 @@ ``` }]; let arguments = (ins OptionalAttr:$callee, - Variadic); + Variadic, + DefaultValuedAttr:$fastmathFlags); let results = (outs Variadic); let builders = [ OpBuilderDAG<(ins "LLVMFuncOp":$func, "ValueRange":$operands, Index: mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td @@ -0,0 +1,30 @@ +//===-- LLVMOpsInterfaces.td - LLVM op interfaces ----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the LLVM IR interfaces definition file. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_OPS_INTERFACES +#define LLVM_OPS_INTERFACES + +include "mlir/IR/OpBase.td" + +def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> { + let description = [{ + Access to op fastmath flags. + }]; + + let cppNamespace = "::mlir::LLVM"; + + let methods = [ + InterfaceMethod<"Get fastmath flags", "::mlir::LLVM::FastmathFlags", "fastmathFlags">, + ]; +} + +#endif // LLVM_OPS_INTERFACES Index: mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp =================================================================== --- mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -834,7 +834,8 @@ rewriter.template replaceOpWithNewOp( operation, dstType, rewriter.getI64IntegerAttr(static_cast(predicate)), - operation.operand1(), operation.operand2()); + operation.operand1(), operation.operand2(), + LLVM::FMFAttr::get({}, operation.getContext())); return success(); } }; Index: mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp =================================================================== --- mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1849,10 +1849,11 @@ auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. + auto fmf = LLVM::FMFAttr::get({}, op.getContext()); Value real = - rewriter.create(loc, arg.lhs.real(), arg.rhs.real()); + rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); Value imag = - rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag()); + rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); result.setReal(rewriter, loc, real); result.setImaginary(rewriter, loc, imag); @@ -1876,10 +1877,11 @@ auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to substract complex numbers. + auto fmf = LLVM::FMFAttr::get({}, op.getContext()); Value real = - rewriter.create(loc, arg.lhs.real(), arg.rhs.real()); + rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); Value imag = - rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag()); + rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); result.setReal(rewriter, loc, real); result.setImaginary(rewriter, loc, imag); @@ -3177,11 +3179,12 @@ ConversionPatternRewriter &rewriter) const override { CmpFOpAdaptor transformed(operands); + auto fmf = LLVM::FMFAttr::get({}, cmpfOp.getContext()); rewriter.replaceOpWithNewOp( cmpfOp, typeConverter->convertType(cmpfOp.getResult().getType()), rewriter.getI64IntegerAttr(static_cast( convertCmpPredicate(cmpfOp.getPredicate()))), - transformed.lhs(), transformed.rhs()); + transformed.lhs(), transformed.rhs(), fmf); return success(); } Index: mlir/lib/Dialect/LLVMIR/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -10,6 +10,7 @@ DEPENDS MLIRLLVMOpsIncGen + MLIRLLVMOpsInterfacesIncGen MLIROpenMPOpsIncGen intrinsics_gen Index: mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -36,6 +36,117 @@ static constexpr const char kNonTemporalAttrName[] = "nontemporal"; #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" +#include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc" + +namespace mlir { +namespace LLVM { +namespace detail { +struct BitmaskEnumStorage : public AttributeStorage { + using KeyTy = uint64_t; + + BitmaskEnumStorage(KeyTy val) : value(val) {} + + bool operator==(const KeyTy &key) const { return value == key; } + + static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + BitmaskEnumStorage(key); + } + + KeyTy value = 0; +}; +} // namespace detail +} // namespace LLVM +} // namespace mlir + +static auto processFMFAttr(ArrayRef attrs) { + SmallVector filteredAttrs( + llvm::make_filter_range(attrs, [&](NamedAttribute attr) { + if (attr.first == "fastmathFlags") { + auto defAttr = FMFAttr::get({}, attr.second.getContext()); + return defAttr != attr.second; + } + return true; + })); + return filteredAttrs; +} + +template +static void printFloadUnOp(OpAsmPrinter &p, Op &op) { + p << op.getOperationName() << " " << op.getOperand(); + p.printOptionalAttrDict(processFMFAttr(op.getAttrs())); + p << " : " << op.res().getType(); +} + +ParseResult parseFloadUnOp(OpAsmParser &parser, OperationState &result) { + // TODO: copypasted from tbgenerated code as we cannont use ""let printer" and + // "let assemblyFormat" together + OpAsmParser::OperandType operandRawOperands[1]; + ArrayRef operandOperands(operandRawOperands); + llvm::SMLoc operandOperandsLoc; + mlir::Type resRawTypes[1]; + ArrayRef resTypes(resRawTypes); + + operandOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(operandRawOperands[0])) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + if (parser.parseColon()) + return failure(); + + if (parser.parseType(resRawTypes[0])) + return failure(); + result.addTypes(resTypes); + if (parser.resolveOperands(operandOperands, resTypes[0], result.operands)) + return failure(); + return success(); +} + +template +static void printFloadBinOp(OpAsmPrinter &p, Op &op) { + p << op.getOperationName() << " " << op.getOperand(0) << ", " + << op.getOperand(1); + p.printOptionalAttrDict(processFMFAttr(op.getAttrs())); + p << " : " << op.res().getType(); +} + +ParseResult parseFloadBinOp(OpAsmParser &parser, OperationState &result) { + // TODO: copypasted from tbgenerated code as we cannont use ""let printer" and + // "let assemblyFormat" together + OpAsmParser::OperandType lhsRawOperands[1]; + ArrayRef<::mlir::OpAsmParser::OperandType> lhsOperands(lhsRawOperands); + llvm::SMLoc lhsOperandsLoc; + OpAsmParser::OperandType rhsRawOperands[1]; + ArrayRef<::mlir::OpAsmParser::OperandType> rhsOperands(rhsRawOperands); + llvm::SMLoc rhsOperandsLoc; + Type resRawTypes[1]; + ArrayRef<::mlir::Type> resTypes(resRawTypes); + + lhsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(lhsRawOperands[0])) + return failure(); + if (parser.parseComma()) + return failure(); + + rhsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(rhsRawOperands[0])) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + if (parser.parseColon()) + return failure(); + + if (parser.parseType(resRawTypes[0])) + return failure(); + result.addTypes(resTypes); + if (parser.resolveOperands(lhsOperands, resTypes[0], result.operands)) + return failure(); + if (parser.resolveOperands(rhsOperands, resTypes[0], result.operands)) + return failure(); + return success(); +} //===----------------------------------------------------------------------===// // Printing/parsing for LLVM::CmpOp. @@ -50,7 +161,7 @@ static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) { p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate()) << "\" " << op.getOperand(0) << ", " << op.getOperand(1); - p.printOptionalAttrDict(op.getAttrs(), {"predicate"}); + p.printOptionalAttrDict(processFMFAttr(op.getAttrs()), {"predicate"}); p << " : " << op.lhs().getType(); } @@ -775,7 +886,7 @@ auto args = op.getOperands().drop_front(isDirect ? 0 : 1); p << '(' << args << ')'; - p.printOptionalAttrDict(op.getAttrs(), {"callee"}); + p.printOptionalAttrDict(processFMFAttr(op.getAttrs()), {"callee"}); // Reconstruct the function MLIR function type from operand and result types. p << " : " @@ -2048,6 +2159,8 @@ //===----------------------------------------------------------------------===// void LLVMDialect::initialize() { + addAttributes(); + // clang-format off addTypeshasTrait() && op->hasTrait(); } + +FMFAttr FMFAttr::get(FastmathFlags flags, MLIRContext *context) { + return Base::get(context, static_cast(flags)); +} + +FastmathFlags FMFAttr::getFlags() const { + return static_cast(getImpl()->value); +} + +static constexpr const FastmathFlags FastmathFlagsList[] = { + // clang-format off + FastmathFlags::nnan, + FastmathFlags::ninf, + FastmathFlags::nsz, + FastmathFlags::arcp, + FastmathFlags::contract, + FastmathFlags::afn, + FastmathFlags::reassoc, + FastmathFlags::fast, + // clang-format on +}; + +void FMFAttr::print(DialectAsmPrinter &printer) const { + printer << "fastmath<"; + auto flags = llvm::make_filter_range(FastmathFlagsList, [&](auto flag) { + return bitEnumContains(getFlags(), flag); + }); + llvm::interleaveComma(flags, printer, + [&](auto flag) { printer << stringifyEnum(flag); }); + printer << ">"; +} + +Attribute FMFAttr::parse(DialectAsmParser &parser) { + if (failed(parser.parseLess())) + return {}; + + FastmathFlags flags = {}; + if (failed(parser.parseOptionalGreater())) { + do { + StringRef elemName; + if (failed(parser.parseKeyword(&elemName))) + return {}; + + auto elem = symbolizeFastmathFlags(elemName); + if (!elem) { + parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ") + << elemName; + return {}; + } + + flags = flags | *elem; + } while (succeeded(parser.parseOptionalComma())); + + if (failed(parser.parseGreater())) + return {}; + } + + return FMFAttr::get(flags, parser.getBuilder().getContext()); +} + +Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser, + Type type) const { + if (type) { + parser.emitError(parser.getNameLoc(), "unexpected type"); + return {}; + } + StringRef attrKind; + if (parser.parseKeyword(&attrKind)) + return {}; + + if (attrKind == "fastmath") + return FMFAttr::parse(parser); + + parser.emitError(parser.getNameLoc(), "Unknown attrribute type: ") + << attrKind; + return {}; +} + +void LLVMDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const { + if (auto fmf = attr.dyn_cast()) + fmf.print(os); + else + llvm_unreachable("Unknown attribute type"); +} Index: mlir/lib/Target/LLVMIR/ModuleTranslation.cpp =================================================================== --- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -710,6 +710,29 @@ }); } +static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) { + using llvmFMF = llvm::FastMathFlags; + using FuncT = void (llvmFMF::*)(bool); + const std::pair handlers[] = { + // clang-format off + {FastmathFlags::nnan, &llvmFMF::setNoNaNs}, + {FastmathFlags::ninf, &llvmFMF::setNoInfs}, + {FastmathFlags::nsz, &llvmFMF::setNoSignedZeros}, + {FastmathFlags::arcp, &llvmFMF::setAllowReciprocal}, + {FastmathFlags::contract, &llvmFMF::setAllowContract}, + {FastmathFlags::afn, &llvmFMF::setApproxFunc}, + {FastmathFlags::reassoc, &llvmFMF::setAllowReassoc}, + {FastmathFlags::fast, &llvmFMF::setFast}, + // clang-format on + }; + llvm::FastMathFlags ret; + auto fmf = op.fastmathFlags(); + for (auto it : handlers) + if (bitEnumContains(fmf, it.first)) + (ret.*(it.second))(true); + return ret; +} + /// Given a single MLIR operation, create the corresponding LLVM IR operation /// using the `builder`. LLVM IR Builder does not have a generic interface so /// this has to be a long chain of `if`s calling different functions with a @@ -724,6 +747,10 @@ return position; }; + llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder); + if (auto fmf = dyn_cast(opInst)) + builder.setFastMathFlags(getFastmathFlags(fmf)); + #include "mlir/Dialect/LLVMIR/LLVMConversions.inc" // Emit function calls. If the "callee" attribute is present, this is a Index: mlir/test/Dialect/LLVMIR/roundtrip.mlir =================================================================== --- mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -387,3 +387,35 @@ llvm.return } + +// CHECK-LABEL: @fastmathFlags +func @fastmathFlags(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm.i32) { +// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float +// CHECK: {{.*}} = llvm.fsub %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float +// CHECK: {{.*}} = llvm.fmul %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float +// CHECK: {{.*}} = llvm.fdiv %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float +// CHECK: {{.*}} = llvm.frem %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float + %0 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float + %1 = llvm.fsub %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float + %2 = llvm.fmul %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float + %3 = llvm.fdiv %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float + %4 = llvm.frem %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float + +// CHECK: {{.*}} = llvm.fcmp "oeq" %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float + %5 = llvm.fcmp "oeq" %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float + +// CHECK: {{.*}} = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath} : !llvm.float + %6 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath} : !llvm.float + +// CHECK: {{.*}} = llvm.call @foo(%arg2) {fastmathFlags = #llvm.fastmath} : (!llvm.i32) -> !llvm.struct<(i32, double, i32)> + %7 = llvm.call @foo(%arg2) {fastmathFlags = #llvm.fastmath} : (!llvm.i32) -> !llvm.struct<(i32, double, i32)> + +// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 : !llvm.float + %8 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<>} : !llvm.float +// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float + %9 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float + +// CHECK: {{.*}} = llvm.fneg %arg0 : !llvm.float + %10 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<>} : !llvm.float + return +} Index: mlir/test/Target/llvmir.mlir =================================================================== --- mlir/test/Target/llvmir.mlir +++ mlir/test/Target/llvmir.mlir @@ -1359,6 +1359,49 @@ } // ----- +llvm.func @fastmathFlagsFunc(!llvm.float) -> !llvm.float + +// CHECK-LABEL: @fastmathFlags +llvm.func @fastmathFlags(%arg0: !llvm.float) { +// CHECK: {{.*}} = fadd nnan ninf float {{.*}}, {{.*}} +// CHECK: {{.*}} = fsub nnan ninf float {{.*}}, {{.*}} +// CHECK: {{.*}} = fmul nnan ninf float {{.*}}, {{.*}} +// CHECK: {{.*}} = fdiv nnan ninf float {{.*}}, {{.*}} +// CHECK: {{.*}} = frem nnan ninf float {{.*}}, {{.*}} + %0 = llvm.fadd %arg0, %arg0 {fastmathFlags = #llvm.fastmath} : !llvm.float + %1 = llvm.fsub %arg0, %arg0 {fastmathFlags = #llvm.fastmath} : !llvm.float + %2 = llvm.fmul %arg0, %arg0 {fastmathFlags = #llvm.fastmath} : !llvm.float + %3 = llvm.fdiv %arg0, %arg0 {fastmathFlags = #llvm.fastmath} : !llvm.float + %4 = llvm.frem %arg0, %arg0 {fastmathFlags = #llvm.fastmath} : !llvm.float + +// CHECK: {{.*}} = fcmp nnan ninf oeq {{.*}}, {{.*}} + %5 = llvm.fcmp "oeq" %arg0, %arg0 {fastmathFlags = #llvm.fastmath} : !llvm.float + +// CHECK: {{.*}} = fneg nnan ninf float {{.*}} + %6 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath} : !llvm.float + +// CHECK: {{.*}} = call float @fastmathFlagsFunc({{.*}}) +// CHECK: {{.*}} = call nnan float @fastmathFlagsFunc({{.*}}) +// CHECK: {{.*}} = call ninf float @fastmathFlagsFunc({{.*}}) +// CHECK: {{.*}} = call nsz float @fastmathFlagsFunc({{.*}}) +// CHECK: {{.*}} = call arcp float @fastmathFlagsFunc({{.*}}) +// CHECK: {{.*}} = call contract float @fastmathFlagsFunc({{.*}}) +// CHECK: {{.*}} = call afn float @fastmathFlagsFunc({{.*}}) +// CHECK: {{.*}} = call reassoc float @fastmathFlagsFunc({{.*}}) +// CHECK: {{.*}} = call fast float @fastmathFlagsFunc({{.*}}) + %8 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<>} : (!llvm.float) -> (!llvm.float) + %9 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (!llvm.float) -> (!llvm.float) + %10 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (!llvm.float) -> (!llvm.float) + %11 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (!llvm.float) -> (!llvm.float) + %12 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (!llvm.float) -> (!llvm.float) + %13 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (!llvm.float) -> (!llvm.float) + %14 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (!llvm.float) -> (!llvm.float) + %15 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (!llvm.float) -> (!llvm.float) + %16 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (!llvm.float) -> (!llvm.float) + llvm.return +} + +// ----- // CHECK-LABEL: @switch_args llvm.func @switch_args(%arg0: !llvm.i32) {