diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -62,6 +62,14 @@ /// b) whose buffer uses would be free of memory hazards. std::unique_ptr createLinalgComprehensiveFuncBufferizePass(); +/// This pass implements a cross-dialect bufferization approach and performs an +/// analysis to determine which op operands and results may be bufferized in the +/// same buffers. The analysis is performed on topologically sorted CallOp and +/// FuncOp within a module. It provides analyses and bufferization across +/// function boundaries. Within a single function body, the bufferization used +/// is that provided by `LinalgComprehensiveFuncBufferizePass`. +std::unique_ptr createLinalgComprehensiveModuleBufferizePass(); + /// Create a pass to convert Linalg operations which work on tensors to use /// buffers instead. std::unique_ptr> createLinalgBufferizePass(); diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -32,7 +32,7 @@ This pass implements a cross-dialect bufferization approach and performs an analysis to determine which op operands and results may be bufferized in the same buffers. The analysis is performed on SSA use-def chains starting from - function operands that are annotated with the 'inplaceable' attribute + function operands that are annotated with the 'inplaceable' attribute. }]; let options = [ Option<"testAnalysisOnly", "test-analysis-only", "bool", @@ -42,6 +42,25 @@ let constructor = "mlir::createLinalgComprehensiveFuncBufferizePass()"; } +def LinalgComprehensiveModuleBufferize : + Pass<"linalg-comprehensive-module-bufferize", "ModuleOp"> { + let summary = "Bufferize (tensor into memref) for a Module."; + let description = [{ + This pass implements a cross-dialect bufferization approach and performs an + analysis to determine which op operands and results may be bufferized in the + same buffers. The analysis is performed on topologically sorted CallOp and + FuncOp within a module. It provides analyses and bufferization across + function boundaries. Within a single function body, the bufferization used + is that provided by `-linalg-comprehensive-func-bufferize`. + }]; + let options = [ + Option<"testAnalysisOnly", "test-analysis-only", "bool", + /*default=*/"false", + "Only runs inplaceability analysis (for testing purposes only)"> + ]; + let constructor = "mlir::createLinalgComprehensiveModuleBufferizePass()"; +} + def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> { let summary = "Remove unit-extent dimension in Linalg ops on tensors"; let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()"; diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -375,6 +375,10 @@ /// attribute that was erased, or nullptr if there was no attribute with such /// name. Attribute removeArgAttr(unsigned index, Identifier name); + Attribute removeArgAttr(unsigned index, StringRef name) { + return removeArgAttr( + index, Identifier::get(name, this->getOperation()->getContext())); + } //===--------------------------------------------------------------------===// // Result Attributes diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -16,7 +16,7 @@ // Composability with extensible set of ops is not a first-class concern. // // Bufferization occurs by: -// a. performing an inPlace analysis `inPlaceAnalysisFuncOpInternals` +// a. performing an inPlace analysis `inPlaceAnalysisFuncOpBody` // which marks each operation within the function with the // `kInPlaceResultsAttrName` attribute. // b. traversing each operation in the function and rewriting it in @@ -132,6 +132,19 @@ #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") #define LDBG(X) LLVM_DEBUG(DBGS() << X) +//===----------------------------------------------------------------------===// +// Generic helpers. +//===----------------------------------------------------------------------===// + +/// Return the FuncOp called by `callOp`. +static FuncOp getCalledFunction(CallOpInterface callOp) { + SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast(); + if (!sym) + return nullptr; + return dyn_cast_or_null( + SymbolTable::lookupNearestSymbolFrom(callOp, sym)); +} + //===----------------------------------------------------------------------===// // Bufferization-specific BlockAndValueMapping support with debugging. //===----------------------------------------------------------------------===// @@ -276,6 +289,25 @@ return InPlaceSpec::None; } +/// Set the attribute that triggers inplace bufferization on a FuncOp argument +/// `bbArg`. +static void +setInPlaceFuncArgument(BlockArgument bbArg, + InPlaceSpec inPlaceSpec = InPlaceSpec::True) { + auto funcOp = cast(bbArg.getOwner()->getParentOp()); + funcOp.setArgAttr( + bbArg.getArgNumber(), LinalgDialect::kInplaceableAttrName, + BoolAttr::get(bbArg.getContext(), inPlaceSpec == InPlaceSpec::True)); +} + +/// Remove the attribute that triggers inplace bufferization on a FuncOp +/// argument `bbArg`. +static void removeInPlaceFuncArgument(BlockArgument bbArg) { + auto funcOp = cast(bbArg.getOwner()->getParentOp()); + funcOp.removeArgAttr(bbArg.getArgNumber(), + LinalgDialect::kInplaceableAttrName); +} + LLVM_ATTRIBUTE_UNUSED static InPlaceSpec getInPlace(Value v) { if (auto bbArg = v.dyn_cast()) return getInPlace(bbArg); @@ -305,7 +337,8 @@ static bool hasKnownBufferizationAliasingBehavior(Operation *op) { return // clang-format off - isa(opOperand.getOwner())) return false; + // CallOpInterface alone doesn't bufferize to a memory read, one of the uses + // of the matching bbArg may. It is the responsibility of the caller to + // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be + // conservative. + if (auto callOp = dyn_cast(opOperand.getOwner())) + return true; if (auto linalgOp = dyn_cast(opOperand.getOwner())) return linalgOp.isInputTensor(&opOperand) || linalgOp.isInitTensor(&opOperand); @@ -473,6 +516,19 @@ static bool bufferizesToMemoryWrite(OpOperand &opOperand, InPlaceSpec inPlaceSpec = InPlaceSpec::None) { + // These terminators are not writes. + if (isa(opOperand.getOwner())) + return false; + // ExtractSliceOp alone doesn't bufferize to a memory write, one of its uses + // may. + if (isa(opOperand.getOwner())) + return false; + // CallOpInterface alone doesn't bufferize to a memory write, one of the uses + // of the matching bbArg may. It is the responsibility of the caller to + // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be + // conservative. + if (auto callOp = dyn_cast(opOperand.getOwner())) + return true; Optional maybeOpResult = getAliasingOpResult(opOperand); // Unknown op that returns a tensor. The inplace analysis does not support // it. Conservatively return true. @@ -482,13 +538,6 @@ // This does not bufferize to a write. if (!*maybeOpResult) return false; - // These terminators are not writes. - if (isa(opOperand.getOwner())) - return false; - // ExtractSliceOp alone doesn't bufferize to a memory write, one of its uses - // may. - if (maybeOpResult->getDefiningOp()) - return false; // If we have a matching OpResult, this is a write. // Additionally allow to restrict to only inPlace write, if so specified. return inPlaceSpec == InPlaceSpec::None || @@ -521,7 +570,11 @@ Equivalent }; - explicit BufferizationAliasInfo(FuncOp funcOp); + explicit BufferizationAliasInfo(Operation *rootOp); + + /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the + /// beginning the alias and equivalence sets only contain `v` itself. + void createAliasInfoEntry(Value v); /// Return true if the buffer to which `operand` would bufferize aliases a /// buffer that is known to not be writeable. This implies that the matching @@ -664,33 +717,28 @@ }; } // namespace -BufferizationAliasInfo::BufferizationAliasInfo(FuncOp funcOp) { - funcOp.walk([&](Operation *op) { - for (Value v : op->getResults()) { - if (!v.getType().isa()) - continue; - assert(getInPlace(v) == InPlaceSpec::None && - "unexpected inplace in analysis."); - DenseSet selfSet; - selfSet.insert(v); - aliasInfo.try_emplace(v, selfSet); - equivalentInfo.insert(v); - } - for (Region &r : op->getRegions()) { - for (Block &b : r.getBlocks()) { - for (auto bbArg : b.getArguments()) { - if (!bbArg.getType().isa()) - continue; - DenseSet selfSet; - selfSet.insert(bbArg); - aliasInfo.try_emplace(bbArg, selfSet); - equivalentInfo.insert(bbArg); - } - } - } +BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) { + rootOp->walk([&](Operation *op) { + for (Value v : op->getResults()) + if (v.getType().isa()) + createAliasInfoEntry(v); + for (Region &r : op->getRegions()) + for (Block &b : r.getBlocks()) + for (auto bbArg : b.getArguments()) + if (bbArg.getType().isa()) + createAliasInfoEntry(bbArg); }); } +/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the +/// beginning the alias and equivalence sets only contain `v` itself. +void BufferizationAliasInfo::createAliasInfoEntry(Value v) { + DenseSet selfSet; + selfSet.insert(v); + aliasInfo.try_emplace(v, selfSet); + equivalentInfo.insert(v); +} + /// Return true if the buffer to which `operand` would bufferize aliases a /// buffer that is known to not be writeable. This implies that the matching /// OpResult cannot be bufferized inplace. @@ -1684,8 +1732,8 @@ /// Analyze the `funcOp` body to determine which OpResults are inplaceable. static LogicalResult -inPlaceAnalysisFuncOpInternals(FuncOp funcOp, BufferizationAliasInfo &aliasInfo, - const DominanceInfo &domInfo) { +inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo, + const DominanceInfo &domInfo) { LLVM_DEBUG(llvm::dbgs() << "\n\n"); LDBG("Begin InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n'); assert(funcOp && funcOp->getNumRegions() > 0 && !funcOp.body().empty() && @@ -1821,7 +1869,7 @@ BufferizationAliasInfo aliasInfo(funcOp); // If the analysis fails, just return. This is expected to reset the IR and no // single OpResult should be marked inPlace. - if (failed(inPlaceAnalysisFuncOpInternals(funcOp, aliasInfo, domInfo))) { + if (failed(inPlaceAnalysisFuncOpBody(funcOp, aliasInfo, domInfo))) { signalPassFailure(); return; } @@ -1841,3 +1889,122 @@ std::unique_ptr mlir::createLinalgComprehensiveFuncBufferizePass() { return std::make_unique(); } + +//===----------------------------------------------------------------------===// +// Bufferization entry-point for modules. +//===----------------------------------------------------------------------===// + +/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by +/// callee-caller order (i.e. callees without callers first). +/// Store the map of FuncOp to all its callers in `callerMap`. +/// Return `failure()` if a cycle of calls is detected or if we are unable to +/// retrieve the called FuncOp from any CallOpInterface. +static LogicalResult +getFuncOpsOrderedByCalls(ModuleOp moduleOp, + SmallVectorImpl &orderedFuncOps, + DenseMap> &callerMap) { + // For each FuncOp, the set of functions called by it (i.e. the union of + // symbols of all nested CallOpInterfaceOp). + DenseMap> calledBy; + // For each FuncOp, the number of CallOpInterface it contains. + DenseMap numberCallOpsContainedInFuncOp; + WalkResult res = moduleOp.walk([&](FuncOp funcOp) { + numberCallOpsContainedInFuncOp[funcOp] = 0; + return funcOp.walk([&](CallOpInterface callOp) { + FuncOp calledFunction = getCalledFunction(callOp); + if (!calledFunction) + return WalkResult::interrupt(); + auto it = callerMap.try_emplace(calledFunction, DenseSet{}); + it.first->getSecond().insert(callOp); + if (calledBy[calledFunction].count(funcOp) == 0) { + calledBy[calledFunction].insert(funcOp); + numberCallOpsContainedInFuncOp[funcOp]++; + } + return WalkResult::advance(); + }); + }); + if (res.wasInterrupted()) + return failure(); + // Iteratively remove function operation that do not call any of the + // functions remaining in the callCounter map and add them to the worklist. + while (!numberCallOpsContainedInFuncOp.empty()) { + auto it = llvm::find_if(numberCallOpsContainedInFuncOp, + [](auto entry) { return entry.getSecond() == 0; }); + if (it == numberCallOpsContainedInFuncOp.end()) + return moduleOp.emitOpError( + "expected callgraph to be free of circular dependencies."); + orderedFuncOps.push_back(it->getFirst()); + for (auto callee : calledBy[it->getFirst()]) + numberCallOpsContainedInFuncOp[callee]--; + numberCallOpsContainedInFuncOp.erase(it); + } + return success(); +} + +namespace { +struct LinalgComprehensiveModuleBufferize + : public LinalgComprehensiveModuleBufferizeBase< + LinalgComprehensiveModuleBufferize> { + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } +}; +} // end namespace + +void LinalgComprehensiveModuleBufferize::runOnOperation() { + ModuleOp moduleOp = getOperation(); + + SmallVector orderedFuncOps; + DenseMap> callerMap; + if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) + return signalPassFailure(); + + DominanceInfo domInfo(moduleOp); + BufferizationAliasInfo aliasInfo(moduleOp); + // Interestingly, all function args that are not visible outside of a module + // can be fully bufferized inplace by guaranteeing the CallOp is bufferized + // inplace. Therefore, we just bufferize funcOp as if none of its results were + // inplaceable, detect which operands are cloned internally and decide what to + // do at call sites. + for (FuncOp funcOp : orderedFuncOps) { + // No body => no analysis. + if (funcOp.body().empty()) + continue; + + // In a first approximation: + // ========================= + // If the function is called, we can allocate on the caller side which lets + // us force inplace arguments at function boundaries. + // TODO: do not rely on this behavior. + if (callerMap.find(funcOp) != callerMap.end()) + for (BlockArgument bbArg : funcOp.getArguments()) + if (bbArg.getType().isa()) + setInPlaceFuncArgument(bbArg); + + // If the analysis fails, just return. + if (failed(inPlaceAnalysisFuncOpBody(funcOp, aliasInfo, domInfo))) { + signalPassFailure(); + return; + } + + // TODO: Bufferization phase. + } + // Don't drop the attributes if we only want to report the analysis. + if (testAnalysisOnly) + return; + + // Post-pass cleanup of inplaceable attributes. + moduleOp.walk( + [&](Operation *op) { op->removeAttr(kInPlaceResultsAttrName); }); + moduleOp.walk([&](FuncOp op) { + for (BlockArgument bbArg : op.getArguments()) + removeInPlaceFuncArgument(bbArg); + }); +} + +std::unique_ptr mlir::createLinalgComprehensiveModuleBufferizePass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -0,0 +1,84 @@ +// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize=test-analysis-only -split-input-file | FileCheck %s + +func private @foo(tensor<64xf32>) + +// CHECK-LABEL: dependence_through_call +func @dependence_through_call(%I : tensor<64xf32> {linalg.inplaceable = true}) { + %f1 = constant 1.000000e+00 : f32 + %f2 = constant 2.000000e+00 : f32 + + // 2. %B already bufferizes inplace, %A would alias and have a different + // value. The calls to `foo` are determined to read conservatively, so %A + // cannot bufferize inplace. + // CHECK: fill + // CHECK-SAME: {__inplace_results_attr__ = ["false"]} + %A = linalg.fill(%f1, %I) : f32, tensor<64xf32> -> tensor<64xf32> + + // 1. Bufferizes inplace: no alias to %A is yet possible. + // CHECK: fill + // CHECK-SAME: {__inplace_results_attr__ = ["true"]} + %B = linalg.fill(%f2, %I) : f32, tensor<64xf32> -> tensor<64xf32> + + call @foo(%A) : (tensor<64xf32>) -> () + call @foo(%B) : (tensor<64xf32>) -> () + + return +} + +// ----- + +func private @foo(tensor<64xf32>) + +func private @bar(%A : tensor<64xf32>) { + call @foo(%A) : (tensor<64xf32>) -> () + return +} + +func @read_dependence_through_scf_and_call( + %I : tensor<64xf32> {linalg.inplaceable = true}, + %I2 : tensor<64xf32> {linalg.inplaceable = true}) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c10 = constant 10 : index + %f1 = constant 1.000000e+00 : f32 + %f2 = constant 2.000000e+00 : f32 + + // 5. %B bufferizes inplace, %A would alias and have a different value. + // The calls to `foo` are determined to read conservatively, so %A cannot + // bufferize inplace. + // CHECK: fill + // CHECK-SAME: {__inplace_results_attr__ = ["false"]} + %A = linalg.fill(%f1, %I) : f32, tensor<64xf32> -> tensor<64xf32> + + // 4. Bufferizes inplace: no alias to %A is yet possible. + // CHECK: fill + // CHECK-SAME: {__inplace_results_attr__ = ["true"]} + %B = linalg.fill(%f2, %I) : f32, tensor<64xf32> -> tensor<64xf32> + + // 3. Does not read or write, bufferizes inplace. + // CHECK: scf.for + // CHECK: {__inplace_results_attr__ = ["true", "true"]} + %r:2 = scf.for %i = %c0 to %c10 step %c1 iter_args(%0 = %A, %1 = %B) + -> (tensor<64xf32>, tensor<64xf32>) + { + scf.yield %0, %1 : tensor<64xf32>, tensor<64xf32> + } + call @foo(%r#0) : (tensor<64xf32>) -> () + call @foo(%r#1) : (tensor<64xf32>) -> () + + // 2. %B2 already bufferizes inplace, %A2 would alias and have a different + // value. The calls to `foo` are determined to read conservatively, so %A2 + // cannot bufferize inplace. + // CHECK: fill + // CHECK-SAME: {__inplace_results_attr__ = ["false"]} + %A2 = linalg.fill(%f1, %I2) : f32, tensor<64xf32> -> tensor<64xf32> + + // 1. Bufferizes inplace: no alias to %A2 is yet possible. + // CHECK: fill + // CHECK-SAME: {__inplace_results_attr__ = ["true"]} + %B2 = linalg.fill(%f2, %I2) : f32, tensor<64xf32> -> tensor<64xf32> + + call @bar(%A2) : (tensor<64xf32>) -> () + call @bar(%B2) : (tensor<64xf32>) -> () + return +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize -split-input-file -verify-diagnostics + +// ----- + +// expected-error @-3 {{expected callgraph to be free of circular dependencies}} + +func @foo() { + call @bar() : () -> () + return +} + +func @bar() { + call @foo() : () -> () + return +}