diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -891,7 +891,7 @@ assert(value.getType().isa() && "expected tensor type"); // Case 1: Block argument of the "before" region. - if (auto bbArg = value.cast()) { + if (auto bbArg = value.dyn_cast()) { if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) { Value initArg = whileOp.getInits()[bbArg.getArgNumber()]; auto yieldOp = whileOp.getYieldOp(); diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -876,3 +876,26 @@ } return } + +// ----- + +// This is a regression test. Just check that the IR bufferizes. + +// CHECK-LABEL: func @non_block_argument_yield +func.func @non_block_argument_yield() { + %true = arith.constant true + %0 = bufferization.alloc_tensor() : tensor + %1 = scf.while (%arg0 = %0) : (tensor) -> (tensor) { + scf.condition(%true) %arg0 : tensor + } do { + ^bb0(%arg0: tensor): + %ret = scf.while (%arg1 = %0) : (tensor) -> (tensor) { + scf.condition(%true) %arg1 : tensor + } do { + ^bb0(%arg7: tensor): + scf.yield %0 : tensor + } + scf.yield %ret : tensor + } + return +}