diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1720,6 +1720,8 @@ NativeOpTrait<"ResultsBroadcastableShape">; // X op Y == Y op X def Commutative : NativeOpTrait<"IsCommutative">; +// op op X == op X +def Idempotent : NativeOpTrait<"IsIdempotent">; // op op X == X def Involution : NativeOpTrait<"IsInvolution">; // Op behaves like a constant. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -388,10 +388,12 @@ // corresponding trait classes. This avoids them being template // instantiated/duplicated. namespace impl { +OpFoldResult foldIdempotent(Operation *op); OpFoldResult foldInvolution(Operation *op); LogicalResult verifyZeroOperands(Operation *op); LogicalResult verifyOneOperand(Operation *op); LogicalResult verifyNOperands(Operation *op, unsigned numOperands); +LogicalResult verifyIsIdempotent(Operation *op); LogicalResult verifyIsInvolution(Operation *op); LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands); LogicalResult verifyOperandsAreFloatLike(Operation *op); @@ -1012,7 +1014,7 @@ }; /// This class adds property that the operation is an involution. -/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x) +/// This means a unary to unary operation "f" that satisfies f(f(x)) = x template class IsInvolution : public TraitBase { public: @@ -1033,6 +1035,28 @@ } }; +/// This class adds property that the operation is idempotent. +/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x) +template +class IsIdempotent : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + static_assert(ConcreteType::template hasTrait(), + "expected operation to produce one result"); + static_assert(ConcreteType::template hasTrait(), + "expected operation to take one operand"); + static_assert(ConcreteType::template hasTrait(), + "expected operation to preserve type"); + // Idempotent requires the operation to be side effect free as well + // but currently this check is under a FIXME and is not actually done. + return impl::verifyIsIdempotent(op); + } + + static OpFoldResult foldTrait(Operation *op, ArrayRef operands) { + return impl::foldIdempotent(op); + } +}; + /// This class verifies that all operands of the specified op have a float type, /// a vector thereof, or a tensor thereof. template diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -679,6 +679,16 @@ // Op Trait implementations //===----------------------------------------------------------------------===// +OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { + auto *argumentOp = op->getOperand(0).getDefiningOp(); + if (argumentOp && op->getName() == argumentOp->getName()) { + // Replace the outer operation output with the inner operation. + return op->getOperand(0); + } + + return {}; +} + OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { auto *argumentOp = op->getOperand(0).getDefiningOp(); if (argumentOp && op->getName() == argumentOp->getName()) { @@ -730,6 +740,14 @@ return type; } +LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) { + // FIXME: Add back check for no side effects on operation. + // Currently adding it would cause the shared library build + // to fail since there would be a dependency of IR on SideEffectInterfaces + // which is cyclical. + return success(); +} + LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { // FIXME: Add back check for no side effects on operation. // Currently adding it would cause the shared library build diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -798,6 +798,13 @@ let results = (outs I32); } +def TestIdempotentTraitOp + : TEST_Op<"op_idempotent_trait", + [SameOperandsAndResultType, NoSideEffect, Idempotent]> { + let arguments = (ins I32:$op1); + let results = (outs I32); +} + def TestInvolutionTraitNoOperationFolderOp : TEST_Op<"op_involution_trait_no_operation_fold", [SameOperandsAndResultType, NoSideEffect, Involution]> { diff --git a/mlir/test/mlir-tblgen/trait.mlir b/mlir/test/mlir-tblgen/trait.mlir --- a/mlir/test/mlir-tblgen/trait.mlir +++ b/mlir/test/mlir-tblgen/trait.mlir @@ -59,3 +59,37 @@ // CHECK: return [[OP]] return %1: i32 } + +//===----------------------------------------------------------------------===// +// Test that idempotent folding works correctly +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @testSingleIdempotent +// CHECK-SAME: ([[ARG0:%.+]]: i32) +func @testSingleIdempotent(%arg0 : i32) -> i32 { + // CHECK: [[IDEMPOTENT:%.+]] = "test.op_idempotent_trait"([[ARG0]]) + %0 = "test.op_idempotent_trait"(%arg0) : (i32) -> i32 + // CHECK: return [[IDEMPOTENT]] + return %0: i32 +} + +// CHECK-LABEL: func @testDoubleIdempotent +// CHECK-SAME: ([[ARG0:%.+]]: i32) +func @testDoubleIdempotent(%arg0: i32) -> i32 { + // CHECK: [[IDEMPOTENT:%.+]] = "test.op_idempotent_trait"([[ARG0]]) + %0 = "test.op_idempotent_trait"(%arg0) : (i32) -> i32 + %1 = "test.op_idempotent_trait"(%0) : (i32) -> i32 + // CHECK: return [[IDEMPOTENT]] + return %1: i32 +} + +// CHECK-LABEL: func @testTripleIdempotent +// CHECK-SAME: ([[ARG0:%.+]]: i32) +func @testTripleIdempotent(%arg0: i32) -> i32 { + // CHECK: [[IDEMPOTENT:%.+]] = "test.op_idempotent_trait"([[ARG0]]) + %0 = "test.op_idempotent_trait"(%arg0) : (i32) -> i32 + %1 = "test.op_idempotent_trait"(%0) : (i32) -> i32 + %2 = "test.op_idempotent_trait"(%1) : (i32) -> i32 + // CHECK: return [[IDEMPOTENT]] + return %2: i32 +}