diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h @@ -20,7 +20,10 @@ struct InitTensorEliminationStep : public bufferization::PostAnalysisStep { /// A function that matches anchor OpOperands for InitTensorOp elimination. - using AnchorMatchFn = std::function; + /// If an OpOperand is matched, the function should populate the SmallVector + /// with all values that are needed during `RewriteFn` to produce the + /// replacement value. + using AnchorMatchFn = std::function &)>; /// A function that rewrites matched anchors. using RewriteFn = std::function; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/Operation.h" using namespace mlir; @@ -444,6 +445,79 @@ } // namespace +/// Return true if all `neededValues` are in scope at the given +/// `insertionPoint`. +static bool +neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, + Operation *insertionPoint, + const SmallVector &neededValues) { + for (Value val : neededValues) { + if (auto bbArg = val.dyn_cast()) { + Block *owner = bbArg.getOwner(); + if (!owner->findAncestorOpInBlock(*insertionPoint)) + return false; + } else { + auto opResult = val.cast(); + if (!domInfo.dominates(opResult.getOwner(), insertionPoint)) + return false; + } + } + return true; +} + +/// Return true if the given `insertionPoint` dominates all uses of +/// `initTensorOp`. +static bool insertionPointDominatesUses(const DominanceInfo &domInfo, + Operation *insertionPoint, + Operation *initTensorOp) { + for (Operation *user : initTensorOp->getUsers()) + if (!domInfo.dominates(insertionPoint, user)) + return false; + return true; +} + +/// Find a valid insertion point for a replacement of `initTensorOp`, assuming +/// that the replacement may use any value from `neededValues`. +static Operation * +findValidInsertionPoint(Operation *initTensorOp, + const SmallVector &neededValues) { + DominanceInfo domInfo; + + // Gather all possible insertion points: the location of `initTensorOp` and + // right after the definition of each value in `neededValues`. + SmallVector insertionPointCandidates; + insertionPointCandidates.push_back(initTensorOp); + for (Value val : neededValues) { + // Note: The anchor op is using all of `neededValues`, so: + // * in case of a block argument: There must be at least one op in the block + // (the anchor op or one of its parents). + // * in case of an OpResult: There must be at least one op right after the + // defining op (the anchor op or one of its + // parents). + if (auto bbArg = val.dyn_cast()) { + insertionPointCandidates.push_back( + &bbArg.getOwner()->getOperations().front()); + } else { + insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode()); + } + } + + // Select first matching insertion point. + for (Operation *insertionPoint : insertionPointCandidates) { + // Check if all needed values are in scope. + if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint, + neededValues)) + continue; + // Check if the insertion point is before all uses. + if (!insertionPointDominatesUses(domInfo, insertionPoint, initTensorOp)) + continue; + return insertionPoint; + } + + // No suitable insertion point was found. + return nullptr; +} + /// Try to eliminate InitTensorOps inside `op`. An InitTensorOp is replaced /// with the the result of `rewriteFunc` if it is anchored on a matching /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def @@ -462,8 +536,10 @@ // Skip operands that do not bufferize inplace. if (!aliasInfo.isInPlace(operand)) continue; + // All values that are needed to create the replacement op. + SmallVector neededValues; // Is this a matching OpOperand? - if (!anchorMatchFunc(operand)) + if (!anchorMatchFunc(operand, neededValues)) continue; SetVector maybeInitTensor = state.findValueInReverseUseDefChain(operand.get(), [&](Value val) { @@ -492,8 +568,14 @@ return WalkResult::skip(); Value initTensor = maybeInitTensor.front(); + // Find a suitable insertion point. + Operation *insertionPoint = + findValidInsertionPoint(initTensor.getDefiningOp(), neededValues); + if (!insertionPoint) + continue; + // Create a replacement for the InitTensorOp. - b.setInsertionPoint(initTensor.getDefiningOp()); + b.setInsertionPoint(insertionPoint); Value replacement = rewriteFunc(b, initTensor.getLoc(), operand); if (!replacement) continue; @@ -552,7 +634,7 @@ return eliminateInitTensors( op, state, aliasInfo, /*anchorMatchFunc=*/ - [&](OpOperand &operand) { + [&](OpOperand &operand, SmallVector &neededValues) { auto insertSliceOp = dyn_cast(operand.getOwner()); if (!insertSliceOp) @@ -560,7 +642,19 @@ // Only inplace bufferized InsertSliceOps are eligible. if (!aliasInfo.isInPlace(insertSliceOp->getOpOperand(1) /*dest*/)) return false; - return &operand == &insertSliceOp->getOpOperand(0) /*source*/; + if (&operand != &insertSliceOp->getOpOperand(0) /*source*/) + return false; + + // Collect all values that are needed to construct the replacement op. + neededValues.append(insertSliceOp.offsets().begin(), + insertSliceOp.offsets().end()); + neededValues.append(insertSliceOp.sizes().begin(), + insertSliceOp.sizes().end()); + neededValues.append(insertSliceOp.strides().begin(), + insertSliceOp.strides().end()); + neededValues.push_back(insertSliceOp.dest()); + + return true; }, /*rewriteFunc=*/ [](OpBuilder &b, Location loc, OpOperand &operand) { diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-memref init-tensor-elimination" -split-input-file | FileCheck %s +// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-memref init-tensor-elimination" -canonicalize -split-input-file | FileCheck %s // ----- @@ -62,3 +62,62 @@ return %r1: tensor } + +// ----- + +// CHECK: func @insertion_point_inside_loop( +// CHECK-SAME: %[[t:.*]]: memref, %[[sz:.*]]: index) +func @insertion_point_inside_loop(%t : tensor, %sz : index) -> (tensor) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c5 = arith.constant 5 : index + + // CHECK-NOT: memref.alloc + %blank = linalg.init_tensor [5] : tensor<5xf32> + + // CHECK: scf.for %[[iv:.*]] = %{{.*}} to %[[sz]] step %{{.*}} { + %r = scf.for %iv = %c0 to %sz step %c5 iter_args(%bb = %t) -> (tensor) { + // CHECK: %[[subview:.*]] = memref.subview %[[t]][%[[iv]]] [5] [1] + %iv_i32 = arith.index_cast %iv : index to i32 + %f = arith.sitofp %iv_i32 : i32 to f32 + + // CHECK: linalg.fill(%{{.*}}, %[[subview]]) + %filled = linalg.fill(%f, %blank) : f32, tensor<5xf32> -> tensor<5xf32> + + // CHECK-NOT: memref.copy + %inserted = tensor.insert_slice %filled into %bb[%iv][5][1] : tensor<5xf32> into tensor + scf.yield %inserted : tensor + } + + return %r : tensor +} + +// ----- + +// CHECK: func @insertion_point_outside_loop( +// CHECK-SAME: %[[t:.*]]: memref, %[[sz:.*]]: index, %[[idx:.*]]: index) +func @insertion_point_outside_loop(%t : tensor, %sz : index, + %idx : index) -> (tensor) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c5 = arith.constant 5 : index + + // CHECK-NOT: memref.alloc + // CHECK: %[[subview:.*]] = memref.subview %[[t]][%[[idx]]] [5] [1] + %blank = linalg.init_tensor [5] : tensor<5xf32> + + // CHECK: scf.for %[[iv:.*]] = %{{.*}} to %[[sz]] step %{{.*}} { + %r = scf.for %iv = %c0 to %sz step %c5 iter_args(%bb = %t) -> (tensor) { + %iv_i32 = arith.index_cast %iv : index to i32 + %f = arith.sitofp %iv_i32 : i32 to f32 + + // CHECK: linalg.fill(%{{.*}}, %[[subview]]) + %filled = linalg.fill(%f, %blank) : f32, tensor<5xf32> -> tensor<5xf32> + + // CHECK-NOT: memref.copy + %inserted = tensor.insert_slice %filled into %bb[%idx][5][1] : tensor<5xf32> into tensor + scf.yield %inserted : tensor + } + + return %r : tensor +}