diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -21,11 +21,17 @@ /// Options for analysis-enabled bufferization. struct OneShotBufferizationOptions : public BufferizationOptions { + enum class AnalysisHeuristic { BottomUp, TopDown }; + OneShotBufferizationOptions() = default; /// Specifies whether returning newly allocated memrefs should be allowed. /// Otherwise, a pass failure is triggered. bool allowReturnAllocs = false; + + /// The heuristic controls the order in which ops are traversed during the + /// analysis. + AnalysisHeuristic analysisHeuristic = AnalysisHeuristic::BottomUp; }; /// The BufferizationAliasInfo class maintains a list of buffer aliases and diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -284,6 +284,9 @@ Option<"analysisFuzzerSeed", "analysis-fuzzer-seed", "unsigned", /*default=*/"0", "Test only: Analyze ops in random order with a given seed (fuzzer)">, + Option<"analysisHeuristic", "analysis-heuristic", "std::string", + /*default=*/"\"bottom-up\"", + "Heuristic that control the IR traversal during analysis">, Option<"bufferizeFunctionBoundaries", "bufferize-function-boundaries", "bool", /*default=*/"0", "Bufferize function boundaries (experimental).">, diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -172,6 +172,15 @@ llvm_unreachable("invalid layout map option"); } +static OneShotBufferizationOptions::AnalysisHeuristic +parseHeuristicOption(const std::string &s) { + if (s == "bottom-up") + return OneShotBufferizationOptions::AnalysisHeuristic::BottomUp; + if (s == "top-down") + return OneShotBufferizationOptions::AnalysisHeuristic::TopDown; + llvm_unreachable("invalid analysisheuristic option"); +} + struct OneShotBufferizePass : public bufferization::impl::OneShotBufferizeBase { OneShotBufferizePass() = default; @@ -193,6 +202,7 @@ opt.allowReturnAllocs = allowReturnAllocs; opt.allowUnknownOps = allowUnknownOps; opt.analysisFuzzerSeed = analysisFuzzerSeed; + opt.analysisHeuristic = parseHeuristicOption(analysisHeuristic); opt.copyBeforeWrite = copyBeforeWrite; opt.createDeallocs = createDeallocs; opt.functionBoundaryTypeConversion = diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -675,10 +675,35 @@ const BufferizationAliasInfo &aliasInfo, const AnalysisState &state) { aliasInfo.applyOnAliases(root, [&](Value alias) { - for (auto &use : alias.getUses()) - // Read to a value that aliases root. - if (state.bufferizesToMemoryRead(use)) + for (auto &use : alias.getUses()) { + // Read of a value that aliases root. + if (state.bufferizesToMemoryRead(use)) { res.insert(&use); + continue; + } + + // Read of a dependent value in the SSA use-def chain. E.g.: + // + // %0 = ... + // %1 = tensor.extract_slice %0 {not_analyzed_yet} + // "read"(%1) + // + // In the above example, getAliasingReads(%0) includes the first OpOperand + // of the tensor.extract_slice op. The extract_slice itself does not read + // but its aliasing result is eventually fed into an op that does. + // + // Note: This is considered a "read" only if the use does not bufferize to + // a memory write. (We already ruled out memory reads. In case of a memory + // write, the buffer would be entirely overwritten; in the above example + // there would then be no flow of data from the extract_slice operand to + // its result's uses.) + if (!state.bufferizesToMemoryWrite(use)) { + SmallVector opResults = state.getAliasingOpResult(use); + if (llvm::any_of(opResults, + [&](OpResult r) { return state.isValueRead(r); })) + res.insert(&use); + } + } }); } @@ -837,14 +862,33 @@ llvm::shuffle(ops.begin(), ops.end(), g); } - // Walk ops in reverse for better interference analysis. - for (Operation *op : reverse(ops)) + // Analyze a single op. + auto analyzeOp = [&](Operation *op) { for (OpOperand &opOperand : op->getOpOperands()) if (opOperand.get().getType().isa()) if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) if (failed(bufferizableInPlaceAnalysisImpl(opOperand, aliasInfo, state, domInfo))) return failure(); + return success(); + }; + + OneShotBufferizationOptions::AnalysisHeuristic heuristic = + static_cast(state.getOptions()) + .analysisHeuristic; + if (heuristic == OneShotBufferizationOptions::AnalysisHeuristic::BottomUp) { + // Default: Walk ops in reverse for better interference analysis. + for (Operation *op : reverse(ops)) + if (failed(analyzeOp(op))) + return failure(); + } else if (heuristic == + OneShotBufferizationOptions::AnalysisHeuristic::TopDown) { + for (Operation *op : ops) + if (failed(analyzeOp(op))) + return failure(); + } else { + llvm_unreachable("unsupported heuristic"); + } return success(); } diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir @@ -5,6 +5,9 @@ // RUN: mlir-opt %s -one-shot-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null // RUN: mlir-opt %s -one-shot-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null +// Run with top-down analysis. +// RUN: mlir-opt %s -one-shot-bufferize="allow-unknown-ops analysis-heuristic=top-down" -split-input-file | FileCheck %s --check-prefix=CHECK-TOP-DOWN-ANALYSIS + // Test without analysis: Insert a copy on every buffer write. // RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-unknown-ops copy-before-write" -split-input-file | FileCheck %s --check-prefix=CHECK-COPY-BEFORE-WRITE @@ -174,3 +177,24 @@ // CHECK: return %[[r]] return %0 : tensor<5xf32> } + +// ----- + +// CHECK-LABEL: func @read_of_alias +// CHECK-TOP-DOWN-ANALYSIS-LABEL: func @read_of_alias +func.func @read_of_alias(%t: tensor<100xf32>, %pos1: index, %pos2: index, + %pos3: index, %pos4: index, %sz: index, %f: f32) + -> (f32, f32) +{ + // CHECK: %[[alloc:.*]] = memref.alloc + // CHECK: memref.copy + // CHECK: memref.store %{{.*}}, %[[alloc]] + // CHECK-TOP-DOWN-ANALYSIS: %[[alloc:.*]] = memref.alloc + // CHECK-TOP-DOWN-ANALYSIS: memref.copy + // CHECK-TOP-DOWN-ANALYSIS: memref.store %{{.*}}, %[[alloc]] + %0 = tensor.insert %f into %t[%pos1] : tensor<100xf32> + %1 = tensor.extract_slice %t[%pos2][%sz][1] : tensor<100xf32> to tensor + %2 = tensor.extract %1[%pos3] : tensor + %3 = tensor.extract %0[%pos3] : tensor<100xf32> + return %2, %3 : f32, f32 +}