diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -30,6 +30,7 @@ static constexpr int64_t kBufferAlignments = 128; class BufferizationAliasInfo; +class BufferizableOpInterface; struct BufferizationOptions; class BufferizationState; struct PostAnalysisStep; @@ -92,6 +93,21 @@ std::make_unique(std::forward(args)...)); } + /// Return `true` if the op is allowed to be bufferized. + bool isOpAllowListed(Operation *op) const { + if (!dialectFilter.hasValue()) + return true; + return dialectFilter.getValue().contains(op->getDialect()->getNamespace()); + } + + /// Try to cast the given op to BufferizableOpInterface if the op is allow + /// listed. + BufferizableOpInterface dynCastBufferizableOp(Operation *op) const; + + /// Try to cast the given value to BufferizableOpInterface if the op is allow + /// listed. + BufferizableOpInterface dynCastBufferizableOp(Value value) const; + /// Helper functions for allocation, deallocation, memory copying. std::unique_ptr allocationFns; @@ -114,6 +130,16 @@ /// Registered post analysis steps. PostAnalysisStepList postAnalysisSteps; + + /// Only bufferize ops from dialects that are allowed-listed by the filter. + /// All other ops are ignored. This option controls the scope of partial + /// bufferization. + /// + /// Note: If no filter is specified, all ops are bufferized (as long as they + /// implement BufferizableOpInterface). If a filter is specified, + /// `allowUnknownOps` should be enabled. Otherwise, bufferization would fail + /// when encountering an op that is forbidden by the filter. + Optional> dialectFilter; }; /// Specify fine-grain relationship between buffers to enable more analysis. @@ -128,7 +154,8 @@ /// equivalence classes to support bufferization. class BufferizationAliasInfo { public: - explicit BufferizationAliasInfo(Operation *rootOp); + explicit BufferizationAliasInfo(Operation *rootOp, + const BufferizationOptions &options); // BufferizationAliasInfo should be passed as a reference. BufferizationAliasInfo(const BufferizationAliasInfo &) = delete; @@ -265,7 +292,7 @@ /// starting the traversal from Value 1, the resulting SetVector is: /// { 2, 7, 8, 5 } llvm::SetVector -findValueInReverseUseDefChain(Value value, +findValueInReverseUseDefChain(Value value, const BufferizationOptions &options, std::function condition); /// Find the Value of the last preceding write of a given Value. @@ -276,7 +303,7 @@ /// /// Note: When reaching an end of the reverse SSA use-def chain, that value /// is returned regardless of whether it is a memory write or not. -Value findLastPrecedingWrite(Value value); +Value findLastPrecedingWrite(Value value, const BufferizationOptions &options); /// Dialect-specific bufferization state. Analysis/bufferization information /// that is specific to ops from a certain dialect can be stored in derived @@ -300,7 +327,7 @@ class BufferizationState { public: BufferizationState(Operation *op, const BufferizationOptions &options) - : aliasInfo(op), options(options), builder(op->getContext()) {} + : aliasInfo(op, options), options(options), builder(op->getContext()) {} // BufferizationState should be passed as a reference. BufferizationState(const BufferizationState &) = delete; diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td @@ -252,6 +252,7 @@ /*methodName=*/"isNotConflicting", /*args=*/(ins "OpOperand *":$uRead, "OpOperand *":$uWrite, + "BufferizationState &":$state, "const BufferizationAliasInfo &":$aliasInfo), /*methodBody=*/"", /*defaultImplementation=*/[{ diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -78,7 +78,8 @@ // BufferizationAliasInfo //===----------------------------------------------------------------------===// -BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) { +BufferizationAliasInfo::BufferizationAliasInfo( + Operation *rootOp, const BufferizationOptions &options) { rootOp->walk([&](Operation *op) { for (Value v : op->getResults()) if (v.getType().isa()) @@ -93,6 +94,8 @@ // Set up alias sets for OpResults that must bufferize in-place. This should // be done before making any other bufferization decisions. rootOp->walk([&](BufferizableOpInterface bufferizableOp) { + if (!options.isOpAllowListed(bufferizableOp)) + return WalkResult::skip(); for (OpResult opResult : bufferizableOp->getOpResults()) { if (opResult.getType().isa()) if (bufferizableOp.mustBufferizeInPlace(opResult)) { @@ -105,6 +108,7 @@ markInPlace(opResult); } } + return WalkResult::advance(); }); } @@ -197,6 +201,21 @@ } } +BufferizableOpInterface mlir::linalg::comprehensive_bufferize:: + BufferizationOptions::dynCastBufferizableOp(Operation *op) const { + if (isOpAllowListed(op)) + return dyn_cast(op); + return nullptr; +} + +BufferizableOpInterface mlir::linalg::comprehensive_bufferize:: + BufferizationOptions::dynCastBufferizableOp(Value value) const { + if (auto bufferizableOp = value.getDefiningOp()) + if (isOpAllowListed(bufferizableOp.getOperation())) + return bufferizableOp; + return nullptr; +} + /// Determine which OpOperand* will alias with `result` if the op is bufferized /// in place. Return an empty vector if the op is not bufferizable. SmallVector @@ -283,7 +302,8 @@ // further. llvm::SetVector mlir::linalg::comprehensive_bufferize::findValueInReverseUseDefChain( - Value value, std::function condition) { + Value value, const BufferizationOptions &options, + std::function condition) { llvm::SetVector result, workingSet; workingSet.insert(value); @@ -296,7 +316,7 @@ OpResult opResult = value.cast(); SmallVector opOperands = getAliasingOpOperand(opResult); - if (opOperands.empty()) { + if (opOperands.empty() || !options.isOpAllowListed(value.getDefiningOp())) { result.insert(value); continue; } @@ -310,13 +330,13 @@ // Find the Value of the last preceding write of a given Value. Value mlir::linalg::comprehensive_bufferize::findLastPrecedingWrite( - Value value) { + Value value, const BufferizationOptions &options) { SetVector result = - findValueInReverseUseDefChain(value, [](Value value) { + findValueInReverseUseDefChain(value, options, [&](Value value) { Operation *op = value.getDefiningOp(); if (!op) return true; - auto bufferizableOp = dyn_cast(op); + auto bufferizableOp = options.dynCastBufferizableOp(op); if (!bufferizableOp) return true; return bufferizableOp.isMemoryWrite(value.cast()); @@ -374,9 +394,8 @@ // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA // use-def chain, it returns that value, regardless of whether it is a // memory write or not. - Value lastWrite = findLastPrecedingWrite(operand); - if (auto bufferizableOp = - lastWrite.getDefiningOp()) + Value lastWrite = findLastPrecedingWrite(operand, options); + if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite)) if (!bufferizableOp.isMemoryWrite(lastWrite.cast())) skipCopy = true; // Do not copy if the copied data is never read. @@ -433,7 +452,7 @@ // Bufferize using `BufferizableOpInterface`. Interface implementations are // responsible for bufferizing nested ops. - if (auto bufferizableOp = dyn_cast(op)) { + if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) { b.setInsertionPoint(op); return bufferizableOp.bufferize(b, state); } @@ -646,8 +665,7 @@ if (options.allowUnknownOps) { // `tensor` was not bufferized yet. This should never happen with // bufferizable ops. - assert(!tensor.getDefiningOp() && - "tensor is not mapped"); + assert(!options.dynCastBufferizableOp(tensor) && "tensor is not mapped"); // Insert to_memref op. OpBuilder b(tensor.getContext()); setInsertionPointAfter(b, tensor); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -256,13 +256,13 @@ aliasInfo.applyOnAliases(value, [&](Value v) { // Query BufferizableOpInterface to see if the OpResult is writable. // TODO: Out-of-place bufferized OpResult could be considered writable. - if (auto bufferizableOp = v.getDefiningOp()) + if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(v)) if (bufferizableOp && bufferizableOp.isWritable(v, state)) return; // Query BufferizableOpInterface to see if the BlockArgument is writable. if (auto bbArg = v.dyn_cast()) - if (auto bufferizableOp = dyn_cast( + if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp( bbArg.getOwner()->getParentOp())) if (bufferizableOp.isWritable(bbArg, state)) return; @@ -324,11 +324,12 @@ /// A conflict is: According to SSA use-def chains, a read R is supposed to read /// the result of a write W1. But because of bufferization decisions, R actually /// reads another write W2. -static bool -hasReadAfterWriteInterference(const DenseSet &usesRead, - const DenseSet &usesWrite, - const DominanceInfo &domInfo, - const BufferizationAliasInfo &aliasInfo) { +static bool hasReadAfterWriteInterference( + const DenseSet &usesRead, + const DenseSet &usesWrite, const DominanceInfo &domInfo, + BufferizationState &state, const BufferizationAliasInfo &aliasInfo) { + const BufferizationOptions &options = state.getOptions(); + for (OpOperand *uRead : usesRead) { Operation *readingOp = uRead->getOwner(); @@ -341,7 +342,7 @@ // In the above example, if uRead is the OpOperand of reading_op, lastWrite // is %0. Note that operations that create an alias but do not write (such // as ExtractSliceOp) are skipped. - Value lastWrite = findLastPrecedingWrite(uRead->get()); + Value lastWrite = findLastPrecedingWrite(uRead->get(), options); // Look for conflicting memory writes. Potential conflicts are writes to an // alias that have been decided to bufferize inplace. @@ -370,15 +371,15 @@ continue; // No conflict if the op interface says so. - if (auto bufferizableOp = dyn_cast(readingOp)) - if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, + if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) + if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state, aliasInfo)) continue; if (conflictingWritingOp != readingOp) if (auto bufferizableOp = - dyn_cast(conflictingWritingOp)) - if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, + options.dynCastBufferizableOp(conflictingWritingOp)) + if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state, aliasInfo)) continue; @@ -452,7 +453,7 @@ /// involving aliases of the given OpOperand are checked. bool wouldCreateReadAfterWriteInterference( OpOperand &operand, OpResult result, const DominanceInfo &domInfo, - const BufferizationAliasInfo &aliasInfo, + BufferizationState &state, const BufferizationAliasInfo &aliasInfo, bool checkConsistencyOnly = false) { #ifndef NDEBUG if (result) { @@ -496,7 +497,8 @@ if (!checkConsistencyOnly && bufferizesToMemoryWrite(operand)) usesWrite.insert(&operand); - return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, aliasInfo); + return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state, + aliasInfo); } /// Return true if bufferizing `opOperand` inplace with `opResult` would create @@ -555,7 +557,7 @@ bool foundInterference = wouldCreateWriteToNonWritableBuffer(operand, result, aliasInfo, state) || - wouldCreateReadAfterWriteInterference(operand, result, domInfo, + wouldCreateReadAfterWriteInterference(operand, result, domInfo, state, aliasInfo); if (foundInterference) @@ -603,7 +605,7 @@ for (Operation *op : reverse(ops)) for (OpOperand &opOperand : op->getOpOperands()) if (opOperand.get().getType().isa()) - if (auto bufferizableOp = dyn_cast(op)) + if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) if (OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand)) if (failed(bufferizableInPlaceAnalysisImpl( opOperand, opResult, aliasInfo, state, domInfo))) @@ -633,9 +635,10 @@ /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops. static void equivalenceAnalysis(SmallVector &ops, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + const BufferizationOptions &options) { for (Operation *op : ops) - if (auto bufferizableOp = dyn_cast(op)) + if (auto bufferizableOp = options.dynCastBufferizableOp(op)) for (OpResult opResult : op->getOpResults()) if (opResult.getType().isa()) if (aliasInfo.isInPlace(opResult)) { @@ -652,7 +655,8 @@ /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained /// in `op`. static void equivalenceAnalysis(Operation *op, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + const BufferizationOptions &options) { // Traverse ops in PostOrder: Nested ops first, then enclosing ops. SmallVector ops; op->walk([&](Operation *op) { @@ -662,21 +666,23 @@ ops.push_back(op); }); - equivalenceAnalysis(ops, aliasInfo); + equivalenceAnalysis(ops, aliasInfo, options); } /// Assert that the current bufferization decisions are consistent. static LogicalResult checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo, + BufferizationState &state, const BufferizationAliasInfo &aliasInfo) { + const BufferizationOptions &options = state.getOptions(); Operation *inconsistentOp = nullptr; WalkResult walkResult = op->walk([&](Operation *op) { - if (auto bufferizableOp = dyn_cast(op)) + if (auto bufferizableOp = options.dynCastBufferizableOp(op)) for (OpOperand &opOperand : op->getOpOperands()) if (opOperand.get().getType().isa()) { OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand); if (wouldCreateReadAfterWriteInterference( - opOperand, opResult, domInfo, aliasInfo, + opOperand, opResult, domInfo, state, aliasInfo, /*checkConsistencyOnly=*/true)) { // This error can happen for two reasons. Either the input IR // already has a read-after-write conflict. Or certain @@ -723,14 +729,14 @@ DominanceInfo domInfo(op); BufferizationAliasInfo &aliasInfo = state.aliasInfo; - if (failed(checkAliasInfoConsistency(op, domInfo, aliasInfo))) + if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo))) return failure(); // If the analysis fails, just return. if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo, options.analysisFuzzerSeed))) return failure(); - equivalenceAnalysis(op, aliasInfo); + equivalenceAnalysis(op, aliasInfo, options); auto runPostAnalysisSteps = [&](const PostAnalysisStepList &steps) { for (const std::unique_ptr &step : steps) { @@ -740,7 +746,7 @@ // Analyze ops that were created by the PostAnalysisStep. if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo))) return failure(); - equivalenceAnalysis(newOps, aliasInfo); + equivalenceAnalysis(newOps, aliasInfo, options); } return success(); }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -402,6 +402,7 @@ std::function rewriteFunc, SmallVector &newOps) { OpBuilder b(op->getContext()); + const BufferizationOptions &options = state.getOptions(); WalkResult status = op->walk([&](Operation *op) { for (OpOperand &operand : op->getOpOperands()) { @@ -410,7 +411,7 @@ continue; SetVector maybeInitTensor = - findValueInReverseUseDefChain(operand.get(), [&](Value val) { + findValueInReverseUseDefChain(operand.get(), options, [&](Value val) { // Continue traversal until this function returns true. OpResult opResult = val.dyn_cast(); if (!opResult) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -286,6 +286,7 @@ /// Return true if `value` is originating from an ExtractSliceOp that matches /// the given InsertSliceOp. static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo, + const BufferizationOptions &options, Value value, InsertSliceOp insertOp) { auto condition = [&](Value val) { if (auto extractOp = val.getDefiningOp()) @@ -294,7 +295,7 @@ return false; }; - return llvm::all_of(findValueInReverseUseDefChain(value, condition), + return llvm::all_of(findValueInReverseUseDefChain(value, options, condition), condition); } @@ -326,7 +327,7 @@ } bool isNotConflicting(Operation *op, OpOperand *uRead, - OpOperand *uConflictingWrite, + OpOperand *uConflictingWrite, BufferizationState &state, const BufferizationAliasInfo &aliasInfo) const { Operation *readingOp = uRead->getOwner(); Operation *conflictingWritingOp = uConflictingWrite->getOwner(); @@ -343,8 +344,8 @@ // TODO: Use insertSliceOp.getDestOpOperand etc. when available. if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && - hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(), - insertSliceOp)) + hasMatchingExtractSliceOp(aliasInfo, state.getOptions(), + uConflictingWrite->get(), insertSliceOp)) // Case 1: The main insight is that InsertSliceOp reads only part of // the destination tensor. The overwritten area is not read. If // uConflictingWrite writes into exactly the memory location that is @@ -361,7 +362,8 @@ if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && - hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp)) + hasMatchingExtractSliceOp(aliasInfo, state.getOptions(), uRead->get(), + insertSliceOp)) // Case 2: The read of the source tensor and the write to the dest // tensor via an InsertSliceOp is not a conflict if the read is // reading exactly that part of an equivalent tensor that the @@ -394,8 +396,8 @@ if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && aliasInfo.areEquivalentBufferizedValues(uRead->get(), insertSliceOp.source()) && - hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(), - insertSliceOp)) + hasMatchingExtractSliceOp(aliasInfo, state.getOptions(), + insertSliceOp.source(), insertSliceOp)) return true; return false; diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp --- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp @@ -85,6 +85,10 @@ *this, "analysis-fuzzer-seed", llvm::cl::desc("Analyze ops in random order with a given seed (fuzzer)"), llvm::cl::init(0)}; + ListOption dialectFilter{ + *this, "dialect-filter", + llvm::cl::desc("Bufferize only ops from the specified dialects"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; }; } // namespace @@ -104,6 +108,12 @@ options.testAnalysisOnly = testAnalysisOnly; options.analysisFuzzerSeed = analysisFuzzerSeed; + if (dialectFilter.hasValue()) { + options.dialectFilter.emplace(); + for (const std::string &dialectNamespace : dialectFilter) + options.dialectFilter->insert(dialectNamespace); + } + Operation *op = getFunction().getOperation(); if (failed(runComprehensiveBufferize(op, options))) return;