diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -259,6 +259,150 @@ } }; +/// Helper function for loop bufferization. Return the indices of all values +/// that have a tensor type. +static DenseSet getTensorIndices(ValueRange values) { + DenseSet result; + for (const auto &it : llvm::enumerate(values)) + if (it.value().getType().isa()) + result.insert(it.index()); + return result; +} + +/// Helper function for loop bufferization. Return the indices of all +/// bbArg/yielded value pairs who's buffer relation is "Equivalent". +DenseSet getEquivalentBuffers(ValueRange bbArgs, + ValueRange yieldedValues, + const AnalysisState &state) { + DenseSet result; + int64_t counter = 0; + for (const auto &it : llvm::zip(bbArgs, yieldedValues)) { + if (!std::get<0>(it).getType().isa()) + continue; + if (state.areEquivalentBufferizedValues(std::get<0>(it), std::get<1>(it))) + result.insert(counter); + counter++; + } + return result; +} + +/// Helper function for loop bufferization. Cast the given buffer to the given +/// memref type. +static Value castBuffer(OpBuilder &b, Value buffer, Type type) { + assert(type.isa() && "expected BaseMemRefType"); + assert(buffer.getType().isa() && "expected BaseMemRefType"); + // If the buffer already has the correct type, no cast is needed. + if (buffer.getType() == type) + return buffer; + // TODO: In case `type` has a layout map that is not the fully dynamic + // one, we may not be able to cast the buffer. In that case, the loop + // iter_arg's layout map must be changed (see uses of `castBuffer`). + assert(memref::CastOp::areCastCompatible(buffer.getType(), type) && + "scf.while op bufferization: cast incompatible"); + return b.create(buffer.getLoc(), type, buffer).getResult(); +} + +/// Helper function for loop bufferization. Return the bufferized values of the +/// given OpOperands. If an operand is not a tensor, return the original value. +static SmallVector getBuffers(RewriterBase &rewriter, + MutableArrayRef operands, + BufferizationState &state) { + SmallVector result; + for (OpOperand &opOperand : operands) { + if (opOperand.get().getType().isa()) { + FailureOr resultBuffer = state.getBuffer(rewriter, opOperand); + if (failed(resultBuffer)) + return {}; + result.push_back(*resultBuffer); + } else { + result.push_back(opOperand.get()); + } + } + return result; +} + +/// Helper function for loop bufferization. Compute the buffer that should be +/// yielded from a loop block (loop body or loop condition). If the given tensor +/// is equivalent to the corresponding block argument (as indicated by +/// `isEquivalent`), the buffer can be yielded directly. Otherwise, a new buffer +/// copy must be yielded. +/// +/// According to the `BufferizableOpInterface` implementation of scf loops, a +/// a bufferized OpResult may alias only with the corresponding bufferized +/// init_arg and with no other buffers. I.e., the i-th OpResult may alias with +/// the i-th init_arg; but not with any other OpOperand. If a corresponding +/// OpResult/init_arg pair bufferized to equivalent buffers (as indicated by +/// `isEquivalent`), this aliasing requirement is satisfied. Otherwise, we +/// cannot be sure and must yield a new buffer copy. (New buffer copies do not +/// alias with any buffer.) +static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor, + BaseMemRefType type, bool isEquivalent, + BufferizationState &state) { + assert(tensor.getType().isa() && "expected tensor"); + ensureToMemrefOpIsValid(tensor, type); + Value yieldedVal = + bufferization::lookupBuffer(rewriter, tensor, state.getOptions()); + + if (isEquivalent) + // Yielded value is equivalent to the corresponding iter_arg bbArg. + // Yield the value directly. Most IR should be like that. Everything + // else must be resolved with copies and is potentially inefficient. + // By default, such problematic IR would already have been rejected + // during `verifyAnalysis`, unless `allow-return-allocs`. + return castBuffer(rewriter, yieldedVal, type); + + // It is not certain that the yielded value and the iter_arg bbArg + // have the same buffer. Allocate a new buffer and copy. The yielded + // buffer will get deallocated by `deallocateBuffers`. + + // TODO: There are cases in which it is not neccessary to return a new + // buffer allocation. E.g., when equivalent values are yielded in a + // different order. This could be resolved with copies. + Optional yieldedAlloc = state.createAlloc( + rewriter, tensor.getLoc(), yieldedVal, /*deallocMemref=*/false); + // TODO: We should rollback, but for now just assume that this always + // succeeds. + assert(yieldedAlloc.hasValue() && "could not create alloc"); + LogicalResult copyStatus = bufferization::createMemCpy( + rewriter, tensor.getLoc(), yieldedVal, *yieldedAlloc, state.getOptions()); + (void)copyStatus; + assert(succeeded(copyStatus) && "could not create memcpy"); + + // The iter_arg memref type may have a layout map. Cast the new buffer + // to the same type if needed. + return castBuffer(rewriter, *yieldedAlloc, type); +} + +/// Helper function for loop bufferization. Given a range of values, apply +/// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified +/// value in the result vector. +static SmallVector +convertTensorValues(ValueRange values, const DenseSet &tensorIndices, + llvm::function_ref func) { + SmallVector result; + for (const auto &it : llvm::enumerate(values)) { + size_t idx = it.index(); + Value val = it.value(); + result.push_back(tensorIndices.contains(idx) ? func(val, idx) : val); + } + return result; +} + +/// Helper function for loop bufferization. Given a list of pre-bufferization +/// yielded values, compute the list of bufferized yielded values. +SmallVector getYieldedValues(RewriterBase &rewriter, ValueRange values, + TypeRange bufferizedTypes, + const DenseSet &tensorIndices, + const DenseSet &equivalentTensors, + BufferizationState &state) { + return convertTensorValues( + values, tensorIndices, [&](Value val, int64_t index) { + return getYieldedBuffer(rewriter, val, + bufferizedTypes[index].cast(), + equivalentTensors.contains(index), state); + }); +} + /// Bufferization of scf.for. Replace with a new scf.for that operates on /// memrefs. struct ForOpInterface @@ -312,78 +456,38 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto forOp = cast(op); - auto bufferizableOp = cast(op); + auto oldYieldOp = + cast(forOp.getLoopBody().front().getTerminator()); Block *oldLoopBody = &forOp.getLoopBody().front(); - // Helper function for casting MemRef buffers. - auto castBuffer = [&](Value buffer, Type type) { - assert(type.isa() && "expected BaseMemRefType"); - assert(buffer.getType().isa() && - "expected BaseMemRefType"); - // If the buffer already has the correct type, no cast is needed. - if (buffer.getType() == type) - return buffer; - // TODO: In case `type` has a layout map that is not the fully dynamic - // one, we may not be able to cast the buffer. In that case, the loop - // iter_arg's layout map must be changed (see uses of `castBuffer`). - assert(memref::CastOp::areCastCompatible(buffer.getType(), type) && - "scf.for op bufferization: cast incompatible"); - return rewriter.create(buffer.getLoc(), type, buffer) - .getResult(); - }; - // Indices of all iter_args that have tensor type. These are the ones that // are bufferized. - DenseSet indices; + DenseSet indices = getTensorIndices(forOp.getInitArgs()); // For every yielded value, is the value equivalent to its corresponding // bbArg? - SmallVector equivalentYields; - for (const auto &it : llvm::enumerate(forOp.getInitArgs())) { - if (it.value().getType().isa()) { - indices.insert(it.index()); - BufferRelation relation = bufferizableOp.bufferRelation( - forOp->getResult(it.index()), state.getAnalysisState()); - equivalentYields.push_back(relation == BufferRelation::Equivalent); - } else { - equivalentYields.push_back(false); - } - } + DenseSet equivalentYields = + getEquivalentBuffers(forOp.getRegionIterArgs(), oldYieldOp.getResults(), + state.getAnalysisState()); - // Given a range of values, apply `func` to those marked in `indices`. - // Otherwise, store the unmodified value in the result vector. - auto convert = [&](ValueRange values, - llvm::function_ref func) { - SmallVector result; - for (const auto &it : llvm::enumerate(values)) { - size_t idx = it.index(); - Value val = it.value(); - result.push_back(indices.contains(idx) ? func(val, idx) : val); - } - return result; - }; + // The new memref init_args of the loop. + SmallVector initArgs = + getBuffers(rewriter, forOp.getIterOpOperands(), state); + if (initArgs.size() != indices.size()) + return failure(); // Construct a new scf.for op with memref instead of tensor values. - SmallVector initArgs; - for (OpOperand &opOperand : forOp.getIterOpOperands()) { - if (opOperand.get().getType().isa()) { - FailureOr resultBuffer = state.getBuffer(rewriter, opOperand); - if (failed(resultBuffer)) - return failure(); - initArgs.push_back(*resultBuffer); - } else { - initArgs.push_back(opOperand.get()); - } - } auto newForOp = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), initArgs); + ValueRange initArgsRange(initArgs); + TypeRange initArgsTypes(initArgsRange); Block *loopBody = &newForOp.getLoopBody().front(); // Set up new iter_args. The loop body uses tensors, so wrap the (memref) // iter_args of the new loop in ToTensorOps. rewriter.setInsertionPointToStart(loopBody); - SmallVector iterArgs = - convert(newForOp.getRegionIterArgs(), [&](Value val, int64_t index) { + SmallVector iterArgs = convertTensorValues( + newForOp.getRegionIterArgs(), indices, [&](Value val, int64_t index) { return rewriter.create(val.getLoc(), val); }); iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); @@ -399,42 +503,8 @@ auto yieldOp = cast(loopBody->getTerminator()); rewriter.setInsertionPoint(yieldOp); SmallVector yieldValues = - convert(yieldOp.getResults(), [&](Value val, int64_t index) { - Type initArgType = initArgs[index].getType(); - ensureToMemrefOpIsValid(val, initArgType); - Value yieldedVal = - bufferization::lookupBuffer(rewriter, val, state.getOptions()); - - if (equivalentYields[index]) - // Yielded value is equivalent to the corresponding iter_arg bbArg. - // Yield the value directly. Most IR should be like that. Everything - // else must be resolved with copies and is potentially inefficient. - // By default, such problematic IR would already have been rejected - // during `verifyAnalysis`, unless `allow-return-allocs`. - return castBuffer(yieldedVal, initArgType); - - // It is not certain that the yielded value and the iter_arg bbArg - // have the same buffer. Allocate a new buffer and copy. The yielded - // buffer will get deallocated by `deallocateBuffers`. - - // TODO: There are cases in which it is not neccessary to return a new - // buffer allocation. E.g., when equivalent values are yielded in a - // different order. This could be resolved with copies. - Optional yieldedAlloc = state.createAlloc( - rewriter, val.getLoc(), yieldedVal, /*deallocMemref=*/false); - // TODO: We should rollback, but for now just assume that this always - // succeeds. - assert(yieldedAlloc.hasValue() && "could not create alloc"); - LogicalResult copyStatus = - bufferization::createMemCpy(rewriter, val.getLoc(), yieldedVal, - *yieldedAlloc, state.getOptions()); - (void)copyStatus; - assert(succeeded(copyStatus) && "could not create memcpy"); - - // The iter_arg memref type may have a layout map. Cast the new buffer - // to the same type if needed. - return castBuffer(*yieldedAlloc, initArgType); - }); + getYieldedValues(rewriter, yieldOp.getResults(), initArgsTypes, indices, + equivalentYields, state); yieldOp.getResultsMutable().assign(yieldValues); // Replace loop results.