diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" @@ -272,17 +273,13 @@ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - // Tensor iter_args of scf::ForOps are always considered as a write. This is - // to simplify the analysis. - // TODO: Consider doing sth. like isValueWritten. + // Tensor iter_args of scf::ForOps are always considered as a write. return true; } SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto forOp = cast(op); - if (!opOperand.get().getType().isa()) - return {}; return {forOp.getResultForOpOperand(opOperand)}; } @@ -293,7 +290,8 @@ auto forOp = cast(op); OpOperand &forOperand = forOp.getOpOperandForResult(opResult); auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); - auto yieldOp = cast(&forOp.getLoopBody().front().back()); + auto yieldOp = + cast(forOp.getLoopBody().front().getTerminator()); bool equivalentYield = state.areEquivalentBufferizedValues( bbArg, yieldOp->getOperand(opResult.getResultNumber())); return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None; @@ -313,14 +311,25 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto forOp = cast(op); + auto bufferizableOp = cast(op); Block *oldLoopBody = &forOp.getLoopBody().front(); // Indices of all iter_args that have tensor type. These are the ones that // are bufferized. DenseSet indices; - for (const auto &it : llvm::enumerate(forOp.getInitArgs())) - if (it.value().getType().isa()) + // For every yielded value, is the value equivalent to its corresponding + // bbArg? + SmallVector equivalentYields; + for (const auto &it : llvm::enumerate(forOp.getInitArgs())) { + if (it.value().getType().isa()) { indices.insert(it.index()); + BufferRelation relation = bufferizableOp.bufferRelation( + forOp->getResult(it.index()), state.getAnalysisState()); + equivalentYields.push_back(relation == BufferRelation::Equivalent); + } else { + equivalentYields.push_back(false); + } + } // Given a range of values, apply `func` to those marked in `indices`. // Otherwise, store the unmodified value in the result vector. @@ -374,8 +383,35 @@ SmallVector yieldValues = convert(yieldOp.getResults(), [&](Value val, int64_t index) { ensureToMemrefOpIsValid(val, initArgs[index].getType()); - return rewriter.create( + Value yieldedVal = rewriter.create( val.getLoc(), initArgs[index].getType(), val); + + if (equivalentYields[index]) + // Yielded value is equivalent to the corresponding iter_arg bbArg. + // Yield the value directly. Most IR should be like that. Everything + // else must be resolved with copies and is potentially inefficient. + // By default, such problematic IR would already have been rejected + // during `verifyAnalysis`, unless `allow-return-allocs`. + return yieldedVal; + + // It is not certain that the yielded value and the iter_arg bbArg + // have the same buffer. Allocate a new buffer and copy. The yielded + // buffer will get deallocated by `deallocateBuffers`. + + // TODO: There are cases in which it is not neccessary to return a new + // buffer allocation. E.g., when equivalent values are yielded in a + // different order. This could be resolved with copies. + Optional yieldedAlloc = state.createAlloc( + rewriter, val.getLoc(), yieldedVal, /*deallocMemref=*/false); + // TODO: We should rollback, but for now just assume that this always + // succeeds. + assert(yieldedAlloc.hasValue() && "could not create alloc"); + LogicalResult copyStatus = + bufferization::createMemCpy(rewriter, val.getLoc(), yieldedVal, + *yieldedAlloc, state.getOptions()); + (void)copyStatus; + assert(succeeded(copyStatus) && "could not create memcpy"); + return *yieldedAlloc; }); yieldOp.getResultsMutable().assign(yieldValues); @@ -385,12 +421,17 @@ return success(); } - /// Assert that yielded values of an scf.for op are aliasing with their - /// corresponding bbArgs. This is required because the i-th OpResult of an - /// scf.for op is currently assumed to alias with the i-th iter_arg (in the - /// absence of conflicts). + /// Assert that yielded values of an scf.for op are equivalent to their + /// corresponding bbArgs. Otherwise, an alloc+copy are inserted and yielded + /// from the loop. This could be a performance problem, so it must be + /// explicitly activated with `alloc-return-allocs`. LogicalResult verifyAnalysis(Operation *op, const AnalysisState &state) const { + const auto &options = + static_cast(state.getOptions()); + if (options.allowReturnAllocs) + return success(); + auto forOp = cast(op); auto yieldOp = cast(forOp.getLoopBody().front().getTerminator()); @@ -405,13 +446,10 @@ // Note: This is overly strict. We should check for aliasing bufferized // values. But we don't have a "must-alias" analysis yet. if (!state.areEquivalentBufferizedValues(operand.get(), bbArg)) - // TODO: this could get resolved with copies but it can also turn into - // swaps so we need to be careful about order of copies. return yieldOp->emitError() << "Yield operand #" << operand.getOperandNumber() << " does not bufferize to a buffer that is aliasing the " - "matching" - << " enclosing scf::for operand"; + "matching enclosing scf::for operand"; } return success(); } diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -23,7 +23,7 @@ MLIRArithmetic MLIRBufferization MLIRBufferizationTransforms - MLIRDialectUtils + MLIRDialectUtils MLIRIR MLIRMemRef MLIRPass diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -1218,3 +1218,108 @@ %f = tensor.extract %r[%j] : tensor return %f : f32 } + +// ----- + +// Note: This bufferizes to inefficient code, but bufferization should not see +// such IR in the first place. The iter_arg would canonicalize away. This test +// case is just to ensure that the bufferization generates correct code. + +// CHECK-LABEL: func @scf_for_yield_non_equivalent( +// CHECK-SAME: %[[t:.*]]: memref, %lb : index, %ub : index, %step : index) -> tensor { + %r = scf.for %i = %lb to %ub step %step iter_args(%a = %t) -> tensor { + scf.yield %t : tensor + } + + return %r : tensor +} + +// ----- + +// Note: This bufferizes to inefficient code, but bufferization should not see +// such IR in the first place. The iter_arg would canonicalize away. This test +// case is just to ensure that the bufferization generates correct code. + +// CHECK-LABEL: func @scf_for_yield_allocation( +// CHECK-SAME: %[[t:.*]]: memref, %lb : index, %ub : index, + %step : index) -> tensor { + %r = scf.for %i = %lb to %ub step %step iter_args(%a = %t) -> tensor { + %t2 = linalg.init_tensor [%i] : tensor + scf.yield %t2 : tensor + } + + return %r : tensor +} + +// ----- + +// TODO: The scf.yield could bufferize to 1 alloc and 2 copies (instead of +// 2 allocs and 2 copies). + +// CHECK-LABEL: func @scf_for_swapping_yields( +// CHECK-SAME: %[[A:.*]]: memref, %[[B:.*]]: memref + +func @scf_for_swapping_yields( + %A : tensor, %B : tensor {linalg.inplaceable = true}, + %C : tensor<4xf32>, %lb : index, %ub : index, %step : index) + -> (f32, f32) +{ +// CHECK-DAG: %[[clone1:.*]] = bufferization.clone %[[A]] +// CHECK-DAG: %[[clone2:.*]] = bufferization.clone %[[B]] +// CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[iter1:.*]] = %[[clone1]], %[[iter2:.*]] = %[[clone2]]) + %r0:2 = scf.for %i = %lb to %ub step %step iter_args(%tA = %A, %tB = %B) + -> (tensor, tensor) + { +// CHECK: %[[sv1:.*]] = memref.subview %[[iter1]] +// CHECK: memref.copy %{{.*}}, %[[sv1]] + %ttA = tensor.insert_slice %C into %tA[0][4][1] : tensor<4xf32> into tensor +// CHECK: %[[sv2:.*]] = memref.subview %[[iter2]] +// CHECK: memref.copy %{{.*}}, %[[sv2]] + %ttB = tensor.insert_slice %C into %tB[0][4][1] : tensor<4xf32> into tensor + +// CHECK: %[[alloc2:.*]] = memref.alloc(%{{.*}}) +// CHECK: memref.copy %[[iter2]], %[[alloc2]] +// CHECK: memref.dealloc %[[iter2]] +// CHECK: %[[alloc1:.*]] = memref.alloc(%{{.*}}) +// CHECK: memref.copy %[[iter1]], %[[alloc1]] +// CHECK: memref.dealloc %[[iter1]] +// CHECK: %[[casted1:.*]] = memref.cast %[[alloc1]] +// CHECK: %[[casted2:.*]] = memref.cast %[[alloc2]] +// CHECK: scf.yield %[[casted2]], %[[casted1]] + // Yield tensors in different order. + scf.yield %ttB, %ttA : tensor, tensor + } + +// CHECK: %[[r0:.*]] = memref.load %[[for]]#0 +// CHECK: memref.dealloc %[[for]]#0 +// CHECK: %[[r1:.*]] = memref.load %[[for]]#1 +// CHECK: memref.dealloc %[[for]]#1 + %f0 = tensor.extract %r0#0[%step] : tensor + %f1 = tensor.extract %r0#1[%step] : tensor +// CHECK: return %[[r0]], %[[r1]] + return %f0, %f1: f32, f32 +}