diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -377,18 +377,14 @@ /// Creates a memcpy between two given buffers. void createMemCpy(OpBuilder &b, Location loc, Value from, Value to) const; - /// Lookup the memref buffer that is associated to the given tensor value. - /// Asserts if no buffer is associated. - Value lookupBuffer(RewriterBase &rewriter, Value tensor) const; - /// Return `true` if the given OpResult has been decided to bufferize inplace. bool isInPlace(OpOperand &opOperand) const; - /// Return the result buffer (memref) for a given OpResult (tensor). Allocate + /// Return the buffer (memref) for a given OpOperand (tensor). Allocate /// a new buffer and copy over data from the existing buffer if out-of-place - /// bufferization is necessary. - FailureOr getResultBuffer(RewriterBase &rewriter, - OpResult result) const; + /// bufferization was decided. + FailureOr getBuffer(RewriterBase &rewriter, OpOperand &opOperand, + bool forceInPlace = false) const; /// Return dialect-specific bufferization state. template diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -347,74 +347,73 @@ }); } +static Value lookupBuffer(RewriterBase &rewriter, Value tensor) { + assert(tensor.getType().isa() && "unexpected non-tensor type"); + + // Replace "%t = to_tensor %m" with %m. + if (auto toTensorOp = tensor.getDefiningOp()) + return toTensorOp.memref(); + + // Insert to_memref op. + OpBuilder::InsertionGuard g(rewriter); + setInsertionPointAfter(rewriter, tensor); + Type memrefType; + if (auto rankedTensorType = tensor.getType().dyn_cast()) { + memrefType = getDynamicMemRefType(rankedTensorType); + } else { + memrefType = getUnrankedMemRefType( + tensor.getType().cast().getElementType()); + } + return rewriter.create(tensor.getLoc(), memrefType, + tensor); +} + /// Return the result buffer (memref) for a given OpResult (tensor). Allocate /// a new buffer and copy over data from the existing buffer if out-of-place /// bufferization is necessary. FailureOr -mlir::linalg::comprehensive_bufferize::BufferizationState::getResultBuffer( - RewriterBase &rewriter, OpResult result) const { +mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer( + RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace) const { OpBuilder::InsertionGuard guard(rewriter); - Operation *op = result.getOwner(); - SmallVector aliasingOperands = getAliasingOpOperand(result); - assert(!aliasingOperands.empty() && "could not get aliasing OpOperand"); - OpOperand *opOperand = aliasingOperands.front(); - Value operand = opOperand->get(); + Operation *op = opOperand.getOwner(); + Location loc = op->getLoc(); + Value operand = opOperand.get(); Value operandBuffer = lookupBuffer(rewriter, operand); - // Make sure that all OpOperands are the same buffer. If this is not the case, - // we would have to materialize a memref value. - // TODO: Should be looking for checking for "equivalent buffers" instead of - // operator== here, but equivalent buffers for scf.if yield values are not - // set up yet. - if (aliasingOperands.size() > 1 && - !llvm::all_of(aliasingOperands, [&](OpOperand *o) { - return lookupBuffer(rewriter, o->get()) == operandBuffer; - })) - return FailureOr(op->emitError("result buffer is ambiguous")); - - // If bufferizing out-of-place, allocate a new buffer. - if (!aliasInfo.isInPlace(*opOperand)) { - // Ops with multiple aliasing operands can currently not bufferize - // out-of-place. - assert( - aliasingOperands.size() == 1 && - "ops with multiple aliasing OpOperands cannot bufferize out-of-place"); - Location loc = op->getLoc(); - // Move insertion point right after `operandBuffer`. That is where the - // allocation should be inserted (in the absence of allocation hoisting). - setInsertionPointAfter(rewriter, operandBuffer); - // Allocate the result buffer. - FailureOr resultBuffer = - createAlloc(rewriter, loc, operandBuffer, options.createDeallocs); - if (failed(resultBuffer)) - return failure(); - bool skipCopy = false; - // Do not copy if the last preceding write of `operand` is an op that does - // not write (skipping ops that merely create aliases). E.g., InitTensorOp. - // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA - // use-def chain, it returns that value, regardless of whether it is a - // memory write or not. - Value lastWrite = findLastPrecedingWrite(operand); - if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite)) - if (!bufferizableOp.isMemoryWrite(lastWrite.cast(), *this)) - skipCopy = true; - // Do not copy if the copied data is never read. (Neither by this op nor by - // any following op.) - if (!bufferizesToMemoryRead(*opOperand) && !isValueRead(result)) - skipCopy = true; - // Do not copy if this op does not read the data, but writes it. - if (bufferizesToMemoryWrite(*opOperand) && - !bufferizesToMemoryRead(*opOperand)) - skipCopy = true; - if (!skipCopy) { - // The copy happens right before the op that is bufferized. - rewriter.setInsertionPoint(op); - createMemCpy(rewriter, loc, operandBuffer, *resultBuffer); - } + + if (forceInPlace || aliasInfo.isInPlace(opOperand)) + return operandBuffer; + + // Bufferizing out-of-place: Allocate a new buffer. + // Move insertion point right after `operandBuffer`. That is where the + // allocation should be inserted (in the absence of allocation hoisting). + setInsertionPointAfter(rewriter, operandBuffer); + // Allocate the result buffer. + FailureOr resultBuffer = + createAlloc(rewriter, loc, operandBuffer, options.createDeallocs); + if (failed(resultBuffer)) + return failure(); + // Do not copy if the last preceding write of `operand` is an op that does + // not write (skipping ops that merely create aliases). E.g., InitTensorOp. + // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA + // use-def chain, it returns that value, regardless of whether it is a + // memory write or not. + Value lastWrite = findLastPrecedingWrite(operand); + if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite)) + if (!bufferizableOp.isMemoryWrite(lastWrite.cast(), *this)) + return resultBuffer; + // Do not copy if the copied data is never read. + OpResult aliasingOpResult = getAliasingOpResult(opOperand); + if (aliasingOpResult && !bufferizesToMemoryRead(opOperand) && + !isValueRead(aliasingOpResult)) + return resultBuffer; + // Do not copy if this op does not read the data, but writes it. + if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand)) return resultBuffer; - } - // Bufferizing in-place. No need to allocate a new buffer. - return operandBuffer; + // The copy happens right before the op that is bufferized. + rewriter.setInsertionPoint(op); + createMemCpy(rewriter, loc, operandBuffer, *resultBuffer); + return resultBuffer; } void mlir::linalg::comprehensive_bufferize::replaceOpWithBufferizedValues( @@ -593,28 +592,6 @@ return isa(bbArg.getOwner()->getParentOp()); } -Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer( - RewriterBase &rewriter, Value tensor) const { - assert(tensor.getType().isa() && "unexpected non-tensor type"); - - // Replace "%t = to_tensor %m" with %m. - if (auto toTensorOp = tensor.getDefiningOp()) - return toTensorOp.memref(); - - // Insert to_memref op. - OpBuilder::InsertionGuard g(rewriter); - setInsertionPointAfter(rewriter, tensor); - Type memrefType; - if (auto rankedTensorType = tensor.getType().dyn_cast()) { - memrefType = getDynamicMemRefType(rankedTensorType); - } else { - memrefType = getUnrankedMemRefType( - tensor.getType().cast().getElementType()); - } - return rewriter.create(tensor.getLoc(), memrefType, - tensor); -} - bool mlir::linalg::comprehensive_bufferize::BufferizationState::isInPlace( OpOperand &opOperand) const { return aliasInfo.isInPlace(opOperand); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -46,15 +46,19 @@ newInputBuffers.push_back(opOperand->get()); continue; } - newInputBuffers.push_back(state.lookupBuffer(rewriter, opOperand->get())); + // Input operands are never written to. + newInputBuffers.push_back( + *state.getBuffer(rewriter, *opOperand, /*forceInPlace=*/true)); } // New output operands for the cloned op. SmallVector newOutputBuffers; - for (OpOperand *opOperand : op.getOutputOperands()) { - OpResult opResult = op.getTiedOpResult(opOperand); - assert(opResult && "could not find correspond OpResult"); - FailureOr resultBuffer = state.getResultBuffer(rewriter, opResult); + for (OpResult opResult : op->getOpResults()) { + SmallVector aliasingOpOperands = + state.getAliasingOpOperand(opResult); + assert(aliasingOpOperands.size() == 1 && "expected 1 OpOperand"); + FailureOr resultBuffer = + state.getBuffer(rewriter, *aliasingOpOperands.front()); if (failed(resultBuffer)) return failure(); newOutputBuffers.push_back(*resultBuffer); @@ -284,24 +288,23 @@ // Compute new inputs, outputs and results. SmallVector newInputs, newOutputs, newResults; - for (Value value : tiledLoopOp.inputs()) { - if (value.getType().isa()) { - newInputs.push_back(state.lookupBuffer(rewriter, value)); - } else { - newInputs.push_back(value); - } - } - int nextResultNum = 0; - for (Value value : tiledLoopOp.outputs()) { - if (value.getType().isa()) { - FailureOr buffer = state.getResultBuffer( - rewriter, tiledLoopOp->getResult(nextResultNum++)); - if (failed(buffer)) + for (int i = tiledLoopOp.getNumControlOperands(); + i < tiledLoopOp->getNumOperands(); ++i) { + OpOperand &operand = tiledLoopOp->getOpOperand(i); + Value rewrittenValue = operand.get(); + if (rewrittenValue.getType().isa()) { + FailureOr bufferOrFailure = state.getBuffer(rewriter, operand); + if (failed(bufferOrFailure)) return failure(); - newOutputs.push_back(*buffer); - newResults.push_back(*buffer); + rewrittenValue = *bufferOrFailure; + } + if (i < + tiledLoopOp.getNumControlOperands() + tiledLoopOp.getNumInputs()) { + newInputs.push_back(rewrittenValue); } else { - newOutputs.push_back(value); + newOutputs.push_back(rewrittenValue); + if (operand.get().getType().isa()) + newResults.push_back(rewrittenValue); } } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -351,7 +351,7 @@ // Cast values at the call site if necessary. returnValues.push_back( - getNonCastedValue(state.lookupBuffer(rewriter, returnVal))); + getNonCastedValue(*state.getBuffer(rewriter, returnOperand))); } // 2. Rewrite the terminator without the inPlace bufferizable values. @@ -659,7 +659,8 @@ // Return operands that are equivalent to some bbArg, are not // returned. Value buffer = - state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx)); + *state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx), + /*forceInPlace=*/true); replacementValues[returnValIdx] = buffer; newOperands[*bbArgIdx] = buffer; continue; @@ -690,9 +691,9 @@ // Retrieve buffers for tensor operands. Tensor operand buffers, who's // corresponding FuncOp bbArgs are equivalent to a returned tensor, were // already stored in `newOperands` during Step 1. - Value buffer = newOperands[idx] - ? newOperands[idx] - : state.lookupBuffer(rewriter, tensorOperand); + Value buffer = newOperands[idx] ? newOperands[idx] + : *state.getBuffer(rewriter, opOperand, + /*forceInPlace=*/true); // Caller / callee type mistmatch is handled with a CastOp. auto memRefType = bufferizedFuncType.getInput(idx); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -280,19 +280,17 @@ }; // Construct a new scf.for op with memref instead of tensor values. - bool resultBufferFailure = false; - SmallVector initArgs = - convert(forOp.getInitArgs(), [&](Value val, int64_t index) { - FailureOr resultBuffer = - state.getResultBuffer(rewriter, forOp->getOpResult(index)); - if (failed(resultBuffer)) { - resultBufferFailure = true; - return Value(); - } - return *resultBuffer; - }); - if (resultBufferFailure) - return failure(); + SmallVector initArgs; + for (OpOperand &opOperand : forOp.getIterOpOperands()) { + if (opOperand.get().getType().isa()) { + FailureOr resultBuffer = state.getBuffer(rewriter, opOperand); + if (failed(resultBuffer)) + return failure(); + initArgs.push_back(*resultBuffer); + } else { + initArgs.push_back(opOperand.get()); + } + } auto newForOp = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), initArgs); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -53,7 +53,7 @@ // The result buffer still has the old (pre-cast) type. FailureOr resultBuffer = - state.getResultBuffer(rewriter, castOp->getResult(0)); + state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/); if (failed(resultBuffer)) return failure(); auto sourceMemRefType = resultBuffer->getType().cast(); @@ -106,7 +106,7 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationState &state) const { auto dimOp = cast(op); - Value v = state.lookupBuffer(rewriter, dimOp.source()); + Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/); replaceOpWithNewBufferizedOp(rewriter, op, v, dimOp.index()); return success(); } @@ -143,7 +143,9 @@ const BufferizationState &state) const { auto extractSliceOp = cast(op); Location loc = extractSliceOp.getLoc(); - Value srcMemref = state.lookupBuffer(rewriter, extractSliceOp.source()); + Value srcMemref = + *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/, + /*forceInPlace=*/true); auto srcMemrefType = srcMemref.getType().cast(); auto dstTensorType = extractSliceOp.result().getType().cast(); @@ -206,7 +208,8 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationState &state) const { auto extractOp = cast(op); - Value srcMemref = state.lookupBuffer(rewriter, extractOp.tensor()); + Value srcMemref = + *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/); replaceOpWithNewBufferizedOp(rewriter, op, srcMemref, extractOp.indices()); return success(); @@ -244,7 +247,7 @@ const BufferizationState &state) const { auto insertOp = cast(op); FailureOr destMemref = - state.getResultBuffer(rewriter, insertOp->getOpResult(0)); + state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/); if (failed(destMemref)) return failure(); rewriter.create(insertOp.getLoc(), insertOp.scalar(), @@ -412,7 +415,7 @@ // When bufferizing out-of-place, `getResultBuffer` allocates. FailureOr dstMemref = - state.getResultBuffer(rewriter, insertSliceOp->getResult(0)); + state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/); if (failed(dstMemref)) return failure(); @@ -430,7 +433,8 @@ // Copy tensor. If this tensor.insert_slice has a matching // tensor.extract_slice, the copy operation will eventually fold away. - Value srcMemref = state.lookupBuffer(rewriter, insertSliceOp.source()); + Value srcMemref = + *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/); state.createMemCpy(rewriter, loc, srcMemref, subView); replaceOpWithBufferizedValues(rewriter, op, *dstMemref); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp @@ -48,7 +48,8 @@ "only tensor types expected"); // TransferReadOp always reads from the bufferized op.source(). - Value buffer = state.lookupBuffer(rewriter, readOp.source()); + Value buffer = + *state.getBuffer(rewriter, readOp->getOpOperand(0) /*source*/); replaceOpWithNewBufferizedOp( rewriter, readOp, readOp.getVectorType(), buffer, readOp.indices(), readOp.permutation_map(), readOp.padding(), readOp.mask(), @@ -99,7 +100,7 @@ // Leave the previous transfer_write to dead code as it still has uses at // this point. FailureOr resultBuffer = - state.getResultBuffer(rewriter, op->getResult(0)); + state.getBuffer(rewriter, op->getOpOperand(1) /*source*/); if (failed(resultBuffer)) return failure(); rewriter.create(