Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
Show All 14 Lines | |||||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" | #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" | ||||
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" | #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" | ||||
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" | #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" | ||||
#include "mlir/Dialect/Bufferization/Transforms/Passes.h" | #include "mlir/Dialect/Bufferization/Transforms/Passes.h" | ||||
using namespace mlir; | using namespace mlir; | ||||
using namespace mlir::bufferization; | using namespace mlir::bufferization; | ||||
/// Resolve all operands that are also used inside of repetitive regions of the | |||||
/// same op. Such cases are not fully supported by One-Shot Bufferize. | |||||
/// | |||||
/// E.g.: | |||||
/// %r = scf.for ... iter_args(%t = %tensor) -> tensor<?xf32> { | |||||
/// "some_use"(%tensor) | |||||
/// ... | |||||
/// } | |||||
/// | |||||
/// Is converted to: | |||||
/// %tensor_copy = bufferization.alloc_tensor copy(%tensor) | |||||
/// %r = scf.for ... iter_args(%t = %tensor) -> tensor<?xf32> { | |||||
/// "some_use"(%tensor_copy) | |||||
/// ... | |||||
/// } | |||||
static void | |||||
resolveUsesInRepetitiveRegions(Operation *op, | |||||
const BufferizationOptions &options) { | |||||
IRRewriter rewriter(op->getContext()); | |||||
AnalysisState state(options); | |||||
// Look for repetitive ops (loops). | |||||
op->walk([&](RegionBranchOpInterface regionBranchOp) { | |||||
// Skip non-bufferizable ops. | |||||
auto bufferizableOp = options.dynCastBufferizableOp(regionBranchOp); | |||||
if (!bufferizableOp) | |||||
return WalkResult::advance(); | |||||
// Find all operands that are also used inside of a repetitve region of this | |||||
// op. | |||||
for (OpOperand &opOperand : regionBranchOp->getOpOperands()) { | |||||
Value operand = opOperand.get(); | |||||
// Skip non-tensor operands. | |||||
if (!operand.getType().isa<TensorType>()) | |||||
continue; | |||||
// Skip operands that do not bufferize to memory writes. | |||||
if (!bufferizableOp.bufferizesToMemoryWrite(opOperand, state)) | |||||
continue; | |||||
// Gather all uses inside repetitive regions. | |||||
SmallVector<OpOperand *> usesInsideRegion; | |||||
for (OpOperand &use : operand.getUses()) { | |||||
Operation *owner = use.getOwner(); | |||||
if (!regionBranchOp->isProperAncestor(owner)) | |||||
continue; | |||||
for (Region &r : regionBranchOp->getRegions()) { | |||||
if (r.findAncestorOpInRegion(*owner) && | |||||
regionBranchOp.isRepetitiveRegion(r.getRegionNumber())) { | |||||
jreiffers: You could stop iterating after this. Or consider using llvm::any_of, that way you can get rid… | |||||
usesInsideRegion.push_back(&use); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
// Nothing to do if the operand is not used inside a repetitive region. | |||||
if (usesInsideRegion.empty()) | |||||
continue; | |||||
// Insert a tensor copy and replace all uses inside of repetitive regions. | |||||
rewriter.setInsertionPoint(regionBranchOp); | |||||
auto tensorCopy = rewriter.create<AllocTensorOp>( | |||||
regionBranchOp->getLoc(), operand.getType().cast<TensorType>(), | |||||
/*dynamicSizes=*/ValueRange(), | |||||
/*copy=*/operand, /*memory_space=*/IntegerAttr()); | |||||
for (OpOperand *use : usesInsideRegion) | |||||
use->set(tensorCopy); | |||||
} | |||||
return WalkResult::advance(); | |||||
}); | |||||
} | |||||
LogicalResult mlir::bufferization::insertTensorCopies( | LogicalResult mlir::bufferization::insertTensorCopies( | ||||
Operation *op, const OneShotBufferizationOptions &options) { | Operation *op, const OneShotBufferizationOptions &options) { | ||||
// Preprocessing: Resolve currently unsupported bufferization cases. | |||||
resolveUsesInRepetitiveRegions(op, options); | |||||
OneShotAnalysisState state(op, options); | OneShotAnalysisState state(op, options); | ||||
// Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize | // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize | ||||
// analysis depending on whether function boundary bufferization is enabled or | // analysis depending on whether function boundary bufferization is enabled or | ||||
// not. | // not. | ||||
if (options.bufferizeFunctionBoundaries) { | if (options.bufferizeFunctionBoundaries) { | ||||
if (failed(analyzeModuleOp(cast<ModuleOp>(op), state))) | if (failed(analyzeModuleOp(cast<ModuleOp>(op), state))) | ||||
return failure(); | return failure(); | ||||
} else { | } else { | ||||
▲ Show 20 Lines • Show All 95 Lines • Show Last 20 Lines |
You could stop iterating after this. Or consider using llvm::any_of, that way you can get rid of the continue above as well.
You might also get rid of usesInsideRegion completely, if you initialize tensorCopy lazily:
Not 100% sure it's nicer though.