diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp @@ -178,20 +178,27 @@ namespace { struct TestLinalgGreedyFusion - : public PassWrapper { - void runOnFunction() override { + : public PassWrapper> { + void runOnOperation() override { MLIRContext *context = &getContext(); OwningRewritePatternList patterns = linalg::getLinalgTilingCanonicalizationPatterns(context); patterns.insert(context); FrozenRewritePatternList frozenPatterns(std::move(patterns)); - while (succeeded(fuseLinalgOpsGreedily(getFunction()))) { - (void)applyPatternsAndFoldGreedily(getFunction(), frozenPatterns); + bool repeat = true; + while (repeat) { + repeat = false; + for (FuncOp funcOp : getOperation().getOps()) { + if (succeeded(fuseLinalgOpsGreedily(funcOp))) { + repeat = true; + (void)applyPatternsAndFoldGreedily(funcOp, frozenPatterns); + } + } PassManager pm(context); pm.addPass(createLoopInvariantCodeMotionPass()); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); - LogicalResult res = pm.run(getFunction()->getParentOfType()); + LogicalResult res = pm.run(getOperation()); if (failed(res)) this->signalPassFailure(); }