diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -467,6 +467,50 @@ return true; } + LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, + const AnalysisState &state) const { + auto bufferizableOp = cast(op); + if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) + return failure(); + + // According to the `getAliasing...` implementations, a bufferized OpResult + // may alias only with the corresponding bufferized init_arg and with no + // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg; + // but not with any other OpOperand. If a corresponding OpResult/init_arg + // pair bufferizes to equivalent buffers, this aliasing requirement is + // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy. + // (New buffer copies do not alias with any buffer.) + auto forOp = cast(op); + auto yieldOp = + cast(forOp.getLoopBody().front().getTerminator()); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(yieldOp); + + // Indices of all iter_args that have tensor type. These are the ones that + // are bufferized. + DenseSet indices = getTensorIndices(forOp.getInitArgs()); + // For every yielded value, is the value equivalent to its corresponding + // bbArg? + DenseSet equivalentYields = getEquivalentBuffers( + forOp.getRegionIterArgs(), yieldOp.getResults(), state); + SmallVector yieldValues; + for (int64_t idx = 0; idx < yieldOp.getResults().size(); ++idx) { + Value value = yieldOp.getResults()[idx]; + if (!indices.contains(idx) || equivalentYields.contains(idx)) { + yieldValues.push_back(value); + continue; + } + Value alloc = rewriter.create( + yieldOp.getLoc(), value.getType().cast(), + /*dynamicSizes=*/ValueRange(), value, /*escape=*/true); + yieldValues.push_back(alloc); + } + + rewriter.updateRootInPlace( + yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); }); + return success(); + } + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto forOp = cast(op); @@ -631,6 +675,77 @@ return true; } + LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, + const AnalysisState &state) const { + auto bufferizableOp = cast(op); + if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) + return failure(); + + // According to the `getAliasing...` implementations, a bufferized OpResult + // may alias only with the corresponding bufferized init_arg and with no + // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg; + // but not with any other OpOperand. If a corresponding OpResult/init_arg + // pair bufferizes to equivalent buffers, this aliasing requirement is + // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy. + // (New buffer copies do not alias with any buffer.) + OpBuilder::InsertionGuard g(rewriter); + auto whileOp = cast(op); + auto conditionOp = whileOp.getConditionOp(); + auto yieldOp = whileOp.getYieldOp(); + + // Indices of all bbArgs that have tensor type. These are the ones that + // are bufferized. The "before" and "after" regions may have different args. + DenseSet indicesBefore = getTensorIndices(whileOp.getInits()); + DenseSet indicesAfter = + getTensorIndices(whileOp.getAfterArguments()); + + // For every yielded value, is the value equivalent to its corresponding + // bbArg? + DenseSet equivalentYieldsBefore = getEquivalentBuffers( + whileOp.getBeforeArguments(), conditionOp.getArgs(), state); + DenseSet equivalentYieldsAfter = getEquivalentBuffers( + whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state); + + // Update "before" region. + rewriter.setInsertionPoint(conditionOp); + SmallVector beforeYieldValues; + for (int64_t idx = 0; idx < conditionOp.getArgs().size(); ++idx) { + Value value = conditionOp.getArgs()[idx]; + if (!indicesBefore.contains(idx) || + equivalentYieldsBefore.contains(idx)) { + beforeYieldValues.push_back(value); + continue; + } + Value alloc = rewriter.create( + conditionOp.getLoc(), value.getType().cast(), + /*dynamicSizes=*/ValueRange(), value, /*escape=*/true); + beforeYieldValues.push_back(alloc); + } + rewriter.updateRootInPlace(conditionOp, [&]() { + conditionOp.getArgsMutable().assign(beforeYieldValues); + }); + + // Update "after" region. + rewriter.setInsertionPoint(yieldOp); + SmallVector afterYieldValues; + for (int64_t idx = 0; idx < yieldOp.getResults().size(); ++idx) { + Value value = yieldOp.getResults()[idx]; + if (!indicesAfter.contains(idx) || equivalentYieldsAfter.contains(idx)) { + afterYieldValues.push_back(value); + continue; + } + Value alloc = rewriter.create( + yieldOp.getLoc(), value.getType().cast(), + /*dynamicSizes=*/ValueRange(), value, /*escape=*/true); + afterYieldValues.push_back(alloc); + } + rewriter.updateRootInPlace(yieldOp, [&]() { + yieldOp.getResultsMutable().assign(afterYieldValues); + }); + + return success(); + } + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto whileOp = cast(op); diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir @@ -0,0 +1,109 @@ +// RUN: mlir-opt %s -tensor-copy-insertion="allow-return-allocs" -allow-unregistered-dialect -split-input-file | FileCheck %s +// RUN: mlir-opt %s -tensor-copy-insertion="bufferize-function-boundaries allow-return-allocs" -split-input-file -o /dev/null + +// CHECK-LABEL: func @scf_for( +// CHECK-SAME: %[[A:.*]]: tensor, %[[B:.*]]: tensor +func.func @scf_for(%A : tensor, %B : tensor, + %lb : index, %ub : index, %step : index) + -> (tensor, tensor) +{ + // CHECK: %[[A_copy:.*]] = bufferization.alloc_tensor() copy(%[[A]]) {escape = false} : tensor + // CHECK: %[[B_copy:.*]] = bufferization.alloc_tensor() copy(%[[B]]) {escape = false} : tensor + // CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[iter1:.*]] = %[[A_copy]], %[[iter2:.*]] = %[[B_copy]]) + %r0:2 = scf.for %i = %lb to %ub step %step iter_args(%tA = %A, %tB = %B) + -> (tensor, tensor) + { + // CHECK: scf.yield %[[iter1]], %[[iter2]] + scf.yield %tA, %tB : tensor, tensor + } + + return %r0#0, %r0#1 : tensor, tensor +} + +// ----- + +// CHECK-LABEL: func @scf_for_swapping_yields( +// CHECK-SAME: %[[A:.*]]: tensor, %[[B:.*]]: tensor +func.func @scf_for_swapping_yields(%A : tensor, %B : tensor, + %lb : index, %ub : index, %step : index) + -> (tensor, tensor) +{ + // CHECK: %[[A_copy:.*]] = bufferization.alloc_tensor() copy(%[[A]]) {escape = false} : tensor + // CHECK: %[[B_copy:.*]] = bufferization.alloc_tensor() copy(%[[B]]) {escape = false} : tensor + // CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[iter1:.*]] = %[[A_copy]], %[[iter2:.*]] = %[[B_copy]]) + %r0:2 = scf.for %i = %lb to %ub step %step iter_args(%tA = %A, %tB = %B) + -> (tensor, tensor) + { + // Yield tensors in different order. + // CHECK-DAG: %[[yield1:.*]] = bufferization.alloc_tensor() copy(%[[iter2]]) {escape = true} : tensor + // CHECK-DAG: %[[yield2:.*]] = bufferization.alloc_tensor() copy(%[[iter1]]) {escape = true} : tensor + // CHECK: scf.yield %[[yield1]], %[[yield2]] + scf.yield %tB, %tA : tensor, tensor + } + + return %r0#0, %r0#1 : tensor, tensor +} + +// ----- + +// CHECK-LABEL: func @scf_while( +// CHECK-SAME: %[[A:.*]]: tensor<5xi1>, %[[B:.*]]: tensor<5xi1> +func.func @scf_while(%A: tensor<5xi1>, %B: tensor<5xi1>, %idx: index) + -> (tensor<5xi1>, tensor<5xi1>) +{ + // CHECK: %[[A_copy:.*]] = bufferization.alloc_tensor() copy(%[[A]]) {escape = false} : tensor<5xi1> + // CHECK: %[[B_copy:.*]] = bufferization.alloc_tensor() copy(%[[B]]) {escape = false} : tensor<5xi1> + // CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[A_copy]], %[[w1:.*]] = %[[B_copy]]) {{.*}} { + %r0, %r1 = scf.while (%w0 = %A, %w1 = %B) + : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) { + // CHECK: %[[condition:.*]] = tensor.extract %[[w0]] + %condition = tensor.extract %w0[%idx] : tensor<5xi1> + // Yield tensors in different order. + // CHECK: scf.condition(%[[condition]]) %[[w0]], %[[w1]] + scf.condition(%condition) %w0, %w1 : tensor<5xi1>, tensor<5xi1> + } do { + ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>): + // CHECK: } do { + // CHECK: ^bb0(%[[b0:.*]]: tensor<5xi1>, %[[b1:.*]]: tensor<5xi1>): + // CHECK: scf.yield %[[b0]], %[[b1]] + // CHECK: } + scf.yield %b0, %b1 : tensor<5xi1>, tensor<5xi1> + } + + return %r0, %r1 : tensor<5xi1>, tensor<5xi1> +} + +// ----- + +// CHECK-LABEL: func @scf_while_non_equiv_condition_and_body( +// CHECK-SAME: %[[A:.*]]: tensor<5xi1>, %[[B:.*]]: tensor<5xi1> +func.func @scf_while_non_equiv_condition_and_body(%A: tensor<5xi1>, + %B: tensor<5xi1>, + %idx: index) + -> (tensor<5xi1>, tensor<5xi1>) +{ + // CHECK: %[[A_copy:.*]] = bufferization.alloc_tensor() copy(%[[A]]) {escape = false} : tensor<5xi1> + // CHECK: %[[B_copy:.*]] = bufferization.alloc_tensor() copy(%[[B]]) {escape = false} : tensor<5xi1> + // CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[A_copy]], %[[w1:.*]] = %[[B_copy]]) {{.*}} { + %r0, %r1 = scf.while (%w0 = %A, %w1 = %B) + : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) { + // CHECK: %[[condition:.*]] = tensor.extract %[[w0]] + %condition = tensor.extract %w0[%idx] : tensor<5xi1> + // Yield tensors in different order. + // CHECK-DAG: %[[yield0:.*]] = bufferization.alloc_tensor() copy(%[[w1]]) {escape = true} : tensor<5xi1> + // CHECK-DAG: %[[yield1:.*]] = bufferization.alloc_tensor() copy(%[[w0]]) {escape = true} : tensor<5xi1> + // CHECK: scf.condition(%[[condition]]) %[[yield0]], %[[yield1]] + scf.condition(%condition) %w1, %w0 : tensor<5xi1>, tensor<5xi1> + } do { + ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>): + // CHECK: } do { + // CHECK: ^bb0(%[[b0:.*]]: tensor<5xi1>, %[[b1:.*]]: tensor<5xi1>): + // CHECK-DAG: %[[yield2:.*]] = bufferization.alloc_tensor() copy(%[[b1]]) {escape = true} : tensor<5xi1> + // CHECK-DAG: %[[yield3:.*]] = bufferization.alloc_tensor() copy(%[[b0]]) {escape = true} : tensor<5xi1> + // CHECK: scf.yield %[[yield2]], %[[yield3]] + // CHECK: } + scf.yield %b1, %b0 : tensor<5xi1>, tensor<5xi1> + } + + return %r0, %r1 : tensor<5xi1>, tensor<5xi1> +}