diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -96,6 +96,7 @@ struct SCFFuseProducerOfSliceResult { OpResult origProducer; // Original untiled producer. Value tiledAndFusedProducer; // Tile and fused producer value. + SmallVector tiledOps; }; std::optional tileAndFuseProducerOfSlice(RewriterBase &rewriter, diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -604,7 +604,8 @@ } } return scf::SCFFuseProducerOfSliceResult{fusableProducer, - tileAndFuseResult->tiledValues[0]}; + tileAndFuseResult->tiledValues[0], + tileAndFuseResult->tiledOps}; } /// Reconstruct the fused producer from within the tiled-and-fused code. @@ -612,7 +613,8 @@ RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef loops) { - auto [fusableProducer, fusedProducerValue] = fusedProducerInfo; + auto [fusableProducer, fusedProducerValue, tileAndFusedOps] = + fusedProducerInfo; SmallVector initValues; FailureOr initValue = tensor::getOrCreateDestination( rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer); @@ -623,8 +625,11 @@ yieldTiledValues(rewriter, initValue.value(), fusedProducerValue, resultOffsets, resultSizes, loops); } - if (auto dstStyleProducer = - fusedProducerValue.getDefiningOp()) { + for (auto tileAndFusedOp : tileAndFusedOps) { + auto dstStyleProducer = + dyn_cast(tileAndFusedOp); + if (!dstStyleProducer) + continue; Value dstValue = dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber()) ->get();