diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -408,7 +408,7 @@ // AndIOp //===----------------------------------------------------------------------===// -def Arith_AndIOp : Arith_IntBinaryOp<"andi", [Commutative]> { +def Arith_AndIOp : Arith_IntBinaryOp<"andi", [Commutative, Idempotent]> { let summary = "integer binary and"; let description = [{ The `andi` operation takes two operands and returns one result, each of @@ -436,7 +436,7 @@ // OrIOp //===----------------------------------------------------------------------===// -def Arith_OrIOp : Arith_IntBinaryOp<"ori", [Commutative]> { +def Arith_OrIOp : Arith_IntBinaryOp<"ori", [Commutative, Idempotent]> { let summary = "integer binary or"; let description = [{ The `ori` operation takes two operands and returns one result, each of these diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1945,7 +1945,7 @@ NativeOpTrait<"ResultsBroadcastableShape">; // X op Y == Y op X def Commutative : NativeOpTrait<"IsCommutative">; -// op op X == op X +// op op X == op X (unary) / X op X == X (binary) def Idempotent : NativeOpTrait<"IsIdempotent">; // op op X == X def Involution : NativeOpTrait<"IsInvolution">; diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1090,15 +1090,17 @@ }; /// This class adds property that the operation is idempotent. -/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x) +/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x), +/// or a binary operation "g" that satisfies g(x, x) = x. template class IsIdempotent : public TraitBase { public: static LogicalResult verifyTrait(Operation *op) { static_assert(ConcreteType::template hasTrait(), "expected operation to produce one result"); - static_assert(ConcreteType::template hasTrait(), - "expected operation to take one operand"); + static_assert(ConcreteType::template hasTrait() || + ConcreteType::template hasTrait::Impl>(), + "expected operation to take one or two operands"); static_assert(ConcreteType::template hasTrait(), "expected operation to preserve type"); // Idempotent requires the operation to be side effect free as well diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -494,9 +494,6 @@ APInt intValue; if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) return getLhs(); - /// and(x, x) -> x - if (getLhs() == getRhs()) - return getRhs(); return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a & b; }); @@ -510,9 +507,6 @@ /// or(x, 0) -> x if (matchPattern(getRhs(), m_Zero())) return getLhs(); - /// or(x, x) -> x - if (getLhs() == getRhs()) - return getRhs(); /// or(x, ) -> if (auto rhsAttr = operands[1].dyn_cast_or_null()) if (rhsAttr.getValue().isAllOnes()) diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -629,9 +629,13 @@ //===----------------------------------------------------------------------===// OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { - auto *argumentOp = op->getOperand(0).getDefiningOp(); - if (argumentOp && op->getName() == argumentOp->getName()) { - // Replace the outer operation output with the inner operation. + if (op->getNumOperands() == 1) { + auto *argumentOp = op->getOperand(0).getDefiningOp(); + if (argumentOp && op->getName() == argumentOp->getName()) { + // Replace the outer operation output with the inner operation. + return op->getOperand(0); + } + } else if (op->getOperand(0) == op->getOperand(1)) { return op->getOperand(0); } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1001,6 +1001,13 @@ let results = (outs I32); } +def TestIdempotentTraitBinaryOp + : TEST_Op<"op_idempotent_trait_binary", + [SameOperandsAndResultType, NoSideEffect, Idempotent]> { + let arguments = (ins I32:$op1, I32:$op2); + let results = (outs I32); +} + def TestInvolutionTraitNoOperationFolderOp : TEST_Op<"op_involution_trait_no_operation_fold", [SameOperandsAndResultType, NoSideEffect, Involution]> { diff --git a/mlir/test/mlir-tblgen/trait.mlir b/mlir/test/mlir-tblgen/trait.mlir --- a/mlir/test/mlir-tblgen/trait.mlir +++ b/mlir/test/mlir-tblgen/trait.mlir @@ -93,3 +93,11 @@ // CHECK: return [[IDEMPOTENT]] return %2: i32 } + +// CHECK-LABEL: func @testBinaryIdempotent +// CHECK-SAME: ([[ARG0:%.+]]: i32) +func @testBinaryIdempotent(%arg0 : i32) -> i32 { + %0 = "test.op_idempotent_trait_binary"(%arg0, %arg0) : (i32, i32) -> i32 + // CHECK: return [[ARG0]] + return %0: i32 +}