Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
Show First 20 Lines • Show All 451 Lines • ▼ Show 20 Lines | mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp, | ||||
innerTileSizes.append(split, 0); | innerTileSizes.append(split, 0); | ||||
innerTileSizes.append(tileSizes.begin() + split, tileSizes.end()); | innerTileSizes.append(tileSizes.begin() + split, tileSizes.end()); | ||||
if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange))) | if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange))) | ||||
return failure(); | return failure(); | ||||
fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands()); | fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands()); | ||||
return tileLoopNest; | return tileLoopNest; | ||||
} | } | ||||
namespace { | |||||
struct LinalgTileAndFuseTensorOps | |||||
: public LinalgTileAndFuseTensorOpsBase<LinalgTileAndFuseTensorOps> { | |||||
void notifyFailure(StringRef message) { | |||||
llvm::errs() << " - LinalgTileAndFuseTensorOps: " << message << "\n"; | |||||
signalPassFailure(); | |||||
} | |||||
void runOnFunction() override { | |||||
FuncOp funcOp = getFunction(); | |||||
OpBuilder b(funcOp.getContext()); | |||||
// Heuristic to find a good operation to tile and start fusion. Walk all | |||||
// operations and select the one with the maximal backward slice of fusion | |||||
// candidates. | |||||
LinalgOp rootOp = nullptr; | |||||
int64_t numFusionCandidates = -1; | |||||
funcOp.walk([&](LinalgOp linalgOp) { | |||||
SetVector<Operation *> backwardSlice; | |||||
getBackwardSlice(linalgOp, &backwardSlice); | |||||
int64_t backwardSliceSize = count_if( | |||||
backwardSlice, [](Operation *op) { return isa<LinalgOp>(op); }); | |||||
if (backwardSliceSize > numFusionCandidates) { | |||||
rootOp = linalgOp; | |||||
numFusionCandidates = backwardSliceSize; | |||||
} | |||||
}); | |||||
if (!rootOp) | |||||
return notifyFailure("expect to find a root operation"); | |||||
// Check `tileSizes` contains a tile size for every `rootOp` loop dimension. | |||||
if (tileSizes.size() < rootOp.getNumLoops()) | |||||
return notifyFailure("expect #tile sizes >= #loops"); | |||||
// Check `tileInterchange` contains no entries or as many as `tileSizes`. | |||||
if (!tileInterchange.empty() && | |||||
tileInterchange.size() != tileSizes.size()) { | |||||
return notifyFailure( | |||||
"expect the number of tile sizes and interchange dims to match"); | |||||
} | |||||
// Copy the `tileSizes` and `tileInterchange` prefixes needed to tile | |||||
// `rootOp` or use the identity interchange if `tileInterchange` is empty. | |||||
SmallVector<int64_t> rootTileSizes( | |||||
tileSizes.begin(), tileSizes.begin() + rootOp.getNumLoops()); | |||||
SmallVector<int64_t> rootInterchange = | |||||
tileInterchange.empty() | |||||
? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops())) | |||||
: SmallVector<int64_t>(tileInterchange.begin(), | |||||
tileInterchange.begin() + | |||||
rootOp.getNumLoops()); | |||||
// Check `rootInterchange` is a permutation of the `rootOp` loop dimensions. | |||||
// It has to be a permutation since the tiling cannot tile the same loop | |||||
// dimension multiple times. | |||||
if (!isPermutation(rootInterchange)) | |||||
return notifyFailure( | |||||
"expect the tile interchange permutes the root loops"); | |||||
// Tile `rootOp` and fuse its producers. | |||||
FailureOr<TileLoopNest> tileLoopNest = | |||||
tileConsumerAndFuseProducers(b, rootOp, rootTileSizes, rootInterchange); | |||||
if (failed(tileLoopNest)) | |||||
return notifyFailure("tileConsumerAndFuseProducers failed unexpectedly"); | |||||
// Replace all uses of the tiled loop operation. | |||||
rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults()); | |||||
} | |||||
}; | |||||
} // namespace | |||||
std::unique_ptr<OperationPass<FuncOp>> | |||||
mlir::createLinalgTileAndFuseTensorOpsPass() { | |||||
return std::make_unique<LinalgTileAndFuseTensorOps>(); | |||||
} |