diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -21,40 +21,12 @@ // TODO: Ops in the linalg dialect can directly implement this interface. -/// Helper function for LinalgOp bufferization. -/// When allocating a new buffer, analyze whether `op` wants to read form that -/// buffer. Only in that case, a copy of the result buffer may be needed. -static LogicalResult -allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op, - SmallVectorImpl &resultBuffers, - BufferizationState &state) { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(op); - - // TODO: provide the proper interface to iterate on OpResults and get the - // matching OpOperands. - for (OpOperand *opOperand : op.getOutputOperands()) { - OpResult opResult = cast(op.getOperation()) - .getAliasingOpResult(*opOperand); - assert(opResult && "could not find correspond OpResult"); - Value resultBuffer = state.getResultBuffer(opResult); - if (!resultBuffer) - return failure(); - resultBuffers.push_back(resultBuffer); - } - - if (op->getNumResults()) - state.mapBuffer(op->getResults(), resultBuffers); - - return success(); -} - /// Generic conversion for any LinalgOp on tensors. static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op, BufferizationState &state) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(op); // Nothing to do. This op is already bufferized. if (op.hasBufferSemantics()) @@ -65,7 +37,6 @@ if (!op.hasTensorSemantics()) return op->emitError() << "op does not have tensor semantics"; - Location loc = op.getLoc(); SmallVector newInputBuffers; newInputBuffers.reserve(op.getNumInputs()); for (OpOperand *opOperand : op.getInputOperands()) { @@ -75,10 +46,17 @@ } newInputBuffers.push_back(state.lookupBuffer(opOperand->get())); } + SmallVector newOutputBuffers; - // Try to allocate new buffers depending on op's inplace semantics. - if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, state))) - return failure(); + for (OpOperand *opOperand : op.getOutputOperands()) { + OpResult opResult = op.getTiedOpResult(opOperand); + assert(opResult && "could not find correspond OpResult"); + Value resultBuffer = state.getResultBuffer(opResult); + if (!resultBuffer) + return failure(); + newOutputBuffers.push_back(resultBuffer); + state.mapBuffer(opResult, resultBuffer); + } // Clone the newly bufferized op. SmallVector newOperands = newInputBuffers; @@ -87,53 +65,103 @@ // Set insertion point now that potential alloc/dealloc are introduced. b.setInsertionPoint(op); auto bufferizedOp = cast( - op.clone(b, loc, /*resultTypes=*/TypeRange{}, newOperands)); - - // Replace the results of the old op with the new output buffers. - if (op->getNumResults()) - state.mapBuffer(op->getResults(), newOutputBuffers); + op.clone(b, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands)); // The original op will be DCE'd away later. return comprehensive_bufferize::bufferize(bufferizedOp.getBlock(), state); } +/// Linalg OpResults usually bufferize inplace with their tied (output +/// OpOperands. However, if an output OpOperand is not used in the computation, +/// it is better to bufferize inplace with an actually used input OpOperand; +/// less memory will be touched that way. +/// +/// Example: +/// O(i, j) = A(i, j) + B(j) --> bufferizes inplace to: A(i, j) += B(j) +/// +/// O(i, j) = A(j, i) + B(j) --> cannot bufferize inplace with A because +/// indexing maps are not identical +/// +/// O(i, j) += A(i, j) + B(j) --> Output is used in computation. +/// This could bufferize inplace with A: +/// A(i, j) += O(i, j) + B(j) +/// However, we choose to bufferize inplace with O here, as there is no clear +/// benefit of choosing A. TODO: We may want to consider both options and make +/// an informed decision during analysis in the future. +static DenseMap computeAliasingPairs(LinalgOp op) { + DenseMap mapping; + for (OpResult opResult : op->getOpResults()) { + OpOperand *tiedOperand = + op.getOutputTensorOperands()[opResult.getResultNumber()]; + AffineMap outputIndexingMap = op.getTiedIndexingMap(tiedOperand); + bool onlyParallelIterators = op.getNumParallelLoops() == op.getNumLoops(); + bool tiedOperandUsed = op.payloadUsesValueFromOperand(tiedOperand); + + // If the output arg is used in the computation or at least one iterator is + // not parallel, try to bufferize inplace with the corresponding output + // tensor. + if (tiedOperandUsed || !onlyParallelIterators) { + mapping[tiedOperand] = opResult; + continue; + } + + // Otherwise, try to bufferize inplace with one of the inputs. + OpOperand *chosenOperand = nullptr; + for (OpOperand *opOperand : op.getInputTensorOperands()) { + if (opOperand->get().getType() != opResult.getType()) + continue; + if (!op.payloadUsesValueFromOperand(opOperand)) + continue; + if (op.getTiedIndexingMap(opOperand) != outputIndexingMap) + continue; + // No other OpResult bufferizes aliases with this OpOperand. + if (mapping.count(opOperand)) + continue; + assert(op.getTiedIndexingMap(opOperand).isProjectedPermutation() && + "expected projected permutation"); + chosenOperand = opOperand; + break; + } + + // No suitable input tensor found. Use output tensor. + // TODO: This operand could bufferize inplace with OpOperands that have the + // correct type, even if they are not used inside the computation. + if (!chosenOperand) + chosenOperand = tiedOperand; + + mapping[chosenOperand] = opResult; + } + return mapping; +} + template struct LinalgOpInterface : public BufferizableOpInterface::ExternalModel, OpTy> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { auto genericOp = cast(op); - return (genericOp.isInputTensor(&opOperand) || - genericOp.isInitTensor(&opOperand)) && - genericOp.payloadUsesValueFromOperand(&opOperand); + return genericOp.payloadUsesValueFromOperand(&opOperand); } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { - auto genericOp = cast(op); - return genericOp.isOutputTensor(&opOperand); + auto bufferizableOp = cast(op); + return static_cast(bufferizableOp.getAliasingOpResult(opOperand)); } SmallVector getAliasingOpOperand(Operation *op, OpResult opResult) const { auto genericOp = cast(op); - return {genericOp.getOutputTensorOperands()[opResult.getResultNumber()]}; + DenseMap pairs = computeAliasingPairs(genericOp); + for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) + if (pairs[opOperand] == opResult) + return {opOperand}; + return {}; } OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { auto genericOp = cast(op); - if (!opOperand.get().getType().isa()) - return OpResult(); - // For now assume inputs are never inplaceable. - // TODO: refine this. - if (opOperand.getOperandNumber() < genericOp.getNumInputs()) - return OpResult(); - int64_t outputOperandIndex = - opOperand.getOperandNumber() - genericOp.getNumInputs(); - int64_t numOutputBuffers = 0; - for (unsigned idx = 0; idx < outputOperandIndex; ++idx) - if (!genericOp.getOutputOperand(idx)->get().getType().isa()) - ++numOutputBuffers; - return genericOp->getResult(outputOperandIndex - numOutputBuffers); + DenseMap pairs = computeAliasingPairs(genericOp); + return pairs[&opOperand]; } BufferRelation bufferRelation(Operation *op, OpResult opResult, diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -940,7 +940,7 @@ %t2: tensor {linalg.inplaceable = true}) -> (tensor, tensor){ // CHECK: linalg.generic - // CHECK-SAME: {__inplace_results_attr__ = ["true", "false"] + // CHECK-SAME: {__inplace_results_attr__ = ["true", "true"] %o:2 = linalg.generic #trait ins(%t1 : tensor) outs (%t2, %t2 : tensor, tensor) { ^bb(%0: f32, %1: f32, %2 : f32) : @@ -948,12 +948,45 @@ } -> (tensor, tensor) // CHECK: return - // CHECK-SAME: {__equivalent_func_args__ = [1, -1]} + // CHECK-SAME: {__equivalent_func_args__ = [0, 1]} return %o#0, %o#1 : tensor, tensor } // ----- +#accesses = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (i)>, + affine_map<(i) -> (i)>, + affine_map<(i) -> (i)> +] +#trait = { + indexing_maps = #accesses, + iterator_types = ["parallel"] +} + +// CHECK-LABEL: func @linalg_op_same_out_tensors_2 +func @linalg_op_same_out_tensors_2( + %t1: tensor {linalg.inplaceable = true}, + %t2: tensor {linalg.inplaceable = true}) + -> (tensor, tensor, tensor){ + + // CHECK: linalg.generic + // CHECK-SAME: {__inplace_results_attr__ = ["true", "true", "false"] + %o:3 = linalg.generic #trait + ins(%t1 : tensor) + outs (%t2, %t2, %t2 : tensor, tensor, tensor) { + ^bb(%0: f32, %1: f32, %2 : f32, %3 : f32) : + linalg.yield %0, %0, %0 : f32, f32, f32 + } -> (tensor, tensor, tensor) + + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [0, 1, -1]} + return %o#0, %o#1, %o#2 : tensor, tensor, tensor +} + +// ----- + // CHECK-LABEL: func @double_insert_slice_into_alias func @double_insert_slice_into_alias( %v1: vector<32x90xf32>, diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -1036,3 +1036,75 @@ // CHECK-SAME: outs(%[[ARG2]] : // CHECK: %[[YIELD:.+]] = memref.load %[[ARG0]] // CHECK: linalg.yield %[[YIELD]] + +// ----- + +// CHECK-LABEL: func @linalg_op_bufferizes_inplace_with_input +// CHECK-SAME: %[[t1:.*]]: memref, %[[t2:.*]]: memref, %[[t3:.*]]: memref +func @linalg_op_bufferizes_inplace_with_input( + %t1: tensor {linalg.inplaceable = true}, + %t2: tensor, %t3: tensor, + %s1: index, %s2: index, %cst: f32) -> tensor { + // CHECK: linalg.generic {{.*}} ins(%[[t1]], %[[t2]] : {{.*}}) outs(%[[t1]] : {{.*}}) + %r = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1)-> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%t1, %t2 : tensor, tensor) + outs(%t3 : tensor) { + ^bb0(%arg0 : f32, %arg1 : f32, %arg2 : f32) : + %add = arith.addf %arg0, %arg1 : f32 + linalg.yield %add : f32 + } -> tensor + return %r : tensor +} + +// ----- + +// CHECK-LABEL: func @linalg_op_bufferizes_out_of_place_with_input +// CHECK-SAME: %[[t1:.*]]: memref, %[[t2:.*]]: memref, %[[t3:.*]]: memref +func @linalg_op_bufferizes_out_of_place_with_input( + %t1: tensor, %t2: tensor, %t3: tensor, + %s1: index, %s2: index, %cst: f32) -> tensor { + // CHECK: %[[alloc:.*]] = memref.alloc + // CHECK: linalg.copy(%[[t1]], %[[alloc]]) + // CHECK: linalg.generic {{.*}} ins(%[[t1]], %[[t2]] : {{.*}}) outs(%[[alloc]] : {{.*}}) + %r = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1)-> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%t1, %t2 : tensor, tensor) + outs(%t3 : tensor) { + ^bb0(%arg0 : f32, %arg1 : f32, %arg2 : f32) : + %add = arith.addf %arg0, %arg1 : f32 + linalg.yield %add : f32 + } -> tensor + // CHECK: return %[[alloc]] + return %r : tensor +} + +// ----- + +// CHECK-LABEL: func @linalg_op_output_cannot_alias_with_input +// CHECK-SAME: %[[t1:.*]]: memref, %[[t2:.*]]: memref, %[[t3:.*]]: memref +func @linalg_op_output_cannot_alias_with_input( + %t1: tensor {linalg.inplaceable = true}, + %t2: tensor, %t3: tensor {linalg.inplaceable = true}, + %s1: index, %s2: index, %cst: f32) -> tensor { + // CHECK: linalg.generic {{.*}} ins(%[[t1]], %[[t2]] : {{.*}}) outs(%[[t3]] : {{.*}}) + %r = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1)-> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%t1, %t2 : tensor, tensor) + outs(%t3 : tensor) { + ^bb0(%arg0 : f32, %arg1 : f32, %arg2 : f32) : + %add = arith.addf %arg0, %arg1 : f32 + linalg.yield %add : f32 + } -> tensor + return %r : tensor +} +