diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1724,6 +1724,8 @@ def Idempotent : NativeOpTrait<"IsIdempotent">; // op op X == X def Involution : NativeOpTrait<"IsInvolution">; +// op op' X = op' op X for all per-element unary op'. +def ElementPreserving : NativeOpTrait<"IsElementPreserving">; // Op behaves like a constant. def ConstantLike : NativeOpTrait<"ConstantLike">; // Op behaves like a function. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -395,6 +395,7 @@ LogicalResult verifyNOperands(Operation *op, unsigned numOperands); LogicalResult verifyIsIdempotent(Operation *op); LogicalResult verifyIsInvolution(Operation *op); +LogicalResult verifyIsElementPreserving(Operation *op); LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands); LogicalResult verifyOperandsAreFloatLike(Operation *op); LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op); @@ -1057,6 +1058,26 @@ } }; +/// This class adds property that the operation is element preserving. +/// This means a unary to unary operation "f" that satisfies +/// f(g(x)) = g(f(x)) for all per element unary ops g +template +class IsElementPreserving + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + static_assert(ConcreteType::template hasTrait(), + "expected operation to produce one result"); + static_assert(ConcreteType::template hasTrait(), + "expected operation to take one operand"); + static_assert(ConcreteType::template hasTrait(), + "expected operation to preserve type"); + // ElementPreserving requires the operation to be side effect free as well + // but currently this check is under a FIXME and is not actually done. + return impl::verifyIsElementPreserving(op); + } +}; + /// This class verifies that all operands of the specified op have a float type, /// a vector thereof, or a tensor thereof. template diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -231,6 +231,50 @@ } }; +/// TraitRewritePattern is a wrapper around RewritePattern that allows for +/// matching and rewriting against an op that matches a trait +template