diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -23,6 +23,11 @@ def Arith_FastMathAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; + let extraClassDeclaration = [{ + bool hasFlags(FastMathFlags f) const { + return (f == (getValue() & f)); + } + }]; } // Base class for Arith dialect ops. Ops in this dialect have no side @@ -69,7 +74,8 @@ !listconcat([DeclareOpInterfaceMethods], traits)>, Arguments<(ins FloatLike:$operand, - DefaultValuedAttr:$fastmath)>, + DefaultValuedAttr:$fastmath)>, Results<(outs FloatLike:$result)> { let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)? attr-dict `:` type($result) }]; @@ -81,7 +87,8 @@ !listconcat([DeclareOpInterfaceMethods], traits)>, Arguments<(ins FloatLike:$lhs, FloatLike:$rhs, - DefaultValuedAttr:$fastmath)>, + DefaultValuedAttr:$fastmath)>, Results<(outs FloatLike:$result)> { let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)? attr-dict `:` type($result) }]; diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td @@ -34,6 +34,17 @@ return op.getFastmathAttr(); }] >, + InterfaceMethod< + /*desc=*/ "Sets the FastMathFlagsAttr attribute for the operation", + /*returnType=*/ "void", + /*methodName=*/ "setFastMathFlagsAttr", + /*args=*/ (ins "FastMathFlagsAttr":$value), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + ConcreteOp op = cast(this->getOperation()); + op.setFastmathAttr(value); + }] + >, StaticInterfaceMethod< /*desc=*/ [{Returns the name of the FastMathFlagsAttr attribute for the operation}], @@ -45,7 +56,6 @@ return "fastmath"; }] > - ]; } diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_ARITH_TRANSFORMS_PASSES_H_ #define MLIR_DIALECT_ARITH_TRANSFORMS_PASSES_H_ +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Pass/Pass.h" namespace mlir { @@ -44,6 +45,11 @@ /// equivalent. std::unique_ptr createArithUnsignedWhenEquivalentPass(); +/// Create a pass to set the fastmath value for operations that support the +/// fastmath interface. +std::unique_ptr +createArithSetFastMathPass(FastMathFlags fm = FastMathFlags::none); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td @@ -69,4 +69,17 @@ let dependentDialects = ["vector::VectorDialect"]; } +def ArithSetFastMath : Pass<"set-fastmath"> { + let summary = "Set the fastmath attribute for all fastmath operations"; + let description = [{ + This pass sets the arith dialect fastmath attribute for all operations that + support the ArithFastMathInterface. + }]; + let options = [ + ListOption<"flags", "value", "std::string", + "Comma separated list of fastmath flag bits"> + ]; + let constructor = "mlir::arith::createArithSetFastMathPass()"; +} + #endif // MLIR_DIALECT_ARITH_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ Bufferize.cpp EmulateWideInt.cpp ExpandOps.cpp + SetFastMath.cpp UnsignedWhenEquivalent.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Arith/Transforms/SetFastMath.cpp b/mlir/lib/Dialect/Arith/Transforms/SetFastMath.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arith/Transforms/SetFastMath.cpp @@ -0,0 +1,61 @@ +//===- SetFastMath.cpp - Set fastmath flags globally ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" + +namespace mlir::arith { +#define GEN_PASS_DEF_ARITHSETFASTMATH +#include "mlir/Dialect/Arith/Transforms/Passes.h.inc" +} // namespace mlir::arith + +using namespace mlir; + +namespace { + +// Pass to set the fastmath attribute for ALL operations that implement the +// arith::ArithFastMathInterface. +// Currently, this pass overwrites the fastmath attribute with the value +// provided on the command line (or to the pass constructor). +// TODO: Provide a way to perform a bitwise operation on the source and +// destination fastmath values (e.g. 'or'). +struct ArithSetFastMathPass + : public arith::impl::ArithSetFastMathBase { + explicit ArithSetFastMathPass(arith::FastMathFlags fm) : fastmathValue(fm) {} + LogicalResult initialize(MLIRContext *context) override { + // Parse the comma-delimited string and set the corresponding fastmath + // bit(s). + for (auto &f : flags) { + if (auto enumOpt = arith::symbolizeFastMathFlags(f)) { + fastmathValue = (fastmathValue | enumOpt.value()); + } else + return failure(); + } + return success(); + } + void runOnOperation() override { + // Create an attribute with the desired fastmath value + auto fmfAttr = arith::FastMathFlagsAttr::get(getOperation()->getContext(), + fastmathValue); + getOperation()->walk([&](Operation *op) { + auto fmI = dyn_cast(op); + if (fmI) { + fmI.setFastMathFlagsAttr(fmfAttr); + } + }); + } + // FastMath enum value to be applied + arith::FastMathFlags fastmathValue; +}; + +} // namespace + +std::unique_ptr +mlir::arith::createArithSetFastMathPass(mlir::arith::FastMathFlags fm) { + return std::make_unique(fm); +} diff --git a/mlir/test/Dialect/Arith/set-fastmath.mlir b/mlir/test/Dialect/Arith/set-fastmath.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Arith/set-fastmath.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-opt %s --set-fastmath="value=nsz,nnan,ninf" | FileCheck %s + +// CHECK-LABEL: @fastmath +func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) { +// CHECK: {{.*}} = arith.addf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.subf %arg0, %arg1 fastmath : f32 + %0 = arith.addf %arg0, %arg1 : f32 + %1 = arith.subf %arg0, %arg1 : f32 + return +}