diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -585,6 +585,7 @@ %x = arith.negf %y : tensor<4x?xf8> ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h --- a/mlir/include/mlir/Dialect/CommonFolders.h +++ b/mlir/include/mlir/Dialect/CommonFolders.h @@ -69,6 +69,45 @@ } return {}; } + +/// Performs constant folding `calculate` with element-wise behavior on the one +/// attributes in `operands` and returns the result if possible. +template > +Attribute constFoldUnaryOp(ArrayRef operands, + const CalculationT &&calculate) { + assert(operands.size() == 1 && "unary op takes one operands"); + if (!operands[0]) + return {}; + + if (operands[0].isa()) { + auto op = operands[0].cast(); + + return AttrElementT::get(op.getType(), calculate(op.getValue())); + } + if (operands[0].isa()) { + // Both operands are splats so we can avoid expanding the values out and + // just fold based on the splat value. + auto op = operands[0].cast(); + + auto elementResult = calculate(op.getSplatValue()); + return DenseElementsAttr::get(op.getType(), elementResult); + } else if (operands[0].isa()) { + // Operands are ElementsAttr-derived; perform an element-wise fold by + // expanding the values. + auto op = operands[0].cast(); + + auto opIt = op.value_begin(); + SmallVector elementResults; + elementResults.reserve(op.getNumElements()); + for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) + elementResults.push_back(calculate(*opIt)); + return DenseElementsAttr::get(op.getType(), elementResults); + } + return {}; +} + } // namespace mlir #endif // MLIR_DIALECT_COMMONFOLDERS_H diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -573,6 +573,15 @@ patterns.add(context); } +//===----------------------------------------------------------------------===// +// NegFOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::NegFOp::fold(ArrayRef operands) { + return constFoldUnaryOp(operands, + [](const APFloat &a) { return -a; }); +} + //===----------------------------------------------------------------------===// // AddFOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -1166,3 +1166,14 @@ %r = arith.shrsi %c1, %cm32 : i64 return %r : i64 } + +// ----- + +// CHECK-LABEL: @test_negf( +// CHECK: %[[res:.+]] = arith.constant -2.0 +// CHECK: return %[[res]] +func @test_negf() -> (f32) { + %c = arith.constant 2.0 : f32 + %0 = arith.negf %c : f32 + return %0: f32 +}