diff --git a/flang/include/flang/Optimizer/Dialect/FIRDialect.td b/flang/include/flang/Optimizer/Dialect/FIRDialect.td --- a/flang/include/flang/Optimizer/Dialect/FIRDialect.td +++ b/flang/include/flang/Optimizer/Dialect/FIRDialect.td @@ -26,6 +26,11 @@ let cppNamespace = "::fir"; let useDefaultTypePrinterParser = 0; let useDefaultAttributePrinterParser = 0; + let dependentDialects = [ + // Arith dialect provides FastMathFlagsAttr + // supported by some FIR operations. + "arith::ArithDialect" + ]; } #endif // FORTRAN_DIALECT_FIR_DIALECT diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -14,6 +14,8 @@ #ifndef FORTRAN_DIALECT_FIR_OPS #define FORTRAN_DIALECT_FIR_OPS +include "mlir/Dialect/Arith/IR/ArithBase.td" +include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td" include "flang/Optimizer/Dialect/FIRDialect.td" include "flang/Optimizer/Dialect/FIRTypes.td" include "flang/Optimizer/Dialect/FIRAttr.td" @@ -2266,7 +2268,8 @@ // Procedure call operations //===----------------------------------------------------------------------===// -def fir_CallOp : fir_Op<"call", [CallOpInterface]> { +def fir_CallOp : fir_Op<"call", + [CallOpInterface, DeclareOpInterfaceMethods]> { let summary = "call a procedure"; let description = [{ @@ -2283,7 +2286,9 @@ let arguments = (ins OptionalAttr:$callee, - Variadic:$args + Variadic:$args, + DefaultValuedAttr:$fastmath ); let results = (outs Variadic); diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -19,6 +19,7 @@ #include "flang/Optimizer/Support/InternalNames.h" #include "flang/Optimizer/Support/TypeCode.h" #include "flang/Semantics/runtime-type-info.h" +#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" @@ -699,8 +700,11 @@ llvm::SmallVector resultTys; for (auto r : call.getResults()) resultTys.push_back(convertType(r.getType())); + // Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr. + mlir::arith::AttrConvertFastMathToLLVM + attrConvert(call); rewriter.replaceOpWithNewOp( - call, resultTys, adaptor.getOperands(), call->getAttrs()); + call, resultTys, adaptor.getOperands(), attrConvert.getAttrs()); return mlir::success(); } }; diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -655,8 +655,18 @@ else p << getOperand(0); p << '(' << (*this)->getOperands().drop_front(isDirect ? 0 : 1) << ')'; - p.printOptionalAttrDict((*this)->getAttrs(), - {fir::CallOp::getCalleeAttrNameStr()}); + + // Print 'fastmath<...>' (if it has non-default value) before + // any other attributes. + mlir::arith::FastMathFlagsAttr fmfAttr = getFastmathAttr(); + if (fmfAttr.getValue() != mlir::arith::FastMathFlags::none) { + p << ' ' << mlir::arith::FastMathFlagsAttr::getMnemonic(); + p.printStrippedAttrOrType(fmfAttr); + } + + p.printOptionalAttrDict( + (*this)->getAttrs(), + {fir::CallOp::getCalleeAttrNameStr(), getFastmathAttrName()}); auto resultTypes{getResultTypes()}; llvm::SmallVector argTypes( llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1)); @@ -678,8 +688,18 @@ return mlir::failure(); mlir::Type type; - if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren) || - parser.parseOptionalAttrDict(attrs) || parser.parseColon() || + if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren)) + return mlir::failure(); + + // Parse 'fastmath<...>', if present. + mlir::arith::FastMathFlagsAttr fmfAttr; + llvm::StringRef fmfAttrName = getFastmathAttrName(result.name); + if (mlir::succeeded(parser.parseOptionalKeyword(fmfAttrName))) + if (parser.parseCustomAttributeWithFallback(fmfAttr, mlir::Type{}, + fmfAttrName, attrs)) + return mlir::failure(); + + if (parser.parseOptionalAttrDict(attrs) || parser.parseColon() || parser.parseType(type)) return mlir::failure(); diff --git a/flang/test/Fir/fir-fast-math.fir b/flang/test/Fir/fir-fast-math.fir new file mode 100644 --- /dev/null +++ b/flang/test/Fir/fir-fast-math.fir @@ -0,0 +1,20 @@ +// RUN: fir-opt %s | fir-opt | FileCheck %s + +// CHECK-LABEL: @test_callop +func.func @test_callop(%arg0 : f32) { + // CHECK: fir.call @callee() : () -> () + fir.call @callee() fastmath : () -> () + // CHECK: fir.call @callee() : () -> () + fir.call @callee() {fastmath = #arith.fastmath} : () -> () + // CHECK: fir.call @callee() fastmath : () -> () + fir.call @callee() fastmath : () -> () + // CHECK: fir.call @callee() fastmath : () -> () + fir.call @callee() {fastmath = #arith.fastmath} : () -> () + // CHECK: fir.call @callee() fastmath : () -> () + fir.call @callee() fastmath : () -> () + // CHECK: fir.call @callee() fastmath : () -> () + fir.call @callee() {fastmath = #arith.fastmath} : () -> () + return +} + +func.func private @callee()