diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1724,6 +1724,8 @@ def Idempotent : NativeOpTrait<"IsIdempotent">; // op op X == X def Involution : NativeOpTrait<"IsInvolution">; +// op op' X = op' op X for all per-element unary op'. +def ValuePreserving : NativeOpTrait<"IsValuePreserving">; // Op behaves like a constant. def ConstantLike : NativeOpTrait<"ConstantLike">; // Op behaves like a function. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -395,6 +395,7 @@ LogicalResult verifyNOperands(Operation *op, unsigned numOperands); LogicalResult verifyIsIdempotent(Operation *op); LogicalResult verifyIsInvolution(Operation *op); +LogicalResult verifyIsValuePreserving(Operation *op); LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands); LogicalResult verifyOperandsAreFloatLike(Operation *op); LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op); @@ -1057,6 +1058,24 @@ } }; +/// This class adds property that the operation is idempotent. +/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x) +template +class IsValuePreserving : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + static_assert(ConcreteType::template hasTrait(), + "expected operation to produce one result"); + static_assert(ConcreteType::template hasTrait(), + "expected operation to take one operand"); + static_assert(ConcreteType::template hasTrait(), + "expected operation to preserve type"); + // ValuePreserving requires the operation to be side effect free as well + // but currently this check is under a FIXME and is not actually done. + return impl::verifyIsValuePreserving(op); + } +}; + /// This class verifies that all operands of the specified op have a float type, /// a vector thereof, or a tensor thereof. template diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -680,7 +680,16 @@ //===----------------------------------------------------------------------===// OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { - auto *argumentOp = op->getOperand(0).getDefiningOp(); + Operation *argumentOp = op; + // Skip over value preserving nodes to see if we can find + // a match for the op to fold. + do { + // Idempotent and value preserving ops all have one arg. + assert(argumentOp->getNumOperands() == 1); + argumentOp = argumentOp->getOperand(0).getDefiningOp(); + } while (argumentOp && op->getName() != argumentOp->getName() && + argumentOp->hasTrait()); + if (argumentOp && op->getName() == argumentOp->getName()) { // Replace the outer operation output with the inner operation. return op->getOperand(0); @@ -756,6 +765,14 @@ return success(); } +LogicalResult OpTrait::impl::verifyIsValuePreserving(Operation *op) { + // FIXME: Add back check for no side effects on operation. + // Currently adding it would cause the shared library build + // to fail since there would be a dependency of IR on SideEffectInterfaces + // which is cyclical. + return success(); +} + LogicalResult OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { for (auto opType : op->getOperandTypes()) { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -870,6 +870,20 @@ let hasFolder = 1; } +def TestValuePreservingTraitOp + : TEST_Op<"op_value_preserving_trait", + [SameOperandsAndResultType, NoSideEffect, ValuePreserving]> { + let arguments = (ins I32:$op1); + let results = (outs I32); +} + +def TestNoValuePreservingTraitOp + : TEST_Op<"op_no_value_preserving_trait", + [SameOperandsAndResultType, NoSideEffect]> { + let arguments = (ins I32:$op1); + let results = (outs I32); +} + def TestOpInPlaceFoldAnchor : TEST_Op<"op_in_place_fold_anchor"> { let arguments = (ins I32); let results = (outs I32); diff --git a/mlir/test/mlir-tblgen/trait.mlir b/mlir/test/mlir-tblgen/trait.mlir --- a/mlir/test/mlir-tblgen/trait.mlir +++ b/mlir/test/mlir-tblgen/trait.mlir @@ -93,3 +93,51 @@ // CHECK: return [[IDEMPOTENT]] return %2: i32 } + +//===----------------------------------------------------------------------===// +// Test that folding idempotent ops mixed with value preserving ops +// works correctly +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @testSingleValuePreserving +// CHECK-SAME: ([[ARG0:%.+]]: i32) +func @testSingleValuePreserving(%arg0 : i32) -> i32 { + // CHECK: [[IDEMPOTENT:%.+]] = "test.op_idempotent_trait"([[ARG0]]) + %0 = "test.op_idempotent_trait"(%arg0) : (i32) -> i32 + // CHECK: [[VALUEPRES:%.+]] = "test.op_value_preserving_trait"([[IDEMPOTENT]]) + %1 = "test.op_value_preserving_trait"(%0) : (i32) -> i32 + %2 = "test.op_idempotent_trait"(%1) : (i32) -> i32 + // CHECK: return [[VALUEPRES]] + return %2: i32 +} + +// CHECK-LABEL: func @testChainValuePreserving +// CHECK-SAME: ([[ARG0:%.+]]: i32) +func @testChainValuePreserving(%arg0 : i32) -> i32 { + // CHECK: [[IDEMPOTENT1:%.+]] = "test.op_idempotent_trait"([[ARG0]]) + %0 = "test.op_idempotent_trait"(%arg0) : (i32) -> i32 + // CHECK: [[VALUEPRES1:%.+]] = "test.op_value_preserving_trait"([[IDEMPOTENT1]]) + %1 = "test.op_value_preserving_trait"(%0) : (i32) -> i32 + // CHECK: [[VALUEPRES2:%.+]] = "test.op_value_preserving_trait"([[VALUEPRES1]]) + %2 = "test.op_value_preserving_trait"(%1) : (i32) -> i32 + // CHECK: [[VALUEPRES3:%.+]] = "test.op_value_preserving_trait"([[VALUEPRES2]]) + %3 = "test.op_value_preserving_trait"(%2) : (i32) -> i32 + %4 = "test.op_idempotent_trait"(%3) : (i32) -> i32 + // CHECK: return [[VALUEPRES3]] + return %4: i32 +} + +// CHECK-LABEL: func @testNoValuePreserving +// CHECK-SAME: ([[ARG0:%.+]]: i32) +func @testNoValuePreserving(%arg0 : i32) -> i32 { + // CHECK: [[IDEMPOTENT1:%.+]] = "test.op_idempotent_trait"([[ARG0]]) + %0 = "test.op_idempotent_trait"(%arg0) : (i32) -> i32 + // CHECK: [[VALUEPRES:%.+]] = "test.op_value_preserving_trait"([[IDEMPOTENT1]]) + %1 = "test.op_value_preserving_trait"(%0) : (i32) -> i32 + // CHECK: [[NOVALUEPRES:%.+]] = "test.op_no_value_preserving_trait"([[VALUEPRES]]) + %2 = "test.op_no_value_preserving_trait"(%1) : (i32) -> i32 + // CHECK: [[IDEMPOTENT2:%.+]] = "test.op_value_preserving_trait"([[NOVALUEPRES]]) + %3 = "test.op_value_preserving_trait"(%2) : (i32) -> i32 + // CHECK: return [[IDEMPOTENT2]] + return %3: i32 +}