diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -230,6 +230,16 @@ /// bufferized or not. bool bufferizeFunctionBoundaries = false; + /// Certain ops have aliasing OpOperand/OpResult invariants (e.g., scf.for). + /// If this flag is set to `false`, those invariants are no longer enforced + /// with buffer copies. + /// + /// Note: Deactivating this flag can lead to incorrect bufferization results + /// when used incorrectly. This flag is useful with + /// `AlwaysCopyBufferizationState` which bufferizes all writing tensor + /// OpOperands out-of-place. + bool enforceAliasingInvariants = true; + /// This flag controls buffer types on function signatures. /// /// * InferLayoutMap: All function parameter types have a fully dynamic layout diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -468,6 +468,54 @@ return true; } + LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, + const AnalysisState &state) const { + auto bufferizableOp = cast(op); + if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) + return failure(); + + if (!state.getOptions().enforceAliasingInvariants) + return success(); + + // According to the `getAliasing...` implementations, a bufferized OpResult + // may alias only with the corresponding bufferized init_arg and with no + // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg; + // but not with any other OpOperand. If a corresponding OpResult/init_arg + // pair bufferizes to equivalent buffers, this aliasing requirement is + // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy. + // (New buffer copies do not alias with any buffer.) + auto forOp = cast(op); + auto yieldOp = + cast(forOp.getLoopBody().front().getTerminator()); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(yieldOp); + + // Indices of all iter_args that have tensor type. These are the ones that + // are bufferized. + DenseSet indices = getTensorIndices(forOp.getInitArgs()); + // For every yielded value, is the value equivalent to its corresponding + // bbArg? + DenseSet equivalentYields = getEquivalentBuffers( + forOp.getRegionIterArgs(), yieldOp.getResults(), state); + SmallVector yieldValues; + for (int64_t idx = 0; + idx < static_cast(yieldOp.getResults().size()); ++idx) { + Value value = yieldOp.getResults()[idx]; + if (!indices.contains(idx) || equivalentYields.contains(idx)) { + yieldValues.push_back(value); + continue; + } + Value alloc = rewriter.create( + yieldOp.getLoc(), value.getType().cast(), + /*dynamicSizes=*/ValueRange(), value, /*escape=*/true); + yieldValues.push_back(alloc); + } + + rewriter.updateRootInPlace( + yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); }); + return success(); + } + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto forOp = cast(op); @@ -632,6 +680,82 @@ return true; } + LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, + const AnalysisState &state) const { + auto bufferizableOp = cast(op); + if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) + return failure(); + + if (!state.getOptions().enforceAliasingInvariants) + return success(); + + // According to the `getAliasing...` implementations, a bufferized OpResult + // may alias only with the corresponding bufferized init_arg and with no + // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg; + // but not with any other OpOperand. If a corresponding OpResult/init_arg + // pair bufferizes to equivalent buffers, this aliasing requirement is + // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy. + // (New buffer copies do not alias with any buffer.) + OpBuilder::InsertionGuard g(rewriter); + auto whileOp = cast(op); + auto conditionOp = whileOp.getConditionOp(); + auto yieldOp = whileOp.getYieldOp(); + + // Indices of all bbArgs that have tensor type. These are the ones that + // are bufferized. The "before" and "after" regions may have different args. + DenseSet indicesBefore = getTensorIndices(whileOp.getInits()); + DenseSet indicesAfter = + getTensorIndices(whileOp.getAfterArguments()); + + // For every yielded value, is the value equivalent to its corresponding + // bbArg? + DenseSet equivalentYieldsBefore = getEquivalentBuffers( + whileOp.getBeforeArguments(), conditionOp.getArgs(), state); + DenseSet equivalentYieldsAfter = getEquivalentBuffers( + whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state); + + // Update "before" region. + rewriter.setInsertionPoint(conditionOp); + SmallVector beforeYieldValues; + for (int64_t idx = 0; + idx < static_cast(conditionOp.getArgs().size()); ++idx) { + Value value = conditionOp.getArgs()[idx]; + if (!indicesBefore.contains(idx) || + equivalentYieldsBefore.contains(idx)) { + beforeYieldValues.push_back(value); + continue; + } + Value alloc = rewriter.create( + conditionOp.getLoc(), value.getType().cast(), + /*dynamicSizes=*/ValueRange(), value, /*escape=*/true); + beforeYieldValues.push_back(alloc); + } + rewriter.updateRootInPlace(conditionOp, [&]() { + conditionOp.getArgsMutable().assign(beforeYieldValues); + }); + + // Update "after" region. + rewriter.setInsertionPoint(yieldOp); + SmallVector afterYieldValues; + for (int64_t idx = 0; + idx < static_cast(yieldOp.getResults().size()); ++idx) { + Value value = yieldOp.getResults()[idx]; + if (!indicesAfter.contains(idx) || equivalentYieldsAfter.contains(idx)) { + afterYieldValues.push_back(value); + continue; + } + Value alloc = rewriter.create( + yieldOp.getLoc(), value.getType().cast(), + /*dynamicSizes=*/ValueRange(), value, /*escape=*/true); + afterYieldValues.push_back(alloc); + } + rewriter.updateRootInPlace(yieldOp, [&]() { + yieldOp.getResultsMutable().assign(afterYieldValues); + }); + + return success(); + } + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto whileOp = cast(op); @@ -855,6 +979,42 @@ return BufferRelation::Equivalent; } + LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, + const AnalysisState &state) const { + auto bufferizableOp = cast(op); + if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) + return failure(); + + OpBuilder::InsertionGuard g(rewriter); + auto foreachThreadOp = cast(op); + for (OpResult opResult : foreachThreadOp->getOpResults()) { + SmallVector destOperands = + state.getAliasingOpOperand(opResult); + assert(destOperands.size() == 1 && + "expected exactly one aliasing OpOperand"); + assert(isa(destOperands.front()->getOwner()) && + "expected ParallelInsertSliceOp"); + + // Nothing to do if there is no conflict. + if (state.isInPlace(*destOperands.front())) + continue; + + // Create AllocTensorOp. + bool isYielded = state.isTensorYielded(opResult); + auto resultType = opResult.getType().cast(); + Value alloc = rewriter.create( + op->getLoc(), resultType, /*dynamicDims=*/ValueRange(), + /*copy=*/destOperands.front()->get(), + /*escape=*/isYielded); + + // Update terminator operand. + rewriter.updateRootInPlace(destOperands.front()->getOwner(), + [&]() { destOperands.front()->set(alloc); }); + } + + return success(); + } + LogicalResult bufferize(Operation *op, RewriterBase &b, BufferizationState &state) const { OpBuilder::InsertionGuard g(b); @@ -1010,6 +1170,11 @@ return BufferRelation::Equivalent; } + LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, + const AnalysisState &state) const { + return success(); + } + LogicalResult bufferize(Operation *op, RewriterBase &b, BufferizationState &state) const { // Will be bufferized as part of ForeachThreadOp. diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir @@ -0,0 +1,135 @@ +// RUN: mlir-opt %s -tensor-copy-insertion="allow-return-allocs" -allow-unregistered-dialect -split-input-file | FileCheck %s +// RUN: mlir-opt %s -tensor-copy-insertion="bufferize-function-boundaries allow-return-allocs" -split-input-file | FileCheck %s --check-prefix=CHECK-FUNC + +// CHECK-LABEL: func @scf_for( +// CHECK-SAME: %[[A:.*]]: tensor, %[[B:.*]]: tensor +func.func @scf_for(%A : tensor, %B : tensor, + %lb : index, %ub : index, %step : index) + -> (tensor, tensor) +{ + // CHECK: %[[A_copy:.*]] = bufferization.alloc_tensor() copy(%[[A]]) {escape = false} : tensor + // CHECK: %[[B_copy:.*]] = bufferization.alloc_tensor() copy(%[[B]]) {escape = false} : tensor + // CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[iter1:.*]] = %[[A_copy]], %[[iter2:.*]] = %[[B_copy]]) + %r0:2 = scf.for %i = %lb to %ub step %step iter_args(%tA = %A, %tB = %B) + -> (tensor, tensor) + { + // CHECK: scf.yield %[[iter1]], %[[iter2]] + scf.yield %tA, %tB : tensor, tensor + } + + return %r0#0, %r0#1 : tensor, tensor +} + +// ----- + +// CHECK-LABEL: func @scf_for_swapping_yields( +// CHECK-SAME: %[[A:.*]]: tensor, %[[B:.*]]: tensor +func.func @scf_for_swapping_yields(%A : tensor, %B : tensor, + %lb : index, %ub : index, %step : index) + -> (tensor, tensor) +{ + // CHECK: %[[A_copy:.*]] = bufferization.alloc_tensor() copy(%[[A]]) {escape = false} : tensor + // CHECK: %[[B_copy:.*]] = bufferization.alloc_tensor() copy(%[[B]]) {escape = false} : tensor + // CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[iter1:.*]] = %[[A_copy]], %[[iter2:.*]] = %[[B_copy]]) + %r0:2 = scf.for %i = %lb to %ub step %step iter_args(%tA = %A, %tB = %B) + -> (tensor, tensor) + { + // Yield tensors in different order. + // CHECK-DAG: %[[yield1:.*]] = bufferization.alloc_tensor() copy(%[[iter2]]) {escape = true} : tensor + // CHECK-DAG: %[[yield2:.*]] = bufferization.alloc_tensor() copy(%[[iter1]]) {escape = true} : tensor + // CHECK: scf.yield %[[yield1]], %[[yield2]] + scf.yield %tB, %tA : tensor, tensor + } + + return %r0#0, %r0#1 : tensor, tensor +} + +// ----- + +// CHECK-LABEL: func @scf_while( +// CHECK-SAME: %[[A:.*]]: tensor<5xi1>, %[[B:.*]]: tensor<5xi1> +func.func @scf_while(%A: tensor<5xi1>, %B: tensor<5xi1>, %idx: index) + -> (tensor<5xi1>, tensor<5xi1>) +{ + // CHECK: %[[A_copy:.*]] = bufferization.alloc_tensor() copy(%[[A]]) {escape = false} : tensor<5xi1> + // CHECK: %[[B_copy:.*]] = bufferization.alloc_tensor() copy(%[[B]]) {escape = false} : tensor<5xi1> + // CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[A_copy]], %[[w1:.*]] = %[[B_copy]]) {{.*}} { + %r0, %r1 = scf.while (%w0 = %A, %w1 = %B) + : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) { + // CHECK: %[[condition:.*]] = tensor.extract %[[w0]] + %condition = tensor.extract %w0[%idx] : tensor<5xi1> + // Yield tensors in different order. + // CHECK: scf.condition(%[[condition]]) %[[w0]], %[[w1]] + scf.condition(%condition) %w0, %w1 : tensor<5xi1>, tensor<5xi1> + } do { + ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>): + // CHECK: } do { + // CHECK: ^bb0(%[[b0:.*]]: tensor<5xi1>, %[[b1:.*]]: tensor<5xi1>): + // CHECK: scf.yield %[[b0]], %[[b1]] + // CHECK: } + scf.yield %b0, %b1 : tensor<5xi1>, tensor<5xi1> + } + + return %r0, %r1 : tensor<5xi1>, tensor<5xi1> +} + +// ----- + +// CHECK-LABEL: func @scf_while_non_equiv_condition_and_body( +// CHECK-SAME: %[[A:.*]]: tensor<5xi1>, %[[B:.*]]: tensor<5xi1> +func.func @scf_while_non_equiv_condition_and_body(%A: tensor<5xi1>, + %B: tensor<5xi1>, + %idx: index) + -> (tensor<5xi1>, tensor<5xi1>) +{ + // CHECK: %[[A_copy:.*]] = bufferization.alloc_tensor() copy(%[[A]]) {escape = false} : tensor<5xi1> + // CHECK: %[[B_copy:.*]] = bufferization.alloc_tensor() copy(%[[B]]) {escape = false} : tensor<5xi1> + // CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[A_copy]], %[[w1:.*]] = %[[B_copy]]) {{.*}} { + %r0, %r1 = scf.while (%w0 = %A, %w1 = %B) + : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) { + // CHECK: %[[condition:.*]] = tensor.extract %[[w0]] + %condition = tensor.extract %w0[%idx] : tensor<5xi1> + // Yield tensors in different order. + // CHECK-DAG: %[[yield0:.*]] = bufferization.alloc_tensor() copy(%[[w1]]) {escape = true} : tensor<5xi1> + // CHECK-DAG: %[[yield1:.*]] = bufferization.alloc_tensor() copy(%[[w0]]) {escape = true} : tensor<5xi1> + // CHECK: scf.condition(%[[condition]]) %[[yield0]], %[[yield1]] + scf.condition(%condition) %w1, %w0 : tensor<5xi1>, tensor<5xi1> + } do { + ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>): + // CHECK: } do { + // CHECK: ^bb0(%[[b0:.*]]: tensor<5xi1>, %[[b1:.*]]: tensor<5xi1>): + // CHECK-DAG: %[[yield2:.*]] = bufferization.alloc_tensor() copy(%[[b1]]) {escape = true} : tensor<5xi1> + // CHECK-DAG: %[[yield3:.*]] = bufferization.alloc_tensor() copy(%[[b0]]) {escape = true} : tensor<5xi1> + // CHECK: scf.yield %[[yield2]], %[[yield3]] + // CHECK: } + scf.yield %b1, %b0 : tensor<5xi1>, tensor<5xi1> + } + + return %r0, %r1 : tensor<5xi1>, tensor<5xi1> +} + +// ----- + +// CHECK-LABEL: func @scf_foreach_thread_out_of_place( +// CHECK-SAME: %[[arg0:.*]]: tensor<100xf32>, %[[arg1:.*]]: tensor<100xf32> +// CHECK-FUNC-LABEL: func @scf_foreach_thread_out_of_place( +func.func @scf_foreach_thread_out_of_place(%in: tensor<100xf32>, + %out: tensor<100xf32>) { + %c1 = arith.constant 1 : index + %num_threads = arith.constant 100 : index + + // CHECK-FUNC-NOT: alloc_tensor + // CHECK: %[[alloc:.*]] = bufferization.alloc_tensor() copy(%[[arg1]]) {escape = false} : tensor<100xf32> + // CHECK: scf.foreach_thread + %result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<100xf32> { + // CHECK: tensor.extract_slice + // CHECK: scf.foreach_thread.perform_concurrently + // CHECK: scf.foreach_thread.parallel_insert_slice %{{.*}} into %[[alloc]] + %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32> + scf.foreach_thread.perform_concurrently { + scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] : + tensor<1xf32> into tensor<100xf32> + } + } + return +}