diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -902,10 +902,12 @@ auto conditionOp = whileOp.getConditionOp(); for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { + Block *block = conditionOp->getBlock(); if (!isa(it.value().getType())) continue; - if (!state.areEquivalentBufferizedValues( - it.value(), conditionOp->getBlock()->getArgument(it.index()))) + if (it.index() >= block->getNumArguments() || + !state.areEquivalentBufferizedValues(it.value(), + block->getArgument(it.index()))) return conditionOp->emitError() << "Condition arg #" << it.index() << " is not equivalent to the corresponding iter bbArg"; @@ -913,10 +915,12 @@ auto yieldOp = whileOp.getYieldOp(); for (const auto &it : llvm::enumerate(yieldOp.getResults())) { + Block *block = yieldOp->getBlock(); if (!isa(it.value().getType())) continue; - if (!state.areEquivalentBufferizedValues( - it.value(), yieldOp->getBlock()->getArgument(it.index()))) + if (it.index() >= block->getNumArguments() || + !state.areEquivalentBufferizedValues(it.value(), + block->getArgument(it.index()))) return yieldOp->emitError() << "Yield operand #" << it.index() << " is not equivalent to the corresponding iter bbArg"; diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir @@ -324,3 +324,17 @@ // This function may write to buffer(%ptr). func.func private @maybe_writing_func(%ptr : tensor<*xf32>) + +// ----- + +func.func @regression_scf_while() { + %false = arith.constant false + %8 = bufferization.alloc_tensor() : tensor<10x10xf32> + scf.while (%arg0 = %8) : (tensor<10x10xf32>) -> () { + scf.condition(%false) + } do { + // expected-error @+1 {{Yield operand #0 is not equivalent to the corresponding iter bbArg}} + scf.yield %8 : tensor<10x10xf32> + } + return +}