diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -751,6 +751,7 @@ %a = maxsi %b, %c : i64 ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -775,6 +776,7 @@ %a = maxui %b, %c : i64 ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -824,6 +826,7 @@ %a = minsi %b, %c : i64 ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -848,6 +851,7 @@ %a = minui %b, %c : i64 ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -939,6 +939,106 @@ return value.isa(); } +//===----------------------------------------------------------------------===// +// MaxSIOp +//===----------------------------------------------------------------------===// + +OpFoldResult MaxSIOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "binary operation takes two operands"); + + // maxsi(x,x) -> x + if (lhs() == rhs()) + return rhs(); + + APInt intValue; + // maxsi(x,MAX_INT) -> MAX_INT + if (matchPattern(rhs(), m_ConstantInt(&intValue)) && + intValue.isMaxSignedValue()) + return rhs(); + + // maxsi(x, MIN_INT) -> x + if (matchPattern(rhs(), m_ConstantInt(&intValue)) && + intValue.isMinSignedValue()) + return lhs(); + + return constFoldBinaryOp( + operands, [](APInt a, APInt b) { return llvm::APIntOps::smax(a, b); }); +} + +//===----------------------------------------------------------------------===// +// MaxUIOp +//===----------------------------------------------------------------------===// + +OpFoldResult MaxUIOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "binary operation takes two operands"); + + // maxui(x,x) -> x + if (lhs() == rhs()) + return rhs(); + + APInt intValue; + // maxui(x,MAX_INT) -> MAX_INT + if (matchPattern(rhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) + return rhs(); + + // maxui(x, MIN_INT) -> x + if (matchPattern(rhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) + return lhs(); + + return constFoldBinaryOp( + operands, [](APInt a, APInt b) { return llvm::APIntOps::umax(a, b); }); +} + +//===----------------------------------------------------------------------===// +// MinSIOp +//===----------------------------------------------------------------------===// + +OpFoldResult MinSIOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "binary operation takes two operands"); + + // minsi(x,x) -> x + if (lhs() == rhs()) + return rhs(); + + APInt intValue; + // minsi(x,MIN_INT) -> MIN_INT + if (matchPattern(rhs(), m_ConstantInt(&intValue)) && + intValue.isMinSignedValue()) + return rhs(); + + // minsi(x, MAX_INT) -> x + if (matchPattern(rhs(), m_ConstantInt(&intValue)) && + intValue.isMaxSignedValue()) + return lhs(); + + return constFoldBinaryOp( + operands, [](APInt a, APInt b) { return llvm::APIntOps::smin(a, b); }); +} + +//===----------------------------------------------------------------------===// +// MinUIOp +//===----------------------------------------------------------------------===// + +OpFoldResult MinUIOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "binary operation takes two operands"); + + // minui(x,x) -> x + if (lhs() == rhs()) + return rhs(); + + APInt intValue; + // minui(x,MIN_INT) -> MIN_INT + if (matchPattern(rhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) + return rhs(); + + // minui(x, MAX_INT) -> x + if (matchPattern(rhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) + return lhs(); + + return constFoldBinaryOp( + operands, [](APInt a, APInt b) { return llvm::APIntOps::umin(a, b); }); +} + //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -60,3 +60,68 @@ %res = select %arg0, %false, %true : i1 return %res : i1 } + +// CHECK-LABEL: test_maxsi +// CHECK: %[[C0:.+]] = arith.constant 42 +// CHECK: %[[MAX_INT_CST:.+]] = arith.constant 127 +// CHECK: %[[X:.+]] = maxsi %arg0, %[[C0]] +// CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]] +func @test_maxsi(%arg0 : i8) -> (i8, i8, i8, i8) { + %maxIntCst = arith.constant 127 : i8 + %minIntCst = arith.constant -128 : i8 + %c0 = arith.constant 42 : i8 + %0 = maxsi %arg0, %arg0 : i8 + %1 = maxsi %arg0, %maxIntCst : i8 + %2 = maxsi %arg0, %minIntCst : i8 + %3 = maxsi %arg0, %c0 : i8 + return %0, %1, %2, %3: i8, i8, i8, i8 +} + +// CHECK-LABEL: test_maxui +// CHECK: %[[C0:.+]] = arith.constant 42 +// CHECK: %[[MAX_INT_CST:.+]] = arith.constant -1 +// CHECK: %[[X:.+]] = maxui %arg0, %[[C0]] +// CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]] +func @test_maxui(%arg0 : i8) -> (i8, i8, i8, i8) { + %maxIntCst = arith.constant 255 : i8 + %minIntCst = arith.constant 0 : i8 + %c0 = arith.constant 42 : i8 + %0 = maxui %arg0, %arg0 : i8 + %1 = maxui %arg0, %maxIntCst : i8 + %2 = maxui %arg0, %minIntCst : i8 + %3 = maxui %arg0, %c0 : i8 + return %0, %1, %2, %3: i8, i8, i8, i8 +} + + +// CHECK-LABEL: test_minsi +// CHECK: %[[C0:.+]] = arith.constant 42 +// CHECK: %[[MIN_INT_CST:.+]] = arith.constant -128 +// CHECK: %[[X:.+]] = minsi %arg0, %[[C0]] +// CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]] +func @test_minsi(%arg0 : i8) -> (i8, i8, i8, i8) { + %maxIntCst = arith.constant 127 : i8 + %minIntCst = arith.constant -128 : i8 + %c0 = arith.constant 42 : i8 + %0 = minsi %arg0, %arg0 : i8 + %1 = minsi %arg0, %maxIntCst : i8 + %2 = minsi %arg0, %minIntCst : i8 + %3 = minsi %arg0, %c0 : i8 + return %0, %1, %2, %3: i8, i8, i8, i8 +} + +// CHECK-LABEL: test_minui +// CHECK: %[[C0:.+]] = arith.constant 42 +// CHECK: %[[MIN_INT_CST:.+]] = arith.constant 0 +// CHECK: %[[X:.+]] = minui %arg0, %[[C0]] +// CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]] +func @test_minui(%arg0 : i8) -> (i8, i8, i8, i8) { + %maxIntCst = arith.constant 255 : i8 + %minIntCst = arith.constant 0 : i8 + %c0 = arith.constant 42 : i8 + %0 = minui %arg0, %arg0 : i8 + %1 = minui %arg0, %maxIntCst : i8 + %2 = minui %arg0, %minIntCst : i8 + %3 = minui %arg0, %c0 : i8 + return %0, %1, %2, %3: i8, i8, i8, i8 +}