diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -130,7 +130,7 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FormatVariadic.h" -#define DEBUG_TYPE "comprehensive-func-bufferize" +#define DEBUG_TYPE "comprehensive-module-bufferize" using namespace mlir; using namespace linalg; @@ -140,8 +140,8 @@ #define LDBG(X) LLVM_DEBUG(DBGS() << X) // Forward declarations. -static std::string printOperationInfo(Operation *); -static std::string printValueInfo(Value); +static std::string printOperationInfo(Operation *, bool prefix = true); +static std::string printValueInfo(Value, bool prefix = true); //===----------------------------------------------------------------------===// // Generic helpers. @@ -365,30 +365,31 @@ } /// Print the operation name and bufferization information. -static std::string printOperationInfo(Operation *op) { +static std::string printOperationInfo(Operation *op, bool prefix) { std::string result; llvm::raw_string_ostream os(result); AsmState state(op->getParentOfType()); - os << op->getName(); + StringRef tab = prefix ? "\n[" DEBUG_TYPE "]\t" : ""; + os << tab << op->getName(); SmallVector shapedOperands; for (OpOperand &opOperand : op->getOpOperands()) { std::string prefix = - llvm::formatv("\n\t-> #{0} ", opOperand.getOperandNumber()); + llvm::formatv("{0} -> #{1} ", tab, opOperand.getOperandNumber()); printTensorOrBufferInfo(prefix, opOperand.get(), state, os); } for (OpResult opResult : op->getOpResults()) { std::string prefix = - llvm::formatv("\n\t<- #{0} ", opResult.getResultNumber()); + llvm::formatv("{0} <- #{1} ", tab, opResult.getResultNumber()); printTensorOrBufferInfo(prefix, opResult, state, os); } return result; } /// Print the bufferization information for the defining op or block argument. -static std::string printValueInfo(Value value) { +static std::string printValueInfo(Value value, bool prefix) { auto *op = value.getDefiningOp(); if (op) - return printOperationInfo(op); + return printOperationInfo(op, prefix); // Print the block argument bufferization information. std::string result; llvm::raw_string_ostream os(result); @@ -552,6 +553,7 @@ .Case([&](scf::ForOp op) { return &op.getIterOpOperands()[result.getResultNumber()]; }) + .Case([&](InitTensorOp op) { return nullptr; }) .Case([&](InsertSliceOp op) { return &op->getOpOperand(1); }) .Case([&](LinalgOp op) { return op.getOutputTensorOperands()[result.getResultNumber()]; @@ -580,7 +582,7 @@ return None; return TypeSwitch(opOperand.getOwner()) // These terminators legitimately have no result. - .Case( + .Case( [&](auto op) { return OpResult(); }) // DimOp has no tensor result. .Case([&](auto op) { return None; }) @@ -759,10 +761,12 @@ void applyOnEquivalenceClass(Value v, function_ref fun) const; /// Print to `os`. - void print(raw_ostream &os) const; + void printAliases(raw_ostream &os) const; + void printEquivalences(raw_ostream &os) const; /// Print to `errs()`. - void dump() const { print(llvm::errs()); } + void dumpAliases() const { printAliases(llvm::errs()); } + void dumpEquivalences() const { printEquivalences(llvm::errs()); } private: /// Check that aliasInfo for `v` exists and return a reference to it. @@ -954,10 +958,12 @@ setInPlaceOpResult(result, InPlaceSpec::True); if (mergeAliases(result, operand.get())) mergeAliasesToFixedPoint(); + // Dump the updated alias analysis. + LLVM_DEBUG(dumpAliases()); if (bufferRelation == BufferRelation::Equivalent) equivalentInfo.unionSets(result, operand.get()); - // Dump the updated analysis. - LLVM_DEBUG(dump()); + // Dump the updated equivalence analysis. + LLVM_DEBUG(dumpEquivalences()); } /// Set the inPlace bufferization spec to false. @@ -984,7 +990,7 @@ Operation *opToBufferize = result.getDefiningOp(); Value root = (*maybeAliasingOperand)->get(); LDBG("----Start wouldCreateReadAfterWriteInterference\n"); - LDBG("--------rootValue: " << printValueInfo(root) << "\n"); + LDBG("--------aliasing rootValue: " << printValueInfo(root) << "\n"); // Collect: // 1. all the inplace write uses of some alias of `root`. @@ -1046,10 +1052,10 @@ // At this point, aliasingWriteOp properly dominates aliasingReadOp or // there is no clear dominance and we need to be conservative. LDBG("---->found RaW interference\n"); - LDBG(" Interfering read -> #" << uRead->getOperandNumber() << ":\n" + LDBG(" Interfering read -> #" << uRead->getOperandNumber() << ":" << printOperationInfo(aliasingReadOp) << '\n'); - LDBG(" Interfering write -> #" << uWrite->getOperandNumber() << ":\n" + LDBG(" Interfering write -> #" << uWrite->getOperandNumber() << ":" << printOperationInfo(aliasingWriteOp) << '\n'); LDBG("---->opportunity to clobber RaW interference\n"); @@ -1098,28 +1104,34 @@ } } -void BufferizationAliasInfo::print(raw_ostream &os) const { +void BufferizationAliasInfo::printAliases(raw_ostream &os) const { os << "\n/========================== AliasInfo " "==========================\n"; for (auto it : aliasInfo) { - os << "|\n| -- source: " << printValueInfo(it.getFirst()) << '\n'; + os << "|\n| -- source: " << printValueInfo(it.getFirst(), /*prefix=*/false) + << '\n'; for (auto v : it.getSecond()) - os << "| ---- target: " << printValueInfo(v) << '\n'; + os << "| ---- target: " << printValueInfo(v, /*prefix=*/false) << '\n'; } os << "|\n\\====================== End AliasInfo " "======================\n\n"; +} + +void BufferizationAliasInfo::printEquivalences(raw_ostream &os) const { os << "\n/********************* Equivalent Buffers *********************\n"; for (auto it = equivalentInfo.begin(), eit = equivalentInfo.end(); it != eit; ++it) { if (!it->isLeader()) continue; Value leader = it->getData(); - os << "|\n| -- leader: " << printValueInfo(leader) << '\n'; + os << "|\n| -- leader: " << printValueInfo(leader, /*prefix=*/false) + << '\n'; for (auto mit = equivalentInfo.member_begin(it), meit = equivalentInfo.member_end(); mit != meit; ++mit) { Value v = static_cast(*mit); - os << "| ---- equivalent member: " << printValueInfo(v) << '\n'; + os << "| ---- equivalent member: " << printValueInfo(v, /*prefix=*/false) + << '\n'; } } os << "|\n\\***************** End Equivalent Buffers *****************\n\n"; @@ -1195,12 +1207,13 @@ auto leaderIt = equivalentInfo.findLeader(valueToClobber); for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit; ++mit) { - /// Note: the "would write to memory after bufferization" condition is - /// verified by `candidateOp` since it would produce a value that - /// bufferizes to an equivalent buffer. Operation *candidateOp = mit->v.getDefiningOp(); if (!candidateOp) continue; + auto maybeAliasingOperand = getAliasingOpOperand(mit->v.cast()); + if (!maybeAliasingOperand || !*maybeAliasingOperand || + !bufferizesToMemoryWrite(**maybeAliasingOperand)) + continue; LDBG("---->clobbering candidate: " << printOperationInfo(candidateOp) << '\n'); if (domInfo.properlyDominates(aliasingWriteOp, candidateOp) && @@ -2311,7 +2324,12 @@ return success(); } -/// Analyze the `funcOp` body to determine which OpResults are inplaceable. +/// Analyze the `funcOp` body to determine which OpResults are inplaceable: +/// 1. First, analyze InsertSliceOp greedily: we almost never want to +/// bufferize the tensor "inserted into" to become out-of-place. +/// 2. Walk the other ops in reverse. This is a good starter heuristic. +/// ExtractSliceOps are interleaved with other ops in traversal order. +/// static LogicalResult inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo, const DominanceInfo &domInfo) { @@ -2321,26 +2339,22 @@ "expected a funcOp definition with a body"); // Collect ops so we can build our own traversal. - SmallVector extractSliceOps; + SmallVector otherOps; SmallVector insertSliceOps; - SmallVector nonSliceOps; funcOp.walk([&](Operation *op) { - if (auto extractSliceOp = dyn_cast(op)) - return extractSliceOps.push_back(extractSliceOp); if (auto insertSliceOp = dyn_cast(op)) return insertSliceOps.push_back(insertSliceOp); // No tensors => no buffers. if (none_of(op->getOperandTypes(), isaTensor) && none_of(op->getResultTypes(), isaTensor)) return; - nonSliceOps.push_back(op); + otherOps.push_back(op); }); - // Bufferize InsertSliceOp greedily: we almost never want to bufferize + // First, analyze InsertSliceOp greedily: we almost never want to bufferize // the tensor "inserted into" to become out-of-place. This implementation // does not distinguish between different InsertSliceOp. If we want // finer-grained behavior, we could order the InsertSliceOp with some metric. - // Walk InsertSliceOp in reverse for better interference behavior. for (InsertSliceOp insertSliceOp : reverse(insertSliceOps)) { OpOperand &destOpOperand = insertSliceOp->getOpOperand(1); if (failed(bufferizableInPlaceAnalysis( @@ -2349,23 +2363,27 @@ return failure(); } - // Analyze all ops that return a tensors, except ExtractSliceOp and - // InsertSliceOp which are handled separately. - // Walk other ops in reverse for better interference behavior. - for (Operation *op : reverse(nonSliceOps)) - for (OpOperand &opOperand : op->getOpOperands()) + // Walk ops in reverse for better interference analysis. + for (Operation *op : reverse(otherOps)) { + for (OpOperand &opOperand : op->getOpOperands()) { if (OpResult result = getInplaceableOpResult(opOperand)) if (result.getType().isa() && failed(bufferizableInPlaceAnalysis(opOperand, result, aliasInfo, domInfo))) return failure(); - - // Finally, bufferize ExtractSliceOp. - // Walk ExtractSliceOps in reverse for better clobbering behavior: it is - // easier to detect clobbers of smaller slices before larger ones. - for (ExtractSliceOp extractSliceOp : reverse(extractSliceOps)) - if (failed(bufferizableInPlaceAnalysis(extractSliceOp, aliasInfo, domInfo))) - return failure(); + } + // Special logic to analyze ExtractSliceOp. + // Note that ExtractSliceOp analysis needs to be interleaved with other ops + // to properly capture aliases. + // Walk ExtractSliceOps in reverse for better clobbering analysis behavior: + // it is easier to detect clobbers of smaller slices before larger ones. + if (auto extractSliceOp = dyn_cast(op)) { + if (failed( + bufferizableInPlaceAnalysis(extractSliceOp, aliasInfo, domInfo))) + return failure(); + continue; + } + } LDBG("End InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n'); diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -247,20 +247,20 @@ %C : tensor {linalg.inplaceable = true}) -> (tensor<4x4xf32>, tensor<4x4xf32>) { - // Step 3. %sB forward propagates to a write in %D but it is not inplace. + // Step 4. %sB forward propagates to a write in %D but it is not inplace. // So this is only ever read and can bufferize inplace. // CHECK: tensor.extract_slice // CHECK-SAME: {__inplace_results_attr__ = ["true"]} %sB = tensor.extract_slice %B[0, 0][4, 4][1, 1] : tensor to tensor<4x4xf32> - // Step 2. %sB has a read interference in %E, it does not bufferize inplace. + // Step 3. %sB has a read interference in %E, it does not bufferize inplace. // CHECK: linalg.matmul // CHECK-SAME: {__inplace_results_attr__ = ["false"]} %D = linalg.matmul ins(%B, %C: tensor, tensor) outs(%sB: tensor<4x4xf32>) -> tensor<4x4xf32> - // Step 4. %sC forward propagates to an inplace write in %E. + // Step 2. %sC forward propagates to an inplace write in %E. // %sC backward propagates to %C which is inplaceable. // As a consequence this is bufferized inplace. // CHECK: tensor.extract_slice @@ -298,7 +298,7 @@ // CHECK-SAME: {__inplace_results_attr__ = ["false"]} %sB = tensor.extract_slice %B[0, 0][4, 4][1, 1] : tensor to tensor<4x4xf32> - // Step 1. %sB backprops to the tensor.extract_slice producer which is not + // Step 3. %sB backprops to the tensor.extract_slice producer which is not // considered an interference. This bufferizes inplace. // CHECK: linalg.matmul // CHECK-SAME: {__inplace_results_attr__ = ["true"]} @@ -306,7 +306,7 @@ outs(%sB: tensor<4x4xf32>) -> tensor<4x4xf32> - // Step 3. %sC forward propagates to an inplace write in %E. + // Step 2. %sC forward propagates to an inplace write in %E. // %sC backward propagates to %C which is inplaceable. // As a consequence this is bufferized inplace. // CHECK: tensor.extract_slice @@ -482,7 +482,7 @@ %lb : index, %ub : index, %step : index) -> (tensor, tensor) { - // %r0 must be out of place because one use of %t in the subsequent production + // %r0 must be out of place because one use of %t in the subsequent production // of %r1 is read. // CHECK: scf.for // CHECK-NEXT: call @@ -503,7 +503,7 @@ scf.yield %t : tensor } - // %r2 must be out of place because one use of %t in the subsequent production + // %r2 must be out of place because one use of %t in the subsequent production // of %r3 is read. // CHECK: linalg.tiled_loop // CHECK-NEXT: call @@ -619,3 +619,86 @@ call @bar(%B2) : (tensor<64xf32>) -> () return } + +//===----------------------------------------------------------------------===// +// Transitive cases through extract_slice. +//===----------------------------------------------------------------------===// + +builtin.func @matmul_on_tensors( + %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + -> tensor<256x256xf32> +{ + %c0 = constant 0 : index + %cst_0 = constant 0.000000e+00 : f32 + %cst_1 = constant 1.000000e+00 : f32 + + %7 = linalg.init_tensor [256, 256] : tensor<256x256xf32> + + // CHECK: linalg.fill + // CHECK-SAME: {__inplace_results_attr__ = ["false"]} + // CHECK: linalg.fill + // CHECK-SAME: {__inplace_results_attr__ = ["false"]} + %8 = linalg.fill(%cst_0, %7) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %11 = linalg.fill(%cst_1, %7) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"]} + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"]} + // CHECK: linalg.matmul + // CHECK-SAME: {__inplace_results_attr__ = ["true"]} + %sA = tensor.extract_slice %8[0, 0][256, 16][1, 1]: tensor<256x256xf32> to tensor<256x16xf32> + %sB = tensor.extract_slice %11[0, 0][16, 256][1, 1]: tensor<256x256xf32> to tensor<16x256xf32> + %r = linalg.matmul + ins(%sA, %sB : tensor<256x16xf32>, tensor<16x256xf32>) + outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32> + + return %r : tensor<256x256xf32> +} + +// ----- + +builtin.func @matmul_on_tensors( + %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + -> tensor<256x256xf32> +{ + %c0 = constant 0 : index + %cst_0 = constant 0.000000e+00 : f32 + %cst_1 = constant 1.000000e+00 : f32 + + %7 = linalg.init_tensor [256, 256] : tensor<256x256xf32> + + // CHECK: linalg.fill + // CHECK-SAME: {__inplace_results_attr__ = ["true"]} + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["false"] + %8 = linalg.fill(%cst_0, %7) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %9 = vector.transfer_read %arg0[%c0, %c0], %cst_0 {in_bounds = [false, true]} : tensor<518x518xf32>, vector<256x256xf32> + %10 = vector.transfer_write %9, %8[%c0, %c0] {in_bounds = [true, true]} : vector<256x256xf32>, tensor<256x256xf32> + + // CHECK: linalg.fill + // CHECK-SAME: {__inplace_results_attr__ = ["true"]} + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["false"] + %11 = linalg.fill(%cst_1, %7) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %12 = vector.transfer_read %arg1[%c0, %c0], %cst_0 {in_bounds = [false, true]} : tensor<518x518xf32>, vector<256x256xf32> + %13 = vector.transfer_write %12, %11[%c0, %c0] {in_bounds = [true, true]} : vector<256x256xf32>, tensor<256x256xf32> + + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"]} + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"]} + // CHECK: linalg.matmul + // CHECK-SAME: {__inplace_results_attr__ = ["true"]} + %sA = tensor.extract_slice %10[0, 0][256, 16][1, 1]: tensor<256x256xf32> to tensor<256x16xf32> + %sB = tensor.extract_slice %13[0, 0][16, 256][1, 1]: tensor<256x256xf32> to tensor<16x256xf32> + %r = linalg.matmul + ins(%sA, %sB : tensor<256x16xf32>, tensor<16x256xf32>) + outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32> + + return %r : tensor<256x256xf32> +}