diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -601,8 +601,13 @@ SmallVector castedInitArgs; for (const auto &it : llvm::enumerate(initArgs)) { Value initArg = it.value(); - auto targetType = - bufferization::getBufferType(forOp->getResult(it.index()), options); + Value result = forOp->getResult(it.index()); + // If the type is not a tensor, bufferization doesn't need to touch it. + if (!result.getType().isa()) { + castedInitArgs.push_back(initArg); + continue; + } + auto targetType = bufferization::getBufferType(result, options); if (failed(targetType)) return failure(); castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType)); @@ -846,8 +851,13 @@ SmallVector castedInitArgs; for (const auto &it : llvm::enumerate(initArgs)) { Value initArg = it.value(); - auto targetType = bufferization::getBufferType( - whileOp.getBeforeArguments()[it.index()], options); + Value beforeArg = whileOp.getBeforeArguments()[it.index()]; + // If the type is not a tensor, bufferization doesn't need to touch it. + if (!beforeArg.getType().isa()) { + castedInitArgs.push_back(initArg); + continue; + } + auto targetType = bufferization::getBufferType(beforeArg, options); if (failed(targetType)) return failure(); castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType)); @@ -856,6 +866,8 @@ // The result types of a WhileOp are the same as the "after" bbArg types. SmallVector argsTypesAfter = llvm::to_vector( llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) { + if (!bbArg.getType().isa()) + return bbArg.getType(); // TODO: error handling return bufferization::getBufferType(bbArg, options)->cast(); })); diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -344,13 +344,14 @@ // CHECK-SAME: %[[arg0:.*]]: memref func.func @scf_while(%arg0: tensor, %idx: index) -> tensor { // CHECK: scf.while : () -> () { - %res = scf.while (%arg1 = %arg0) : (tensor) -> tensor { + %res:2 = scf.while (%arg1 = %arg0, %i = %idx) : + (tensor, index) -> (tensor, index) { // CHECK: %[[condition:.*]] = memref.load %[[arg0]] // CHECK: scf.condition(%[[condition]]) %condition = tensor.extract %arg1[%idx] : tensor - scf.condition(%condition) %arg1 : tensor + scf.condition(%condition) %arg1, %idx : tensor, index } do { - ^bb0(%arg2: tensor): + ^bb0(%arg2: tensor, %i: index): // CHECK: } do { // CHECK: memref.store %{{.*}}, %[[arg0]] // CHECK: scf.yield @@ -358,11 +359,11 @@ %pos = "dummy.some_op"() : () -> (index) %val = "dummy.another_op"() : () -> (i1) %1 = tensor.insert %val into %arg2[%pos] : tensor - scf.yield %1 : tensor + scf.yield %1, %i : tensor, index } // CHECK: return - return %res : tensor + return %res#0 : tensor } // ----- @@ -853,3 +854,19 @@ %x = tensor.extract %r[%c1] : tensor return %x : f32 } + +// ----- + +// CHECK-LABEL: func @non_tensor_for_arg +func.func @non_tensor_for_arg(%A : tensor {bufferization.writable = true}) + -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2.0 : f32 + %c10 = arith.constant 10 : index + %r1:2 = scf.for %i = %c0 to %c10 step %c1 iter_args(%idx = %c1, %t = %A) -> (index, tensor) { + %t2 = tensor.insert %c2 into %t[%idx] : tensor + scf.yield %idx, %t2 : index, tensor + } + return %r1#1 : tensor +} \ No newline at end of file