diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp @@ -20,8 +20,78 @@ using namespace mlir; using namespace mlir::bufferization; +/// Resolve all operands that are also used inside of repetitive regions of the +/// same op. Such cases are not fully supported by One-Shot Bufferize. +/// +/// E.g.: +/// %r = scf.for ... iter_args(%t = %tensor) -> tensor { +/// "some_use"(%tensor) +/// ... +/// } +/// +/// Is converted to: +/// %tensor_copy = bufferization.alloc_tensor copy(%tensor) +/// %r = scf.for ... iter_args(%t = %tensor) -> tensor { +/// "some_use"(%tensor_copy) +/// ... +/// } +static void resolveUsesInRepetitveRegions(Operation *op, + const BufferizationOptions &options) { + IRRewriter rewriter(op->getContext()); + AnalysisState state(options); + + // Look for repetitive ops (loops). + op->walk([&](RegionBranchOpInterface regionBranchOp) { + // Skip non-bufferizable ops. + auto bufferizableOp = options.dynCastBufferizableOp(regionBranchOp); + if (!bufferizableOp) + return WalkResult::advance(); + + // Find all operands that are also used inside of a repetitve region of this + // op. + for (OpOperand &opOperand : regionBranchOp->getOpOperands()) { + Value operand = opOperand.get(); + // Skip non-tensor operands. + if (!operand.getType().isa()) + continue; + // Skip operands that do not bufferize to memory writes. + if (!bufferizableOp.bufferizesToMemoryWrite(opOperand, state)) + continue; + + // Gather all uses inside repetitive regions. + SmallVector usesInsideRegion; + for (OpOperand &use : operand.getUses()) { + Operation *owner = use.getOwner(); + if (!regionBranchOp->isProperAncestor(owner)) + continue; + for (Region &r : regionBranchOp->getRegions()) + if (r.findAncestorOpInRegion(*owner) && + regionBranchOp.isRepetitiveRegion(r.getRegionNumber())) + usesInsideRegion.push_back(&use); + } + // Nothing to do if the operand is not used inside a repetitive region. + if (usesInsideRegion.empty()) + continue; + + // Insert a tensor copy and replace all uses inside of repetitive regions. + rewriter.setInsertionPoint(regionBranchOp); + auto tensorCopy = rewriter.create( + regionBranchOp->getLoc(), operand.getType().cast(), + /*dynamicSizes=*/ValueRange(), + /*copy=*/operand, /*memory_space=*/IntegerAttr()); + for (OpOperand *use : usesInsideRegion) + use->set(tensorCopy); + } + + return WalkResult::advance(); + }); +} + LogicalResult mlir::bufferization::insertTensorCopies( Operation *op, const OneShotBufferizationOptions &options) { + // Preprocessing: Resolve currently unsupported bufferization cases. + resolveUsesInRepetitveRegions(op, options); + OneShotAnalysisState state(op, options); // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize // analysis depending on whether function boundary bufferization is enabled or diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir @@ -168,7 +168,8 @@ %c16 = arith.constant 16 : index // Hoisted alloc. - // CHECK: %[[ALLOC:.*]] = memref.alloc() {alignment = 128 : i64} : memref<8x16xf32> + // CHECK: %[[ALLOC:.*]] = memref.alloc() {alignment = 128 : i64} : memref<128x192xf32> + // CHECK: memref.copy %[[C]], %[[ALLOC]] // CHECK: scf.for %[[I:.*]] = %0 = scf.for %arg3 = %c0 to %c128 step %c8 iter_args(%arg4 = %C) -> (tensor<128x192xf32>) { @@ -180,12 +181,14 @@ %3 = tensor.extract_slice %B[0, %arg5] [256, 16] [1, 1] : tensor<256x192xf32> to tensor<256x16xf32> - // %4 does not match an insert_slice, it cannot be bufferized inplace and needs to alloc. + // C was already replaced with a copy by preprocessing, so no copy is + // needed here. + // CHECK: %[[C_SLICE:.*]] = memref.subview %[[ALLOC]] %4 = tensor.extract_slice %C[%arg3, %arg5] [8, 16] [1, 1] : tensor<128x192xf32> to tensor<8x16xf32> // linalg.fill is inplace. - // CHECK: linalg.fill ins(%{{.*}} : f32) outs(%[[ALLOC]] : memref<8x16xf32>) + // CHECK: linalg.fill ins(%{{.*}} : f32) outs(%[[C_SLICE]] %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<8x16xf32>) -> tensor<8x16xf32> // CHECK: scf.for %[[K:.*]] = @@ -196,7 +199,7 @@ tensor<256x16xf32> to tensor<32x16xf32> // linalg.matmul is inplace as well as the enclosing scf.for. - // CHECK: linalg.matmul ins({{.*}} outs(%[[ALLOC]] + // CHECK: linalg.matmul ins({{.*}} outs(%[[C_SLICE]] %10 = linalg.matmul ins(%8, %9 : tensor<8x32xf32>, tensor<32x16xf32>) outs(%arg8 : tensor<8x16xf32>) -> tensor<8x16xf32> @@ -207,15 +210,16 @@ // that is not in place. So we must insert a copy of the small buffer into // the bigger buffer. // CHECK: %[[T:.*]] = memref.subview %[[C]][%[[I]], %[[J]]] [8, 16] [1, 1] - // CHECK: memref.copy %[[ALLOC]], %[[T]] + // CHECK: memref.copy %[[C_SLICE]], %[[T]] %7 = tensor.insert_slice %6 into %arg6[%arg3, %arg5] [8, 16] [1, 1] : tensor<8x16xf32> into tensor<128x192xf32> - // CHECK: memref.dealloc %[[ALLOC]] scf.yield %7 : tensor<128x192xf32> } scf.yield %2 : tensor<128x192xf32> } + + // CHECK: memref.dealloc %[[ALLOC]] return %0 : tensor<128x192xf32> } diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -233,15 +233,17 @@ // CHECK-LABEL: func @scf_for_yield_non_equivalent( // CHECK-SAME: %[[t:.*]]: memref, %lb : index, %ub : index, %step : index) -> tensor { @@ -709,3 +711,34 @@ %f1 = tensor.extract %r0#1[%step] : tensor return %f0, %f1: f32, f32 } + +// ----- + +// CHECK-LABEL: func @scf_for_yield_alias_of_non_equivalent( +func.func @scf_for_yield_alias_of_non_equivalent(%sz: index) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 5.0 : f32 + + // CHECK: %[[generate:.*]] = memref.alloc + %0 = tensor.generate %sz { + ^bb0(%i: index): + tensor.yield %cst : f32 + } : tensor + + // A copy is inserted because %t is used inside the loop. + // CHECK: %[[generate_copy:.*]] = memref.alloc + // CHECK: memref.copy %[[generate]], %[[generate_copy]] + // CHECK: scf.for + %r = scf.for %iv = %c0 to %sz step %c1 iter_args(%t = %0) -> tensor { + %iv_sub = arith.subi %iv, %c1 : index + // CHECK: memref.subview %[[generate_copy]] + %ll = tensor.extract_slice %0[%iv_sub][%sz][1] : tensor to tensor + %l = tensor.extract %ll[%c0] : tensor + %double = arith.mulf %cst, %l : f32 + // CHECK: memref.store %{{.*}}, %[[generate]] + %s = tensor.insert %double into %t[%iv] : tensor + scf.yield %s : tensor + } + return %r : tensor +}