diff --git a/mlir/test/Transforms/test-operation-folder-commutative.mlir b/mlir/test/Transforms/test-operation-folder-commutative.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/test-operation-folder-commutative.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt --pass-pipeline="func.func(test-patterns)" %s | FileCheck %s + +// CHECK-LABEL: func @test_reorder_constants_and_match +func @test_reorder_constants_and_match(%arg0 : i32) -> (i32) { + // CHECK: %[[CST:.+]] = arith.constant 43 + %cst = arith.constant 43 : i32 + // CHECK: return %[[CST]] + %y = "test.op_commutative2"(%cst, %arg0) : (i32, i32) -> i32 + %x = "test.op_commutative2"(%y, %arg0) : (i32, i32) -> i32 + return %x : i32 +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1089,6 +1089,11 @@ let results = (outs I32); } +def TestCommutative2Op : TEST_Op<"op_commutative2", [Commutative]> { + let arguments = (ins I32:$op1, I32:$op2); + let results = (outs I32); +} + def TestIdempotentTraitOp : TEST_Op<"op_idempotent_trait", [SameOperandsAndResultType, NoSideEffect, Idempotent]> { diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -124,18 +124,40 @@ } }; +/// This pattern matches test.op_commutative2 with the first operand being +/// another test.op_commutative2 with a constant on the right side and fold it +/// away by propagating it as its result. This is intend to check that patterns +/// are applied after the commutative property moves constant to the right. +struct FolderCommutativeOp2WithConstant + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TestCommutative2Op op, + PatternRewriter &rewriter) const override { + auto operand = + dyn_cast_or_null(op->getOperand(0).getDefiningOp()); + if (!operand) + return failure(); + Attribute constInput; + if (!matchPattern(operand->getOperand(1), m_Constant(&constInput))) + return failure(); + rewriter.replaceOp(op, operand->getOperand(1)); + return success(); + } +}; + struct TestPatternDriver : public PassWrapper> { StringRef getArgument() const final { return "test-patterns"; } StringRef getDescription() const final { return "Run test dialect patterns"; } void runOnOperation() override { mlir::RewritePatternSet patterns(&getContext()); - populateWithGenerated(patterns); // Verify named pattern is generated with expected name. patterns.add( - &getContext()); + FolderInsertBeforePreviouslyFoldedConstantPattern, + FolderCommutativeOp2WithConstant>(&getContext()); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); }