diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -21,47 +21,18 @@ // TODO: Ops in the linalg dialect can directly implement this interface. -/// Helper function for LinalgOp bufferization. -/// When allocating a new buffer, analyze whether `op` wants to read form that -/// buffer. Only in that case, a copy of the result buffer may be needed. -static LogicalResult -allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op, - SmallVectorImpl &resultBuffers, - BufferizationState &state) { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(op); - - // TODO: provide the proper interface to iterate on OpResults and get the - // matching OpOperands. - for (OpOperand *opOperand : op.getOutputOperands()) { - OpResult opResult = cast(op.getOperation()) - .getAliasingOpResult(*opOperand); - assert(opResult && "could not find correspond OpResult"); - Value resultBuffer = getResultBuffer(b, opResult, state); - if (!resultBuffer) - return failure(); - resultBuffers.push_back(resultBuffer); - } - - if (op->getNumResults()) - state.mapBuffer(op->getResults(), resultBuffers); - - return success(); -} - /// Generic conversion for any LinalgOp on tensors. static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op, BufferizationState &state) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(op); // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need // basis. if (!op.hasTensorSemantics()) return op->emitError() << "op does not have tensor semantics"; - Location loc = op.getLoc(); SmallVector newInputBuffers; newInputBuffers.reserve(op.getNumInputs()); for (OpOperand *opOperand : op.getInputOperands()) { @@ -71,10 +42,17 @@ } newInputBuffers.push_back(state.lookupBuffer(opOperand->get())); } + SmallVector newOutputBuffers; - // Try to allocate new buffers depending on op's inplace semantics. - if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, state))) - return failure(); + for (OpOperand *opOperand : op.getOutputOperands()) { + OpResult opResult = op.getTiedOpResult(opOperand); + assert(opResult && "could not find correspond OpResult"); + Value resultBuffer = getResultBuffer(b, opResult, state); + if (!resultBuffer) + return failure(); + newOutputBuffers.push_back(resultBuffer); + state.mapBuffer(opResult, resultBuffer); + } // Clone the newly bufferized op. SmallVector newOperands = newInputBuffers; @@ -82,11 +60,7 @@ // Set insertion point now that potential alloc/dealloc are introduced. b.setInsertionPoint(op); - op.clone(b, loc, /*resultTypes=*/TypeRange{}, newOperands); - - // Replace the results of the old op with the new output buffers. - if (op->getNumResults()) - state.mapBuffer(op->getResults(), newOutputBuffers); + op.clone(b, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands); // The original op will be DCE'd away later. @@ -99,37 +73,62 @@ OpTy> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { auto genericOp = cast(op); - return (genericOp.isInputTensor(&opOperand) || - genericOp.isInitTensor(&opOperand)) && - genericOp.payloadUsesValueFromOperand(&opOperand); + return genericOp.payloadUsesValueFromOperand(&opOperand); } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { - auto genericOp = cast(op); - return genericOp.isOutputTensor(&opOperand); + auto bufferizableOp = cast(op); + return static_cast(bufferizableOp.getAliasingOpResult(opOperand)); } SmallVector getAliasingOpOperand(Operation *op, OpResult opResult) const { auto genericOp = cast(op); - return {genericOp.getOutputTensorOperands()[opResult.getResultNumber()]}; + OpOperand *tiedOperand = + genericOp.getOutputTensorOperands()[opResult.getResultNumber()]; + bool onlyParallelIterators = + genericOp.getNumParallelLoops() == genericOp.getNumLoops(); + bool tiedOperandUsed = genericOp.payloadUsesValueFromOperand(tiedOperand); + + // If the output arg is used in the computation or at least one iterator is + // not parallel, try to bufferize inplace with the corresponding output + // tensor. + if (tiedOperandUsed || !onlyParallelIterators) + return {tiedOperand}; + + // Otherwise, try to bufferize inplace with one of the inputs. + for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { + if (opOperand->get().getType() != opResult.getType()) + continue; + if (!genericOp.payloadUsesValueFromOperand(opOperand)) + continue; + return {opOperand}; + } + + // No suitable input tensor found. Use output tensor. + // TODO: This operand could bufferize inplace with OpOperands that have the + // correct type, even if they are not used inside the computation. + return {tiedOperand}; } OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { auto genericOp = cast(op); + auto bufferizableOp = cast(op); if (!opOperand.get().getType().isa()) return OpResult(); - // For now assume inputs are never inplaceable. - // TODO: refine this. - if (opOperand.getOperandNumber() < genericOp.getNumInputs()) - return OpResult(); - int64_t outputOperandIndex = - opOperand.getOperandNumber() - genericOp.getNumInputs(); - int64_t numOutputBuffers = 0; - for (unsigned idx = 0; idx < outputOperandIndex; ++idx) - if (!genericOp.getOutputOperand(idx)->get().getType().isa()) - ++numOutputBuffers; - return genericOp->getResult(outputOperandIndex - numOutputBuffers); + + // Check all OpResults to see which one aliases with this OpOperand. + for (OpResult opResult : genericOp->getOpResults()) { + SmallVector aliasingOpOperands = + bufferizableOp.getAliasingOpOperand(opResult); + assert(aliasingOpOperands.size() <= 1 && + "expected at most 1 aliasing OpOperand"); + if (!aliasingOpOperands.empty() && + (aliasingOpOperands.front() == &opOperand)) + return opResult; + } + + return OpResult(); } BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -916,3 +916,51 @@ } return %r : tensor } + +// ----- + +// CHECK-LABEL: func @linalg_op_bufferizes_inplace_with_input +// CHECK-SAME: %[[t1:.*]]: memref, %[[t2:.*]]: memref, %[[t3:.*]]: memref +func @linalg_op_bufferizes_inplace_with_input( + %t1: tensor {linalg.inplaceable = true}, + %t2: tensor, %t3: tensor, + %s1: index, %s2: index, %cst: f32) -> tensor { + // CHECK: linalg.generic {{.*}} ins(%[[t1]], %[[t2]] : {{.*}}) outs(%[[t1]] : {{.*}}) + %r = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1)-> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%t1, %t2 : tensor, tensor) + outs(%t3 : tensor) { + ^bb0(%arg0 : f32, %arg1 : f32, %arg2 : f32) : + %add = arith.addf %arg0, %arg1 : f32 + linalg.yield %add : f32 + } -> tensor + return %r : tensor +} + +// ----- + +// CHECK-LABEL: func @linalg_op_bufferizes_out_of_place_with_input +// CHECK-SAME: %[[t1:.*]]: memref, %[[t2:.*]]: memref, %[[t3:.*]]: memref +func @linalg_op_bufferizes_out_of_place_with_input( + %t1: tensor, %t2: tensor, %t3: tensor, + %s1: index, %s2: index, %cst: f32) -> tensor { + // CHECK: %[[alloc:.*]] = memref.alloc + // CHECK: linalg.copy(%[[t1]], %[[alloc]]) + // CHECK: linalg.generic {{.*}} ins(%[[t1]], %[[t2]] : {{.*}}) outs(%[[alloc]] : {{.*}}) + %r = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1)-> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%t1, %t2 : tensor, tensor) + outs(%t3 : tensor) { + ^bb0(%arg0 : f32, %arg1 : f32, %arg2 : f32) : + %add = arith.addf %arg0, %arg1 : f32 + linalg.yield %add : f32 + } -> tensor + // CHECK: return %[[alloc]] + return %r : tensor +}