diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -762,13 +762,6 @@ OpBuilder::InsertionGuard g(rewriter); auto whileOp = cast(op); auto conditionOp = whileOp.getConditionOp(); - auto yieldOp = whileOp.getYieldOp(); - - // Indices of all bbArgs that have tensor type. These are the ones that - // are bufferized. The "before" and "after" regions may have different args. - DenseSet indicesBefore = getTensorIndices(whileOp.getInits()); - DenseSet indicesAfter = - getTensorIndices(whileOp.getAfterArguments()); // For every yielded value, is the value equivalent to its corresponding // bbArg? @@ -783,8 +776,9 @@ for (int64_t idx = 0; idx < static_cast(conditionOp.getArgs().size()); ++idx) { Value value = conditionOp.getArgs()[idx]; - if (!indicesBefore.contains(idx) || - equivalentYieldsBefore.contains(idx)) { + if (!value.getType().isa() || + (equivalentYieldsAfter.contains(idx) && + equivalentYieldsBefore.contains(idx))) { beforeYieldValues.push_back(value); continue; } @@ -799,27 +793,6 @@ conditionOp.getArgsMutable().assign(beforeYieldValues); }); - // Update "after" region. - rewriter.setInsertionPoint(yieldOp); - SmallVector afterYieldValues; - for (int64_t idx = 0; - idx < static_cast(yieldOp.getResults().size()); ++idx) { - Value value = yieldOp.getResults()[idx]; - if (!indicesAfter.contains(idx) || equivalentYieldsAfter.contains(idx)) { - afterYieldValues.push_back(value); - continue; - } - FailureOr alloc = - allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value, - /*escape=*/true, state.getOptions()); - if (failed(alloc)) - return failure(); - afterYieldValues.push_back(*alloc); - } - rewriter.updateRootInPlace(yieldOp, [&]() { - yieldOp.getResultsMutable().assign(afterYieldValues); - }); - return success(); } diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-allow-return-allocs-no-deallocs.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-allow-return-allocs-no-deallocs.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-allow-return-allocs-no-deallocs.mlir @@ -0,0 +1,20 @@ +// RUN: mlir-opt %s \ +// RUN: -one-shot-bufferize="allow-return-allocs create-deallocs=0" \ +// RUN: -split-input-file | \ +// RUN: FileCheck %s --dump-input=always + +// A regression test to check that different before and after argument types are +// bufferized successfully. +func.func @different_before_after_args() -> tensor { + %true = arith.constant true + %cst = arith.constant dense<0.0> : tensor + %0 = scf.while (%arg4 = %true) : (i1) -> (tensor) { + scf.condition(%true) %cst : tensor + } do { + ^bb0(%arg4: tensor): + scf.yield %true : i1 + } + return %0 : tensor +} + +// CHECK-LABEL: @different_before_after_args \ No newline at end of file diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir @@ -98,9 +98,7 @@ ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>): // CHECK: } do { // CHECK: ^bb0(%[[b0:.*]]: tensor<5xi1>, %[[b1:.*]]: tensor<5xi1>): - // CHECK-DAG: %[[yield2:.*]] = bufferization.alloc_tensor() copy(%[[b1]]) {bufferization.escape = [true]} : tensor<5xi1> - // CHECK-DAG: %[[yield3:.*]] = bufferization.alloc_tensor() copy(%[[b0]]) {bufferization.escape = [true]} : tensor<5xi1> - // CHECK: scf.yield %[[yield2]], %[[yield3]] + // CHECK: scf.yield %[[b1]], %[[b0]] // CHECK: } scf.yield %b1, %b0 : tensor<5xi1>, tensor<5xi1> } diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -430,8 +430,8 @@ %idx: index) -> (tensor<5xi1>, tensor<5xi1>) { - // CHECK: %[[clone1:.*]] = bufferization.clone %[[arg1]] - // CHECK: %[[clone0:.*]] = bufferization.clone %[[arg0]] + // CHECK-DAG: %[[clone1:.*]] = bufferization.clone %[[arg1]] + // CHECK-DAG: %[[clone0:.*]] = bufferization.clone %[[arg0]] // CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[clone0]], %[[w1:.*]] = %[[clone1]]) {{.*}} { %r0, %r1 = scf.while (%w0 = %arg0, %w1 = %arg1) : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) { @@ -454,19 +454,13 @@ // CHECK: } do { // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1>, %[[b1:.*]]: memref<5xi1>): // CHECK: memref.store %{{.*}}, %[[b0]] - // CHECK: %[[a3:.*]] = memref.alloc() {{.*}} : memref<5xi1> - // CHECK: memref.copy %[[b1]], %[[a3]] + // CHECK: %[[casted1:.*]] = memref.cast %[[b1]] + // CHECK: %[[casted0:.*]] = memref.cast %[[b0]] + // CHECK: %[[cloned1:.*]] = bufferization.clone %[[casted1]] // CHECK: memref.dealloc %[[b1]] - // CHECK: %[[a2:.*]] = memref.alloc() {{.*}} : memref<5xi1> - // CHECK: memref.copy %[[b0]], %[[a2]] + // CHECK: %[[cloned0:.*]] = bufferization.clone %[[casted0]] // CHECK: memref.dealloc %[[b0]] - // CHECK: %[[casted3:.*]] = memref.cast %[[a3]] - // CHECK: %[[casted2:.*]] = memref.cast %[[a2]] - // CHECK: %[[cloned2:.*]] = bufferization.clone %[[casted2]] - // CHECK: memref.dealloc %[[a2]] - // CHECK: %[[cloned3:.*]] = bufferization.clone %[[casted3]] - // CHECK: memref.dealloc %[[a3]] - // CHECK: scf.yield %[[cloned3]], %[[cloned2]] + // CHECK: scf.yield %[[cloned1]], %[[cloned0]] // CHECK: } %pos = "dummy.some_op"() : () -> (index) %val = "dummy.another_op"() : () -> (i1)