diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -233,6 +233,14 @@ if (auto *constOp = tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl, res.getType(), op->getLoc())) { + // Ensure that this constant dominates the operation we are replacing it + // with. This may not automatically happen if the operation being folded + // was inserted before the constant within the insertion block. + if (constOp->getBlock() == op->getBlock() && + !constOp->isBeforeInBlock(op)) { + constOp->moveBefore(op); + } + results.push_back(constOp->getResult(0)); continue; } diff --git a/mlir/test/Transforms/test-operation-folder.mlir b/mlir/test/Transforms/test-operation-folder.mlir --- a/mlir/test/Transforms/test-operation-folder.mlir +++ b/mlir/test/Transforms/test-operation-folder.mlir @@ -10,3 +10,15 @@ %0 = "test.op_in_place_fold_anchor"(%c42) : (i32) -> (i32) return %0 : i32 } + +func @test_fold_before_previously_folded_op() -> (i32, i32) { + // When folding two constants will be generated and uniqued. Check that the + // uniqued constant properly dominates both uses. + // CHECK: %[[CST:.+]] = constant true + // CHECK-NEXT: "test.cast"(%[[CST]]) : (i1) -> i32 + // CHECK-NEXT: "test.cast"(%[[CST]]) : (i1) -> i32 + + %0 = "test.cast"() {test_fold_before_previously_folded_op} : () -> (i32) + %1 = "test.cast"() {test_fold_before_previously_folded_op} : () -> (i32) + return %0, %1 : i32, i32 +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -99,6 +99,29 @@ } }; +/// This pattern creates a foldable operation at the entry point of the block. +/// This tests the situation where the operation folder will need to replace an +/// operation with a previously created constant that does not initially +/// dominate the operation to replace. +struct FolderInsertBeforePreviouslyFoldedConstantPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TestCastOp op, + PatternRewriter &rewriter) const override { + if (!op->hasAttr("test_fold_before_previously_folded_op")) + return failure(); + rewriter.setInsertionPointToStart(op->getBlock()); + + auto constOp = + rewriter.create(op.getLoc(), rewriter.getBoolAttr(true)); + rewriter.replaceOpWithNewOp(op, rewriter.getI32Type(), + Value(constOp)); + return success(); + } +}; + struct TestPatternDriver : public PassWrapper { StringRef getArgument() const final { return "test-patterns"; } StringRef getDescription() const final { return "Run test dialect patterns"; } @@ -107,7 +130,9 @@ populateWithGenerated(patterns); // Verify named pattern is generated with expected name. - patterns.add(&getContext()); + patterns.add( + &getContext()); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); }