diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -42,7 +42,9 @@ /// Simplify the operations within the given regions. bool simplify(MutableArrayRef regions); - /// Add the given operation to the worklist. + /// Add the given operation to the worklist. Parent ops may or may not be + /// added to the worklist, depending on the type of rewrite driver. By + /// default, parent ops are added. virtual void addToWorklist(Operation *op); /// Pop the next operation from the worklist. @@ -56,6 +58,9 @@ void finalizeRootUpdate(Operation *op) override; protected: + /// Add the given operation to the worklist. + void addSingleOpToWorklist(Operation *op); + // Implement the hook for inserting operations, and make sure that newly // inserted ops are added to the worklist for processing. void notifyOperationInserted(Operation *op) override; @@ -101,6 +106,10 @@ GreedyRewriteConfig config; private: + /// Only ops within this scope are simplified. This is set at the beginning + /// of `simplify()` to the current scope the rewriter operates on. + DenseSet scope; + #ifndef NDEBUG /// A logger used to emit information during the application process. llvm::ScopedPrinter logger{llvm::dbgs()}; @@ -119,6 +128,9 @@ } bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions) { + for (Region &r : regions) + scope.insert(&r); + #ifndef NDEBUG const char *logLineComment = "//===-------------------------------------------===//\n"; @@ -306,6 +318,24 @@ } void GreedyPatternRewriteDriver::addToWorklist(Operation *op) { + // Gather potential ancestors while looking for a "scope" parent region. + SmallVector ancestors; + ancestors.push_back(op); + while (Region *region = op->getParentRegion()) { + if (scope.contains(region)) { + // All gathered ops are in fact ancestors. + for (Operation *op : ancestors) + addSingleOpToWorklist(op); + break; + } + op = region->getParentOp(); + if (!op) + break; + ancestors.push_back(op); + } +} + +void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) { // Check to see if the worklist already contains this op. if (worklistMap.count(op)) return; @@ -540,7 +570,8 @@ /// This is a specialized GreedyPatternRewriteDriver to apply patterns and /// perform folding for a supplied set of ops. It repeatedly simplifies while /// restricting the rewrites to only the provided set of ops or optionally -/// to those directly affected by it (result users or operand providers). +/// to those directly affected by it (result users or operand providers). Parent +/// ops are not considered. class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver { public: explicit MultiOpPatternRewriteDriver(MLIRContext *ctx, @@ -553,7 +584,7 @@ void addToWorklist(Operation *op) override { if (!strictMode || strictModeFilteredOps.contains(op)) - GreedyPatternRewriteDriver::addToWorklist(op); + GreedyPatternRewriteDriver::addSingleOpToWorklist(op); } private: diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -22,13 +22,13 @@ // CHECK: %[[RESULT:.+]] = arith.select %[[COND]], %[[RES1]], %[[RES2]] : f32 // CHECK: return %[[RESULT]] -// ---- +// ----- // CHECK-LABEL: func @ctlz func.func @ctlz(%arg: i32) -> i32 { - // CHECK: %[[C0:.+]] = arith.constant 0 : i32 - // CHECK: %[[C32:.+]] = arith.constant 32 : i32 - // CHECK: %[[C1:.+]] = arith.constant 1 : i32 + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 + // CHECK-DAG: %[[C32:.+]] = arith.constant 32 : i32 + // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i32 // CHECK: %[[WHILE:.+]]:3 = scf.while (%[[A1:.+]] = %arg0, %[[A2:.+]] = %[[C32]], %[[A3:.+]] = %[[C0]]) // CHECK: %[[CMP:.+]] = arith.cmpi ne, %[[A1]], %[[A3]] // CHECK: scf.condition(%[[CMP]]) %[[A1]], %[[A2]], %[[A3]] diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir --- a/mlir/test/Dialect/SCF/loop-pipelining.mlir +++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir @@ -417,7 +417,7 @@ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[CSTF:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[CSTF:.*]] = arith.constant 2.000000e+00 : f32 // Prologue: // CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref // CHECK-NEXT: %[[ADD0:.*]] = arith.addf %[[L0]], %[[CSTF]] : f32 @@ -426,19 +426,22 @@ // CHECK-NEXT: %[[R:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] to %[[C2]] // CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]], // CHECK-SAME: %[[ADDARG:.*]] = %[[ADD0]], %[[LARG:.*]] = %[[L1]]) -> (f32, f32, f32) { -// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[LARG]], %[[ADDARG]] : f32 +// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[ADDARG]], %[[CSTF]] : f32 +// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[LARG]], %[[MUL0]] : f32 // CHECK-NEXT: %[[IV2:.*]] = arith.addi %[[IV]], %[[C2]] : index // CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV2]]] : memref -// CHECK-NEXT: scf.yield %[[ADDARG]], %[[ADD1]], %[[L2]] : f32, f32, f32 +// CHECK-NEXT: scf.yield %[[MUL0]], %[[ADD1]], %[[L2]] : f32, f32, f32 // CHECK-NEXT: } // Epilogue: -// CHECK-NEXT: %[[ADD2:.*]] = arith.addf %[[R]]#2, %[[R]]#1 : f32 -// CHECK-NEXT: return %[[ADD2]] : f32 +// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[R]]#1, %[[CSTF]] : f32 +// CHECK-NEXT: %[[ADD2:.*]] = arith.addf %[[R]]#2, %[[MUL1]] : f32 +// CHECK-NEXT: %[[MUL2:.*]] = arith.mulf %[[ADD2]], %[[CSTF]] : f32 +// CHECK-NEXT: return %[[MUL2]] : f32 func.func @backedge_different_stage(%A: memref) -> f32 { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index - %cf = arith.constant 1.0 : f32 + %cf = arith.constant 2.0 : f32 %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) { %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref %A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32 @@ -455,7 +458,7 @@ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[CSTF:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[CSTF:.*]] = arith.constant 2.000000e+00 : f32 // Prologue: // CHECK: %[[L0:.*]] = scf.execute_region // CHECK-NEXT: memref.load %[[A]][%[[C0]]] : memref @@ -467,23 +470,26 @@ // CHECK: %[[R:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] to %[[C2]] // CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]], // CHECK-SAME: %[[ADDARG:.*]] = %[[ADD0]], %[[LARG:.*]] = %[[L1]]) -> (f32, f32, f32) { +// CHECK: %[[MUL0:.*]] = arith.mulf %[[ADDARG]], %[[CSTF]] : f32 // CHECK: %[[ADD1:.*]] = scf.execute_region -// CHECK-NEXT: arith.addf %[[LARG]], %[[ADDARG]] : f32 +// CHECK-NEXT: arith.addf %[[LARG]], %[[MUL0]] : f32 // CHECK: %[[IV2:.*]] = arith.addi %[[IV]], %[[C2]] : index // CHECK: %[[L2:.*]] = scf.execute_region // CHECK-NEXT: memref.load %[[A]][%[[IV2]]] : memref -// CHECK: scf.yield %[[ADDARG]], %[[ADD1]], %[[L2]] : f32, f32, f32 +// CHECK: scf.yield %[[MUL0]], %[[ADD1]], %[[L2]] : f32, f32, f32 // CHECK-NEXT: } // Epilogue: +// CHECK: %[[MUL1:.*]] = arith.mulf %[[R]]#1, %[[CSTF]] : f32 // CHECK: %[[ADD2:.*]] = scf.execute_region -// CHECK-NEXT: arith.addf %[[R]]#2, %[[R]]#1 : f32 -// CHECK: return %[[ADD2]] : f32 +// CHECK-NEXT: arith.addf %[[R]]#2, %[[MUL1]] : f32 +// CHECK: %[[MUL2:.*]] = arith.mulf %[[ADD2]], %[[CSTF]] : f32 +// CHECK: return %[[MUL2]] : f32 func.func @region_backedge_different_stage(%A: memref) -> f32 { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index - %cf = arith.constant 1.0 : f32 + %cf = arith.constant 2.0 : f32 %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) { %A_elem = scf.execute_region -> f32 { %A_elem1 = memref.load %A[%i0] : memref @@ -507,7 +513,7 @@ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[CSTF:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[CSTF:.*]] = arith.constant 2.000000e+00 : f32 // Prologue: // CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref // Kernel: @@ -515,18 +521,20 @@ // CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]], // CHECK-SAME: %[[LARG:.*]] = %[[L0]]) -> (f32, f32) { // CHECK-NEXT: %[[ADD0:.*]] = arith.addf %[[LARG]], %[[C]] : f32 +// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[ADD0]], %[[CSTF]] : f32 // CHECK-NEXT: %[[IV1:.*]] = arith.addi %[[IV]], %[[C1]] : index // CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV1]]] : memref -// CHECK-NEXT: scf.yield %[[ADD0]], %[[L2]] : f32, f32 +// CHECK-NEXT: scf.yield %[[MUL0]], %[[L2]] : f32, f32 // CHECK-NEXT: } // Epilogue: // CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[R]]#1, %[[R]]#0 : f32 -// CHECK-NEXT: return %[[ADD1]] : f32 +// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[ADD1]], %[[CSTF]] : f32 +// CHECK-NEXT: return %[[MUL1]] : f32 func.func @backedge_same_stage(%A: memref) -> f32 { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index - %cf = arith.constant 1.0 : f32 + %cf = arith.constant 2.0 : f32 %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) { %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref %A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32 @@ -538,7 +546,7 @@ // ----- -// CHECK: @pipeline_op_with_region(%[[ARG0:.+]]: memref, %[[ARG1:.+]]: memref, %[[ARG2:.+]]: memref) { +// CHECK: @pipeline_op_with_region(%[[ARG0:.+]]: memref, %[[ARG1:.+]]: memref, %[[ARG2:.+]]: memref, %[[CF:.*]]: f32) { // CHECK: %[[C0:.+]] = arith.constant 0 : // CHECK: %[[C3:.+]] = arith.constant 3 : // CHECK: %[[C1:.+]] = arith.constant 1 : @@ -590,11 +598,10 @@ __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 2 } -func.func @pipeline_op_with_region(%A: memref, %B: memref, %result: memref) { +func.func @pipeline_op_with_region(%A: memref, %B: memref, %result: memref, %cf: f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index - %cf = arith.constant 1.0 : f32 %a_buf = memref.alloc() : memref<2x8xf32> %b_buf = memref.alloc() : memref<2x8xf32> scf.for %i0 = %c0 to %c4 step %c1 { diff --git a/mlir/test/IR/greedy-pattern-rewriter-driver.mlir b/mlir/test/IR/greedy-pattern-rewriter-driver.mlir --- a/mlir/test/IR/greedy-pattern-rewriter-driver.mlir +++ b/mlir/test/IR/greedy-pattern-rewriter-driver.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -test-patterns="max-iterations=1" | FileCheck %s +// RUN: mlir-opt %s -test-patterns="max-iterations=1" \ +// RUN: -allow-unregistered-dialect --split-input-file | FileCheck %s // CHECK-LABEL: func @add_to_worklist_after_inplace_update() func.func @add_to_worklist_after_inplace_update() { @@ -10,3 +11,16 @@ "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> () return } + +// ----- + +// CHECK-LABEL: func @add_ancestors_to_worklist() +func.func @add_ancestors_to_worklist() { + // CHECK: "foo.maybe_eligible_op"() {eligible} : () -> index + // CHECK-NEXT: "test.one_region_op"() + "test.one_region_op"() ({ + %0 = "foo.maybe_eligible_op" () : () -> (index) + "foo.yield"(%0) : (index) -> () + }) {hoist_eligible_ops}: () -> () + return +} diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -140,6 +140,8 @@ auto attrCycle = op->getAttrOfType(kTestPipeliningOpOrderMarker); if (attrCycle && attrStage) { + // TODO: Index can be out-of-bounds if ops of the loop body disappear + // due to folding. schedule[attrCycle.getInt()] = std::make_pair(op, unsigned(attrStage.getInt())); } diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -167,6 +167,38 @@ } }; +/// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op". +struct MakeOpEligible : public RewritePattern { + MakeOpEligible(MLIRContext *context) + : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (op->hasAttr("eligible")) + return failure(); + rewriter.updateRootInPlace( + op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); }); + return success(); + } +}; + +/// This pattern hoists eligible ops out of a "test.one_region_op". +struct HoistEligibleOps : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(test::OneRegionOp op, + PatternRewriter &rewriter) const override { + Operation *terminator = op.getRegion().front().getTerminator(); + Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp(); + if (toBeHoisted->getParentOp() != op) + return failure(); + if (!toBeHoisted->hasAttr("eligible")) + return failure(); + toBeHoisted->moveBefore(op); + return success(); + } +}; + struct TestPatternDriver : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver) @@ -183,7 +215,8 @@ // Verify named pattern is generated with expected name. patterns.add(&getContext()); + FolderCommutativeOp2WithConstant, HoistEligibleOps, + MakeOpEligible>(&getContext()); // Additional patterns for testing the GreedyPatternRewriteDriver. patterns.insert>(&getContext());