diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp @@ -376,17 +376,20 @@ // Determine the actual operand to introduce a clone for and rewire the // operand to point to the clone instead. - Value operand = - regionInterface.getSuccessorEntryOperands(argRegion->getRegionNumber()) - [llvm::find(it->getSuccessorInputs(), blockArg).getIndex()]; + auto operands = + regionInterface.getSuccessorEntryOperands(argRegion->getRegionNumber()); + size_t operandIndex = + llvm::find(it->getSuccessorInputs(), blockArg).getIndex() + + operands.getBeginOperandIndex(); + Value operand = parentOp->getOperand(operandIndex); + assert(operand == + operands[operandIndex - operands.getBeginOperandIndex()] && + "region interface operands don't match parentOp operands"); auto clone = introduceCloneBuffers(operand, parentOp); if (failed(clone)) return failure(); - auto op = llvm::find(parentOp->getOperands(), operand); - assert(op != parentOp->getOperands().end() && - "parentOp does not contain operand"); - parentOp->setOperand(op.getIndex(), *clone); + parentOp->setOperand(operandIndex, *clone); return success(); } diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir @@ -1222,3 +1222,61 @@ %1 = bufferization.clone %arg1 : memref to memref return %0 : memref } + +// ----- + +// CHECK-LABEL: func @while_two_arg +func @while_two_arg(%arg0: index) { + %a = memref.alloc(%arg0) : memref +// CHECK: %[[WHILE:.*]]:2 = scf.while (%[[ARG1:.*]] = %[[ALLOC:.*]], %[[ARG2:.*]] = %[[CLONE:.*]]) + scf.while (%arg1 = %a, %arg2 = %a) : (memref, memref) -> (memref, memref) { +// CHECK-NEXT: make_condition + %0 = "test.make_condition"() : () -> i1 +// CHECK-NEXT: bufferization.clone %[[ARG2]] +// CHECK-NEXT: memref.dealloc %[[ARG2]] + scf.condition(%0) %arg1, %arg2 : memref, memref + } do { + ^bb0(%arg1: memref, %arg2: memref): +// CHECK: %[[ALLOC2:.*]] = memref.alloc + %b = memref.alloc(%arg0) : memref +// CHECK: memref.dealloc %[[ARG2]] +// CHECK: %[[CLONE2:.*]] = bufferization.clone %[[ALLOC2]] +// CHECK: memref.dealloc %[[ALLOC2]] + scf.yield %arg1, %b : memref, memref + } +// CHECK: } +// CHECK-NEXT: memref.dealloc %[[WHILE]]#1 +// CHECK-NEXT: memref.dealloc %[[ALLOC]] +// CHECK-NEXT: return + return +} + +// ----- + +func @while_three_arg(%arg0: index) { +// CHECK: %[[ALLOC:.*]] = memref.alloc + %a = memref.alloc(%arg0) : memref +// CHECK-NEXT: %[[CLONE1:.*]] = bufferization.clone %[[ALLOC]] +// CHECK-NEXT: %[[CLONE2:.*]] = bufferization.clone %[[ALLOC]] +// CHECK-NEXT: %[[CLONE3:.*]] = bufferization.clone %[[ALLOC]] +// CHECK-NEXT: memref.dealloc %[[ALLOC]] +// CHECK-NEXT: %[[WHILE:.*]]:3 = scf.while +// FIXME: This is non-deterministic +// CHECK-SAME-DAG: [[CLONE1]] +// CHECK-SAME-DAG: [[CLONE2]] +// CHECK-SAME-DAG: [[CLONE3]] + scf.while (%arg1 = %a, %arg2 = %a, %arg3 = %a) : (memref, memref, memref) -> (memref, memref, memref) { + %0 = "test.make_condition"() : () -> i1 + scf.condition(%0) %arg1, %arg2, %arg3 : memref, memref, memref + } do { + ^bb0(%arg1: memref, %arg2: memref, %arg3: memref): + %b = memref.alloc(%arg0) : memref + %q = memref.alloc(%arg0) : memref + scf.yield %q, %b, %arg2: memref, memref, memref + } +// CHECK-DAG: memref.dealloc %[[WHILE]]#0 +// CHECK-DAG: memref.dealloc %[[WHILE]]#1 +// CHECK-DAG: memref.dealloc %[[WHILE]]#2 +// CHECK-NEXT: return + return +}