Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
Show First 20 Lines • Show All 155 Lines • ▼ Show 20 Lines | WalkResult result = op->walk([&](Operation *op) { | ||||
if (failed(bufferizableOp.resolveConflicts(rewriter, state))) | if (failed(bufferizableOp.resolveConflicts(rewriter, state))) | ||||
return WalkResult::interrupt(); | return WalkResult::interrupt(); | ||||
return WalkResult::advance(); | return WalkResult::advance(); | ||||
}); | }); | ||||
return failure(result.wasInterrupted()); | return failure(result.wasInterrupted()); | ||||
} | } | ||||
namespace { | |||||
struct TensorCopyInsertionPass | |||||
: public bufferization::impl::TensorCopyInsertionBase< | |||||
TensorCopyInsertionPass> { | |||||
TensorCopyInsertionPass() : options(llvm::None) {} | |||||
TensorCopyInsertionPass(const OneShotBufferizationOptions &options) | |||||
: options(options) {} | |||||
void getDependentDialects(DialectRegistry ®istry) const override { | |||||
registry.insert<bufferization::BufferizationDialect>(); | |||||
} | |||||
void runOnOperation() override { | |||||
if (options) { | |||||
if (failed(insertTensorCopies(getOperation(), *options))) | |||||
signalPassFailure(); | |||||
} else { | |||||
OneShotBufferizationOptions options; | |||||
options.allowReturnAllocs = allowReturnAllocs; | |||||
options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; | |||||
options.createDeallocs = createDeallocs; | |||||
if (mustInferMemorySpace) | |||||
options.defaultMemorySpace = None; | |||||
if (failed(insertTensorCopies(getOperation(), options))) | |||||
signalPassFailure(); | |||||
} | |||||
} | |||||
private: | |||||
Optional<OneShotBufferizationOptions> options; | |||||
}; | |||||
} // namespace | |||||
std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass() { | |||||
return std::make_unique<TensorCopyInsertionPass>(); | |||||
} | |||||
std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass( | |||||
const OneShotBufferizationOptions &options) { | |||||
return std::make_unique<TensorCopyInsertionPass>(options); | |||||
} |