diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -634,7 +634,7 @@ // MaxFOp //===----------------------------------------------------------------------===// -def Arith_MaxFOp : Arith_FloatBinaryOp<"maxf", [Commutative]> { +def Arith_MaxFOp : Arith_FloatBinaryOp<"maxf"> { let summary = "floating-point maximum operation"; let description = [{ Syntax: @@ -653,6 +653,7 @@ %a = arith.maxf %b, %c : f64 ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -677,7 +678,7 @@ // MinFOp //===----------------------------------------------------------------------===// -def Arith_MinFOp : Arith_FloatBinaryOp<"minf", [Commutative]> { +def Arith_MinFOp : Arith_FloatBinaryOp<"minf"> { let summary = "floating-point minimum operation"; let description = [{ Syntax: @@ -696,6 +697,7 @@ %a = arith.minf %b, %c : f64 ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -91,6 +91,32 @@ } }; +/// The matcher that matches a constant scalar / vector splat / tensor splat +/// float operation and binds the constant float value. +struct constant_float_op_binder { + FloatAttr::ValueType *bind_value; + + /// Creates a matcher instance that binds the value to bv if match succeeds. + constant_float_op_binder(FloatAttr::ValueType *bv) : bind_value(bv) {} + + bool match(Operation *op) { + Attribute attr; + if (!constant_op_binder(&attr).match(op)) + return false; + auto type = op->getResult(0).getType(); + + if (type.isa()) + return attr_value_binder(bind_value).match(attr); + if (type.isa()) { + if (auto splatAttr = attr.dyn_cast()) { + return attr_value_binder(bind_value) + .match(splatAttr.getSplatValue()); + } + } + return false; + } +}; + /// The matcher that matches a constant scalar / vector splat / tensor splat /// integer operation and binds the constant integer value. struct constant_int_op_binder { @@ -276,6 +302,13 @@ return const_cast(pattern).match(op); } +/// Matches a constant holding a scalar/vector/tensor float (splat) and +/// writes the float value to bind_value. +inline detail::constant_float_op_binder +m_ConstantFloat(FloatAttr::ValueType *bind_value) { + return detail::constant_float_op_binder(bind_value); +} + /// Matches a constant holding a scalar/vector/tensor integer (splat) and /// writes the integer value to bind_value. inline detail::constant_int_op_binder diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -194,12 +194,12 @@ if (matchPattern(getRhs(), m_Zero())) return getLhs(); - // add(sub(a, b), b) -> a + // addi(sub(a, b), b) -> a if (auto sub = getLhs().getDefiningOp()) if (getRhs() == sub.getRhs()) return sub.getLhs(); - // add(b, sub(a, b)) -> a + // addi(b, sub(a, b)) -> a if (auto sub = getRhs().getDefiningOp()) if (getLhs() == sub.getRhs()) return sub.getLhs(); @@ -575,7 +575,24 @@ // AddFOp //===----------------------------------------------------------------------===// +// Returns whether 'value' is a scalar/vector/tensor with zero (negative or +// positive) float values. +static bool isAnyZeroFloat(Value value) { + APFloat floatValue(0.0f); + return matchPattern(value, m_ConstantFloat(&floatValue)) && + floatValue.isZero(); +} + OpFoldResult arith::AddFOp::fold(ArrayRef operands) { + APFloat floatValue(0.0f); + // addf(x, 0) -> x + if (isAnyZeroFloat(getRhs())) + return getLhs(); + + // addf(0, x) -> x + if (isAnyZeroFloat(getLhs())) + return getRhs(); + return constFoldBinaryOp( operands, [](const APFloat &a, const APFloat &b) { return a + b; }); } @@ -585,10 +602,37 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::SubFOp::fold(ArrayRef operands) { + APFloat floatValue(0.0f); + // subf(x, 0) -> x + if (isAnyZeroFloat(getRhs())) + return getLhs(); + return constFoldBinaryOp( operands, [](const APFloat &a, const APFloat &b) { return a - b; }); } +//===----------------------------------------------------------------------===// +// MaxFOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::MaxFOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "maxf takes two operands"); + + // maxf(x,x) -> x + if (getLhs() == getRhs()) + return getRhs(); + + // maxf(c,x) -> maxf(x,c) + if (operands.front() && !operands.back()) { + std::swap(getOperation()->getOpOperand(0), getOperation()->getOpOperand(1)); + return getResult(); + } + + return constFoldBinaryOp( + operands, + [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); +} + //===----------------------------------------------------------------------===// // MaxSIOp //===----------------------------------------------------------------------===// @@ -643,6 +687,28 @@ }); } +//===----------------------------------------------------------------------===// +// MinFOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::MinFOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "minf takes two operands"); + + // minf(x,x) -> x + if (getLhs() == getRhs()) + return getRhs(); + + // minf(c,x) -> minf(x,c) + if (operands.front() && !operands.back()) { + std::swap(getOperation()->getOpOperand(0), getOperation()->getOpOperand(1)); + return getResult(); + } + + return constFoldBinaryOp( + operands, + [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); }); +} + //===----------------------------------------------------------------------===// // MinSIOp //===----------------------------------------------------------------------===// @@ -701,7 +767,24 @@ // MulFOp //===----------------------------------------------------------------------===// +// Returns whether 'value' is a scalar/vector/tensor with one float values. +static bool isOneFloat(Value value) { + APFloat floatValue(0.0f), inverseValue(0.0f); + return matchPattern(value, m_ConstantFloat(&floatValue)) && + floatValue.getExactInverse(&inverseValue) && + floatValue == inverseValue; +} + OpFoldResult arith::MulFOp::fold(ArrayRef operands) { + APFloat floatValue(0.0f), inverseValue(0.0f); + // mulf(x, 1) -> x + if (isOneFloat(getRhs())) + return getLhs(); + + // mulf(1, x) -> x + if (isOneFloat(getLhs())) + return getRhs(); + return constFoldBinaryOp( operands, [](const APFloat &a, const APFloat &b) { return a * b; }); } @@ -711,6 +794,11 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::DivFOp::fold(ArrayRef operands) { + APFloat floatValue(0.0f), inverseValue(0.0f); + // divf(x, 1) -> x + if (isOneFloat(getRhs())) + return getLhs(); + return constFoldBinaryOp( operands, [](const APFloat &a, const APFloat &b) { return a / b; }); } diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -619,6 +619,101 @@ // ----- +// CHECK-LABEL: @test_minf( +func @test_minf(%arg0 : f32) -> (f32, f32) { + // CHECK-NEXT: %[[C0:.+]] = arith.constant + // CHECK-NEXT: %[[X:.+]] = arith.minf %arg0, %[[C0]] + // CHECK-NEXT: return %[[X]], %arg0 + %c0 = arith.constant 0.0 : f32 + %0 = arith.minf %c0, %arg0 : f32 + %1 = arith.minf %arg0, %arg0 : f32 + return %0, %1 : f32, f32 +} + +// ----- + +// CHECK-LABEL: @test_maxf( +func @test_maxf(%arg0 : f32) -> (f32, f32) { + // CHECK-NEXT: %[[C0:.+]] = arith.constant + // CHECK-NEXT: %[[X:.+]] = arith.maxf %arg0, %[[C0]] + // CHECK-NEXT: return %[[X]], %arg0 + %c0 = arith.constant 0.0 : f32 + %0 = arith.maxf %c0, %arg0 : f32 + %1 = arith.maxf %arg0, %arg0 : f32 + return %0, %1 : f32, f32 +} + +// ----- + +// CHECK-LABEL: @test_addf( +func @test_addf(%arg0 : f32) -> (f32, f32, f32) { + // CHECK-NEXT: %[[C2:.+]] = arith.constant 2.0 + // CHECK-NEXT: return %arg0, %arg0, %[[C2]] + %c0 = arith.constant 0.0 : f32 + %c1 = arith.constant 1.0 : f32 + %0 = arith.addf %arg0, %c0 : f32 + %1 = arith.addf %c0, %arg0 : f32 + %2 = arith.addf %c1, %c1 : f32 + return %0, %1, %2 : f32, f32, f32 +} + +// ----- + +// CHECK-LABEL: @test_subf( +func @test_subf(%arg0 : f32) -> (f32, f32) { + // CHECK-NEXT: %[[C1:.+]] = arith.constant -1.0 + // CHECK-NEXT: return %arg0, %[[C1]] + %c0 = arith.constant 0.0 : f32 + %c1 = arith.constant 1.0 : f32 + %0 = arith.subf %arg0, %c0 : f32 + %1 = arith.subf %c0, %c1 : f32 + return %0, %1 : f32, f32 +} + +// ----- + +// CHECK-LABEL: @test_mulf( +func @test_mulf(%arg0 : f32) -> (f32, f32, f32) { + // CHECK-NEXT: %[[C4:.+]] = arith.constant 4.0 + // CHECK-NEXT: return %arg0, %arg0, %[[C4]] + %c1 = arith.constant 1.0 : f32 + %c2 = arith.constant 2.0 : f32 + %0 = arith.mulf %arg0, %c1 : f32 + %1 = arith.mulf %c1, %arg0 : f32 + %2 = arith.mulf %c2, %c2 : f32 + return %0, %1, %2 : f32, f32, f32 +} + +// ----- + +// CHECK-LABEL: @test_divf( +func @test_divf(%arg0 : f32) -> (f32, f32) { + // CHECK-NEXT: %[[C5:.+]] = arith.constant 5.000000e-01 + // CHECK-NEXT: return %arg0, %[[C5]] + %c1 = arith.constant 1.0 : f32 + %c2 = arith.constant 2.0 : f32 + %0 = arith.divf %arg0, %c1 : f32 + %1 = arith.divf %c1, %c2 : f32 + return %0, %1 : f32, f32 +} + +// ----- + +// CHECK-LABEL: @test_cmpf( +func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) { +// CHECK-DAG: %[[T:.*]] = arith.constant true +// CHECK-DAG: %[[F:.*]] = arith.constant false +// CHECK: return %[[F]], %[[F]], %[[T]], %[[T]] + %nan = arith.constant 0x7fffffff : f32 + %0 = arith.cmpf olt, %nan, %arg0 : f32 + %1 = arith.cmpf olt, %arg0, %nan : f32 + %2 = arith.cmpf ugt, %nan, %arg0 : f32 + %3 = arith.cmpf ugt, %arg0, %nan : f32 + return %0, %1, %2, %3 : i1, i1, i1, i1 +} + +// ----- + // CHECK-LABEL: @constant_FPtoUI( func @constant_FPtoUI() -> i32 { // CHECK: %[[C0:.+]] = arith.constant 2 : i32 @@ -678,30 +773,3 @@ %res = arith.sitofp %c0 : i32 to f32 return %res : f32 } - -// ----- -// CHECK-LABEL: @constant_MinMax( -func @constant_MinMax(%arg0 : f32) -> f32 { - // CHECK: %[[const:.+]] = arith.constant - // CHECK: %[[min:.+]] = arith.minf %arg0, %[[const]] : f32 - // CHECK: %[[res:.+]] = arith.maxf %[[min]], %[[const]] : f32 - // CHECK: return %[[res]] - %const = arith.constant 0.0 : f32 - %min = arith.minf %const, %arg0 : f32 - %res = arith.maxf %const, %min : f32 - return %res : f32 -} - -// ----- -// CHECK-LABEL: @cmpf_nan( -func @cmpf_nan(%arg0 : f32) -> (i1, i1, i1, i1) { -// CHECK-DAG: %[[T:.*]] = arith.constant true -// CHECK-DAG: %[[F:.*]] = arith.constant false -// CHECK: return %[[F]], %[[F]], %[[T]], %[[T]] - %nan = arith.constant 0x7fffffff : f32 - %0 = arith.cmpf olt, %nan, %arg0 : f32 - %1 = arith.cmpf olt, %arg0, %nan : f32 - %2 = arith.cmpf ugt, %nan, %arg0 : f32 - %3 = arith.cmpf ugt, %arg0, %nan : f32 - return %0, %1, %2, %3 : i1, i1, i1, i1 -}