diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -653,6 +653,7 @@ %a = arith.maxf %b, %c : f64 ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -696,6 +697,7 @@ %a = arith.minf %b, %c : f64 ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -91,6 +91,43 @@ } }; +/// The matcher that matches a constant scalar / vector splat / tensor splat +/// float operation and binds the constant float value. +struct constant_float_op_binder { + FloatAttr::ValueType *bind_value; + + /// Creates a matcher instance that binds the value to bv if match succeeds. + constant_float_op_binder(FloatAttr::ValueType *bv) : bind_value(bv) {} + + bool match(Operation *op) { + Attribute attr; + if (!constant_op_binder(&attr).match(op)) + return false; + auto type = op->getResult(0).getType(); + + if (type.isa()) + return attr_value_binder(bind_value).match(attr); + if (type.isa()) { + if (auto splatAttr = attr.dyn_cast()) { + return attr_value_binder(bind_value) + .match(splatAttr.getSplatValue()); + } + } + return false; + } +}; + +/// The matcher that matches a given target constant scalar / vector splat / +/// tensor splat float value that fulfills a predicate. +struct constant_float_predicate_matcher { + bool (*predicate)(const APFloat &); + + bool match(Operation *op) { + APFloat value(APFloat::Bogus()); + return constant_float_op_binder(&value).match(op) && predicate(value); + } +}; + /// The matcher that matches a constant scalar / vector splat / tensor splat /// integer operation and binds the constant integer value. struct constant_int_op_binder { @@ -118,22 +155,13 @@ }; /// The matcher that matches a given target constant scalar / vector splat / -/// tensor splat integer value. -template -struct constant_int_value_matcher { - bool match(Operation *op) { - APInt value; - return constant_int_op_binder(&value).match(op) && TargetValue == value; - } -}; +/// tensor splat integer value that fulfills a predicate. +struct constant_int_predicate_matcher { + bool (*predicate)(const APInt &); -/// The matcher that matches anything except the given target constant scalar / -/// vector splat / tensor splat integer value. -template -struct constant_int_not_value_matcher { bool match(Operation *op) { APInt value; - return constant_int_op_binder(&value).match(op) && TargetNotValue != value; + return constant_int_op_binder(&value).match(op) && predicate(value); } }; @@ -239,26 +267,65 @@ return detail::constant_op_binder(bind_value); } -/// Matches a constant scalar / vector splat / tensor splat integer one. -inline detail::constant_int_value_matcher<1> m_One() { - return detail::constant_int_value_matcher<1>(); +/// Matches a constant scalar / vector splat / tensor splat float (both positive +/// and negative) zero. +inline detail::constant_float_predicate_matcher m_AnyZeroFloat() { + return {[](const APFloat &value) { return value.isZero(); }}; } -/// Matches the given OpClass. -template -inline detail::op_matcher m_Op() { - return detail::op_matcher(); +/// Matches a constant scalar / vector splat / tensor splat float positive zero. +inline detail::constant_float_predicate_matcher m_PosZeroFloat() { + return {[](const APFloat &value) { return value.isPosZero(); }}; +} + +/// Matches a constant scalar / vector splat / tensor splat float negative zero. +inline detail::constant_float_predicate_matcher m_NegZeroFloat() { + return {[](const APFloat &value) { return value.isNegZero(); }}; +} + +/// Matches a constant scalar / vector splat / tensor splat float ones. +inline detail::constant_float_predicate_matcher m_OneFloat() { + return {[](const APFloat &value) { + return APFloat(value.getSemantics(), 1) == value; + }}; +} + +/// Matches a constant scalar / vector splat / tensor splat float positive +/// infinity. +inline detail::constant_float_predicate_matcher m_PosInfFloat() { + return {[](const APFloat &value) { + return !value.isNegative() && value.isInfinity(); + }}; +} + +/// Matches a constant scalar / vector splat / tensor splat float negative +/// infinity. +inline detail::constant_float_predicate_matcher m_NegInfFloat() { + return {[](const APFloat &value) { + return value.isNegative() && value.isInfinity(); + }}; } /// Matches a constant scalar / vector splat / tensor splat integer zero. -inline detail::constant_int_value_matcher<0> m_Zero() { - return detail::constant_int_value_matcher<0>(); +inline detail::constant_int_predicate_matcher m_Zero() { + return {[](const APInt &value) { return 0 == value; }}; } /// Matches a constant scalar / vector splat / tensor splat integer that is any /// non-zero value. -inline detail::constant_int_not_value_matcher<0> m_NonZero() { - return detail::constant_int_not_value_matcher<0>(); +inline detail::constant_int_predicate_matcher m_NonZero() { + return {[](const APInt &value) { return 0 != value; }}; +} + +/// Matches a constant scalar / vector splat / tensor splat integer one. +inline detail::constant_int_predicate_matcher m_One() { + return {[](const APInt &value) { return 1 == value; }}; +} + +/// Matches the given OpClass. +template +inline detail::op_matcher m_Op() { + return detail::op_matcher(); } /// Entry point for matching a pattern over a Value. @@ -276,6 +343,13 @@ return const_cast(pattern).match(op); } +/// Matches a constant holding a scalar/vector/tensor float (splat) and +/// writes the float value to bind_value. +inline detail::constant_float_op_binder +m_ConstantFloat(FloatAttr::ValueType *bind_value) { + return detail::constant_float_op_binder(bind_value); +} + /// Matches a constant holding a scalar/vector/tensor integer (splat) and /// writes the integer value to bind_value. inline detail::constant_int_op_binder diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -194,12 +194,12 @@ if (matchPattern(getRhs(), m_Zero())) return getLhs(); - // add(sub(a, b), b) -> a + // addi(subi(a, b), b) -> a if (auto sub = getLhs().getDefiningOp()) if (getRhs() == sub.getRhs()) return sub.getLhs(); - // add(b, sub(a, b)) -> a + // addi(b, subi(a, b)) -> a if (auto sub = getRhs().getDefiningOp()) if (getLhs() == sub.getRhs()) return sub.getLhs(); @@ -576,6 +576,14 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::AddFOp::fold(ArrayRef operands) { + // addf(x, -0) -> x + if (matchPattern(getRhs(), m_NegZeroFloat())) + return getLhs(); + + // addf(-0, x) -> x + if (matchPattern(getLhs(), m_NegZeroFloat())) + return getRhs(); + return constFoldBinaryOp( operands, [](const APFloat &a, const APFloat &b) { return a + b; }); } @@ -585,10 +593,34 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::SubFOp::fold(ArrayRef operands) { + // subf(x, +0) -> x + if (matchPattern(getRhs(), m_PosZeroFloat())) + return getLhs(); + return constFoldBinaryOp( operands, [](const APFloat &a, const APFloat &b) { return a - b; }); } +//===----------------------------------------------------------------------===// +// MaxFOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::MaxFOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "maxf takes two operands"); + + // maxf(x,x) -> x + if (getLhs() == getRhs()) + return getRhs(); + + // maxf(x, -inf) -> x + if (matchPattern(getRhs(), m_NegInfFloat())) + return getLhs(); + + return constFoldBinaryOp( + operands, + [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); +} + //===----------------------------------------------------------------------===// // MaxSIOp //===----------------------------------------------------------------------===// @@ -643,6 +675,26 @@ }); } +//===----------------------------------------------------------------------===// +// MinFOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::MinFOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "minf takes two operands"); + + // minf(x,x) -> x + if (getLhs() == getRhs()) + return getRhs(); + + // minf(x, +inf) -> x + if (matchPattern(getRhs(), m_PosInfFloat())) + return getLhs(); + + return constFoldBinaryOp( + operands, + [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); }); +} + //===----------------------------------------------------------------------===// // MinSIOp //===----------------------------------------------------------------------===// @@ -702,6 +754,15 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::MulFOp::fold(ArrayRef operands) { + APFloat floatValue(0.0f), inverseValue(0.0f); + // mulf(x, 1) -> x + if (matchPattern(getRhs(), m_OneFloat())) + return getLhs(); + + // mulf(1, x) -> x + if (matchPattern(getLhs(), m_OneFloat())) + return getRhs(); + return constFoldBinaryOp( operands, [](const APFloat &a, const APFloat &b) { return a * b; }); } @@ -711,6 +772,11 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::DivFOp::fold(ArrayRef operands) { + APFloat floatValue(0.0f), inverseValue(0.0f); + // divf(x, 1) -> x + if (matchPattern(getRhs(), m_OneFloat())) + return getLhs(); + return constFoldBinaryOp( operands, [](const APFloat &a, const APFloat &b) { return a / b; }); } diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -619,6 +619,113 @@ // ----- +// CHECK-LABEL: @test_minf( +func @test_minf(%arg0 : f32) -> (f32, f32, f32) { + // CHECK-DAG: %[[C0:.+]] = arith.constant 0.0 + // CHECK-NEXT: %[[X:.+]] = arith.minf %arg0, %[[C0]] + // CHECK-NEXT: return %[[X]], %arg0, %arg0 + %c0 = arith.constant 0.0 : f32 + %inf = arith.constant 0x7F800000 : f32 + %0 = arith.minf %c0, %arg0 : f32 + %1 = arith.minf %arg0, %arg0 : f32 + %2 = arith.minf %inf, %arg0 : f32 + return %0, %1, %2 : f32, f32, f32 +} + +// ----- + +// CHECK-LABEL: @test_maxf( +func @test_maxf(%arg0 : f32) -> (f32, f32, f32) { + // CHECK-DAG: %[[C0:.+]] = arith.constant + // CHECK-NEXT: %[[X:.+]] = arith.maxf %arg0, %[[C0]] + // CHECK-NEXT: return %[[X]], %arg0, %arg0 + %c0 = arith.constant 0.0 : f32 + %-inf = arith.constant 0xFF800000 : f32 + %0 = arith.maxf %c0, %arg0 : f32 + %1 = arith.maxf %arg0, %arg0 : f32 + %2 = arith.maxf %-inf, %arg0 : f32 + return %0, %1, %2 : f32, f32, f32 +} + +// ----- + +// CHECK-LABEL: @test_addf( +func @test_addf(%arg0 : f32) -> (f32, f32, f32, f32) { + // CHECK-DAG: %[[C2:.+]] = arith.constant 2.0 + // CHECK-DAG: %[[C0:.+]] = arith.constant 0.0 + // CHECK-NEXT: %[[X:.+]] = arith.addf %arg0, %[[C0]] + // CHECK-NEXT: return %[[X]], %arg0, %arg0, %[[C2]] + %c0 = arith.constant 0.0 : f32 + %c-0 = arith.constant -0.0 : f32 + %c1 = arith.constant 1.0 : f32 + %0 = arith.addf %arg0, %c0 : f32 + %1 = arith.addf %arg0, %c-0 : f32 + %2 = arith.addf %c-0, %arg0 : f32 + %3 = arith.addf %c1, %c1 : f32 + return %0, %1, %2, %3 : f32, f32, f32, f32 +} + +// ----- + +// CHECK-LABEL: @test_subf( +func @test_subf(%arg0 : f16) -> (f16, f16, f16) { + // CHECK-DAG: %[[C1:.+]] = arith.constant -1.0 + // CHECK-DAG: %[[C0:.+]] = arith.constant -0.0 + // CHECK-NEXT: %[[X:.+]] = arith.subf %arg0, %[[C0]] + // CHECK-NEXT: return %arg0, %[[X]], %[[C1]] + %c0 = arith.constant 0.0 : f16 + %c-0 = arith.constant -0.0 : f16 + %c1 = arith.constant 1.0 : f16 + %0 = arith.subf %arg0, %c0 : f16 + %1 = arith.subf %arg0, %c-0 : f16 + %2 = arith.subf %c0, %c1 : f16 + return %0, %1, %2 : f16, f16, f16 +} + +// ----- + +// CHECK-LABEL: @test_mulf( +func @test_mulf(%arg0 : f32) -> (f32, f32, f32) { + // CHECK-NEXT: %[[C4:.+]] = arith.constant 4.0 + // CHECK-NEXT: return %arg0, %arg0, %[[C4]] + %c1 = arith.constant 1.0 : f32 + %c2 = arith.constant 2.0 : f32 + %0 = arith.mulf %arg0, %c1 : f32 + %1 = arith.mulf %c1, %arg0 : f32 + %2 = arith.mulf %c2, %c2 : f32 + return %0, %1, %2 : f32, f32, f32 +} + +// ----- + +// CHECK-LABEL: @test_divf( +func @test_divf(%arg0 : f64) -> (f64, f64) { + // CHECK-NEXT: %[[C5:.+]] = arith.constant 5.000000e-01 + // CHECK-NEXT: return %arg0, %[[C5]] + %c1 = arith.constant 1.0 : f64 + %c2 = arith.constant 2.0 : f64 + %0 = arith.divf %arg0, %c1 : f64 + %1 = arith.divf %c1, %c2 : f64 + return %0, %1 : f64, f64 +} + +// ----- + +// CHECK-LABEL: @test_cmpf( +func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) { +// CHECK-DAG: %[[T:.*]] = arith.constant true +// CHECK-DAG: %[[F:.*]] = arith.constant false +// CHECK: return %[[F]], %[[F]], %[[T]], %[[T]] + %nan = arith.constant 0x7fffffff : f32 + %0 = arith.cmpf olt, %nan, %arg0 : f32 + %1 = arith.cmpf olt, %arg0, %nan : f32 + %2 = arith.cmpf ugt, %nan, %arg0 : f32 + %3 = arith.cmpf ugt, %arg0, %nan : f32 + return %0, %1, %2, %3 : i1, i1, i1, i1 +} + +// ----- + // CHECK-LABEL: @constant_FPtoUI( func @constant_FPtoUI() -> i32 { // CHECK: %[[C0:.+]] = arith.constant 2 : i32 @@ -678,30 +785,3 @@ %res = arith.sitofp %c0 : i32 to f32 return %res : f32 } - -// ----- -// CHECK-LABEL: @constant_MinMax( -func @constant_MinMax(%arg0 : f32) -> f32 { - // CHECK: %[[const:.+]] = arith.constant - // CHECK: %[[min:.+]] = arith.minf %arg0, %[[const]] : f32 - // CHECK: %[[res:.+]] = arith.maxf %[[min]], %[[const]] : f32 - // CHECK: return %[[res]] - %const = arith.constant 0.0 : f32 - %min = arith.minf %const, %arg0 : f32 - %res = arith.maxf %const, %min : f32 - return %res : f32 -} - -// ----- -// CHECK-LABEL: @cmpf_nan( -func @cmpf_nan(%arg0 : f32) -> (i1, i1, i1, i1) { -// CHECK-DAG: %[[T:.*]] = arith.constant true -// CHECK-DAG: %[[F:.*]] = arith.constant false -// CHECK: return %[[F]], %[[F]], %[[T]], %[[T]] - %nan = arith.constant 0x7fffffff : f32 - %0 = arith.cmpf olt, %nan, %arg0 : f32 - %1 = arith.cmpf olt, %arg0, %nan : f32 - %2 = arith.cmpf ugt, %nan, %arg0 : f32 - %3 = arith.cmpf ugt, %arg0, %nan : f32 - return %0, %1, %2, %3 : i1, i1, i1, i1 -} diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -878,7 +878,6 @@ // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> // CHECK: vector.multi_reduction , {{.*}} [1] : vector<4x4xf32> to vector<4xf32> - // CHECK: mulf {{.*}} : vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> %ident = arith.constant 1.0 : f32 %init = linalg.init_tensor [4] : tensor<4xf32> diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir --- a/mlir/test/Dialect/SCF/loop-pipelining.mlir +++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir @@ -224,17 +224,14 @@ // CHECK-NEXT: %[[R:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] to %[[C2]] // CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]], // CHECK-SAME: %[[ADDARG:.*]] = %[[ADD0]], %[[LARG:.*]] = %[[L1]]) -> (f32, f32, f32) { -// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[CSTF]], %[[ADDARG]] : f32 -// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[LARG]], %[[MUL0]] : f32 +// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[LARG]], %[[ADDARG]] : f32 // CHECK-NEXT: %[[IV2:.*]] = arith.addi %[[IV]], %[[C2]] : index // CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV2]]] : memref -// CHECK-NEXT: scf.yield %[[MUL0]], %[[ADD1]], %[[L2]] : f32, f32, f32 +// CHECK-NEXT: scf.yield %[[ADDARG]], %[[ADD1]], %[[L2]] : f32, f32, f32 // CHECK-NEXT: } // Epilogue: -// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[CSTF]], %[[R]]#1 : f32 -// CHECK-NEXT: %[[ADD2:.*]] = arith.addf %[[R]]#2, %[[MUL1]] : f32 -// CHECK-NEXT: %[[MUL2:.*]] = arith.mulf %[[CSTF]], %[[ADD2]] : f32 -// CHECK-NEXT: return %[[MUL2]] : f32 +// CHECK-NEXT: %[[ADD2:.*]] = arith.addf %[[R]]#2, %[[R]]#1 : f32 +// CHECK-NEXT: return %[[ADD2]] : f32 func @backedge_different_stage(%A: memref) -> f32 { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -264,15 +261,13 @@ // CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]], // CHECK-SAME: %[[LARG:.*]] = %[[L0]]) -> (f32, f32) { // CHECK-NEXT: %[[ADD0:.*]] = arith.addf %[[LARG]], %[[C]] : f32 -// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[CSTF]], %[[ADD0]] : f32 // CHECK-NEXT: %[[IV1:.*]] = arith.addi %[[IV]], %[[C1]] : index // CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV1]]] : memref -// CHECK-NEXT: scf.yield %[[MUL0]], %[[L2]] : f32, f32 +// CHECK-NEXT: scf.yield %[[ADD0]], %[[L2]] : f32, f32 // CHECK-NEXT: } // Epilogue: // CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[R]]#1, %[[R]]#0 : f32 -// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[CSTF]], %[[ADD1]] : f32 -// CHECK-NEXT: return %[[MUL1]] : f32 +// CHECK-NEXT: return %[[ADD1]] : f32 func @backedge_same_stage(%A: memref) -> f32 { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index