diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -461,18 +461,20 @@ AnalysisState state(bufferizationOptions); #ifndef NDEBUG - // Ops with nested tensor ops are not supported yet. At the moment, this - // function just bufferizes the given op itself, but not its body. - op->walk([&](Operation *nestedOp) { - if (op == nestedOp) - return; - if (llvm::any_of(nestedOp->getOperands(), - [](Value v) { return v.getType().isa(); })) - llvm_unreachable("ops with nested tensor ops are not supported yet"); - if (llvm::any_of(nestedOp->getResults(), - [](Value v) { return v.getType().isa(); })) - llvm_unreachable("ops with nested tensor ops are not supported yet"); - }); + if (!options.bufferizeDestinationOnly) { + // Ops with nested tensor ops are not supported yet. At the moment, this + // function just bufferizes the given op itself, but not its body. + op->walk([&](Operation *nestedOp) { + if (op == nestedOp) + return; + if (llvm::any_of(nestedOp->getOperands(), + [](Value v) { return v.getType().isa(); })) + llvm_unreachable("ops with nested tensor ops are not supported yet"); + if (llvm::any_of(nestedOp->getResults(), + [](Value v) { return v.getType().isa(); })) + llvm_unreachable("ops with nested tensor ops are not supported yet"); + }); + } #endif // NDEBUG // Gather tensor results. @@ -509,7 +511,7 @@ if (!state.bufferizesToMemoryWrite(operand)) continue; if (!isa(operand.get().getType())) - return nullptr; + continue; addOutOfPlaceOperand(&operand); } // TODO: Support multiple buffers. diff --git a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir --- a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir @@ -218,3 +218,27 @@ %0 = transform.structured.match ops{["tensor.insert"]} in %arg1 : (!transform.any_op) -> !transform.any_op %2, %new = transform.structured.bufferize_to_allocation %0 {memory_space = 4, bufferize_destination_only} : !transform.any_op } + +// ----- + +// CHECK-LABEL: func @scf_for_destination( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: %[[alloc:.*]] = memref.alloc(%{{.*}}) : memref +// CHECK: memref.tensor_store %[[t]], %[[alloc]] +// CHECK: %[[t2:.*]] = bufferization.to_tensor %[[alloc]] restrict writable +// CHECK: %[[for:.*]] = scf.for {{.*}} iter_args(%{{.*}} = %[[t2]]) +// CHECK: memref.dealloc %[[alloc]] +// CHECK: return %[[for]] +func.func @scf_for_destination(%t: tensor, %lb: index, %ub: index, %step: index) -> tensor { + %r = scf.for %iv = %lb to %ub step %step iter_args(%a = %t) -> tensor { + %b = "test.foo"(%a) : (tensor) -> (tensor) + scf.yield %b : tensor + } + return %r : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %2, %new = transform.structured.bufferize_to_allocation %0 {memory_space = 4, bufferize_destination_only} : !transform.any_op +}