diff --git a/mlir/include/mlir/Dialect/Math/IR/MathBase.td b/mlir/include/mlir/Dialect/Math/IR/MathBase.td --- a/mlir/include/mlir/Dialect/Math/IR/MathBase.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathBase.td @@ -15,6 +15,7 @@ The math dialect is intended to hold mathematical operations on integer and floating type beyond simple arithmetics. }]; + let hasConstantMaterializer = 1; let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; } #endif // MATH_BASE diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -195,6 +195,7 @@ %x = math.ceil %y : tensor<4x?xf8> ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -649,6 +650,7 @@ %y = math.log2 %x : f64 ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -6,7 +6,9 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/Builders.h" using namespace mlir; using namespace mlir::math; @@ -17,3 +19,58 @@ #define GET_OP_CLASSES #include "mlir/Dialect/Math/IR/MathOps.cpp.inc" + +//===----------------------------------------------------------------------===// +// CeilOp folder +//===----------------------------------------------------------------------===// + +OpFoldResult math::CeilOp::fold(ArrayRef operands) { + auto constOperand = operands.front(); + if (!constOperand) + return {}; + + auto attr = constOperand.dyn_cast(); + if (!attr) + return {}; + + APFloat sourceVal = attr.getValue(); + sourceVal.roundToIntegral(llvm::RoundingMode::TowardPositive); + + return FloatAttr::get(getType(), sourceVal); +} + +//===----------------------------------------------------------------------===// +// Log2Op folder +//===----------------------------------------------------------------------===// + +OpFoldResult math::Log2Op::fold(ArrayRef operands) { + auto constOperand = operands.front(); + if (!constOperand) + return {}; + + auto attr = constOperand.dyn_cast(); + if (!attr) + return {}; + + auto FT = getType().cast(); + + APFloat APF = attr.getValue(); + + if (APF.isNegative()) + return {}; + + if (FT.getWidth() == 64) + return FloatAttr::get(getType(), log2(APF.convertToDouble())); + + if (FT.getWidth() == 32) + return FloatAttr::get(getType(), log2f(APF.convertToDouble())); + + return {}; +} + +/// Materialize an integer or floating point constant. +Operation *math::MathDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + return builder.create(loc, value, type); +} diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Math/canonicalize.mlir @@ -0,0 +1,75 @@ +// RUN: mlir-opt %s -canonicalize | FileCheck %s + +// CHECK-LABEL: @ceil_fold +// CHECK: %[[cst:.+]] = arith.constant 1.000000e+00 : f32 +// CHECK: return %[[cst]] +func @ceil_fold() -> f32 { + %c = arith.constant 0.3 : f32 + %r = math.ceil %c : f32 + return %r : f32 +} + +// CHECK-LABEL: @ceil_fold2 +// CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f32 +// CHECK: return %[[cst]] +func @ceil_fold2() -> f32 { + %c = arith.constant 2.0 : f32 + %r = math.ceil %c : f32 + return %r : f32 +} + +// CHECK-LABEL: @log2_fold +// CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f32 + // CHECK: return %[[cst]] +func @log2_fold() -> f32 { + %c = arith.constant 4.0 : f32 + %r = math.log2 %c : f32 + return %r : f32 +} + +// CHECK-LABEL: @log2_fold2 +// CHECK: %[[cst:.+]] = arith.constant 0xFF800000 : f32 + // CHECK: return %[[cst]] +func @log2_fold2() -> f32 { + %c = arith.constant 0.0 : f32 + %r = math.log2 %c : f32 + return %r : f32 +} + +// CHECK-LABEL: @log2_nofold2 +// CHECK: %[[cst:.+]] = arith.constant -1.000000e+00 : f32 +// CHECK: %[[res:.+]] = math.log2 %[[cst]] : f32 + // CHECK: return %[[res]] +func @log2_nofold2() -> f32 { + %c = arith.constant -1.0 : f32 + %r = math.log2 %c : f32 + return %r : f32 +} + +// CHECK-LABEL: @log2_fold_64 +// CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f64 + // CHECK: return %[[cst]] +func @log2_fold_64() -> f64 { + %c = arith.constant 4.0 : f64 + %r = math.log2 %c : f64 + return %r : f64 +} + +// CHECK-LABEL: @log2_fold2_64 +// CHECK: %[[cst:.+]] = arith.constant 0xFFF0000000000000 : f64 + // CHECK: return %[[cst]] +func @log2_fold2_64() -> f64 { + %c = arith.constant 0.0 : f64 + %r = math.log2 %c : f64 + return %r : f64 +} + +// CHECK-LABEL: @log2_nofold2_64 +// CHECK: %[[cst:.+]] = arith.constant -1.000000e+00 : f64 +// CHECK: %[[res:.+]] = math.log2 %[[cst]] : f64 + // CHECK: return %[[res]] +func @log2_nofold2_64() -> f64 { + %c = arith.constant -1.0 : f64 + %r = math.log2 %c : f64 + return %r : f64 +}