diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -1221,9 +1221,6 @@ return false; if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) return false; - // TODO: Is the following needed? - if (!equivalentInfo.isEquivalent(st.result(), sti.source())) - return false; return true; } @@ -1372,7 +1369,7 @@ b.setInsertionPointToStart(bbArg.getOwner()); loc = bbArg.getOwner()->getParentOp()->getLoc(); } else { - b.setInsertionPointAfter(shapedValue.getDefiningOp()); + b.setInsertionPoint(shapedValue.getDefiningOp()); loc = shapedValue.getDefiningOp()->getLoc(); } @@ -1795,11 +1792,15 @@ return success(); } -/// InitTensor always allocates. +/// InitTensor always allocates (unless it was eliminated). /// TODO: consider hoisting across function boundaries prior to bufferization. static LogicalResult bufferize(OpBuilder &b, InitTensorOp initTensorOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo) { + // The InitTensorOp may have been eliminated. + if (initTensorOp->getUses().empty()) + return success(); + // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(initTensorOp); @@ -2779,6 +2780,89 @@ } } +/// Try to eliminate InitTensorOps inside funcOp. An InitTensorOp can be +/// eliminated if it is eventually inserted into another tensor (and some other +/// conditions are met). +/// +/// E.g.: +/// %0 = linalg.init_tensor +/// %1 = linalg.fill(%cst, %0) {inplace = [true]} +/// %2 = tensor.insert_slice %1 into %t[10][20][1] +/// +/// InitTensorOp elimination will try to fill %t inplace instead of filling a +/// new allocation %0 and inserting it into %t. This is done by replacing the +/// InitTensorOp with: +/// +/// %0 = tensor.extract_slice %t[10][20][1] +/// +/// The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets +/// those bufferize inplace in the absence of other conflicts. +/// +/// Starting from an InsertSliceOp, an InitTensorOp at the end of the insert +/// source's reverse use-def chain is eliminated if: +/// * The InsertSliceOp was decided to bufferize inplace. +/// * On the reverse use-def chain path from the InsertSliceOp to the +/// InitTensorOp, all ops were decided to bufferize inplace and the buffer +/// relation is "equivalent" (TODO: can be relaxed if needed). +/// * The reverse use-def chain has exactly one end, which is the InitTensorOp. +/// +/// Note that the newly inserted ExtractSliceOp may have to bufferize +/// out-of-place due to RaW conflicts. +static LogicalResult runInitTensorElimination(FuncOp funcOp, + BufferizationAliasInfo &aliasInfo, + DominanceInfo &domInfo) { + OpBuilder b(funcOp->getContext()); + + WalkResult status = funcOp->walk([&](tensor::InsertSliceOp insertOp) { + // Only inplace bufferized InsertSliceOps are eligible. + if (getInPlace(insertOp->getOpResult(0)) != InPlaceSpec::True) + return WalkResult::skip(); + + SetVector maybeInitTensor = + findValueInReverseUseDefChain(insertOp.source(), [](Value val) { + // Continue traversal until this function returns true. + OpResult opResult = val.dyn_cast(); + if (!opResult) + return true; + if (getInPlace(opResult) != InPlaceSpec::True) + return true; + // Only equivalent tensors are supported at the moment. E.g., when + // taking a tensor.extract_slice of an init_tensor, we can currently + // not eliminate the init_tensor. + SmallVector opOperands = getAliasingOpOperand(opResult); + if (!llvm::all_of(opOperands, [](OpOperand *operand) { + return bufferRelation(*operand) == BufferRelation::Equivalent; + })) + return true; + return false; + }); + // Replace only if the InsertSliceOp source originates from exactly one + // InitTensorOp. + if (maybeInitTensor.size() != 1 || + !maybeInitTensor.front().getDefiningOp()) + return WalkResult::skip(); + Value initTensor = maybeInitTensor.front(); + + b.setInsertionPoint(initTensor.getDefiningOp()); + auto extractOp = b.create( + initTensor.getLoc(), insertOp.dest(), insertOp.getMixedOffsets(), + insertOp.getMixedSizes(), insertOp.getMixedStrides()); + // Uses of the InitTensorOp are replaced here, but the op is not deleted. + // InitTensorOps without uses are ignored by the bufferization. + initTensor.replaceAllUsesWith(extractOp.result()); + aliasInfo.createAliasInfoEntry(extractOp.result()); + + // Run analysis on the ExtractSliceOp. + if (failed(bufferizableInPlaceAnalysis(extractOp, aliasInfo, domInfo))) + return WalkResult::interrupt(); + + // Advance to the next operation. + return WalkResult::advance(); + }); + + return failure(status.wasInterrupted()); +} + void LinalgComprehensiveModuleBufferize::runOnOperation() { ModuleOp moduleOp = getOperation(); applyEnablingTransformations(moduleOp); @@ -2818,6 +2902,13 @@ return; } + // Try to eliminate InitTensorOps to avoid new allocations during the + // bufferization phase. + if (failed(runInitTensorElimination(funcOp, aliasInfo, domInfo))) { + signalPassFailure(); + return; + } + // Bufferization phase. if (!testAnalysisOnly) { BlockAndValueMapping tensorToBufferMap; diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -1013,3 +1013,51 @@ return %o, %v3 : tensor, vector<5xf32> } + +// ----- + +//===----------------------------------------------------------------------===// +// InitTensorOp elimination +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @buffer_forwarding_conflict +func @buffer_forwarding_conflict(%arg0: tensor {linalg.inplaceable = true}, %arg1: index) -> (tensor, tensor) { + %cst = constant 0.000000e+00 : f32 + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["false"] + // Instead of allocating, share buffer with some inplace bufferization? + %0 = linalg.init_tensor [%arg1] : tensor + + // CHECK: linalg.fill + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %1 = linalg.fill(%cst, %0) : f32, tensor -> tensor + + // CHECK: tensor.insert_slice + // CHECK-SAME: {__inplace_results_attr__ = ["false"] + %2 = tensor.insert_slice %1 into %arg0[0] [%arg1] [1] : tensor into tensor + + // CHECK: tensor.insert_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %3 = tensor.insert_slice %1 into %arg0[42] [%arg1] [1] : tensor into tensor + return %2, %3 : tensor, tensor +} + +// ----- + +// CHECK-LABEL: func @buffer_forwarding_no_conflict +func @buffer_forwarding_no_conflict(%arg0: tensor {linalg.inplaceable = true}, %arg1: index) -> (tensor, tensor) { + %cst = constant 0.000000e+00 : f32 + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + // Instead of allocating, share buffer with some inplace bufferization? + %0 = linalg.init_tensor [%arg1] : tensor + + // CHECK: linalg.fill + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %1 = linalg.fill(%cst, %0) : f32, tensor -> tensor + + // CHECK: tensor.insert_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %2 = tensor.insert_slice %1 into %arg0[42] [%arg1] [1] : tensor into tensor + return %2, %2 : tensor, tensor +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -755,3 +755,68 @@ return %r1 : tensor } +// ----- + +// CHECK: func @buffer_forwarding_conflict( +// CHECK-SAME: %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref +// CHECK-SAME: %[[sz:[0-9a-zA-Z]*]]: index +func @buffer_forwarding_conflict( + %t: tensor {linalg.buffer_layout = affine_map<(d0) -> (d0)>, linalg.inplaceable = true}, + %sz: index) + -> (tensor, tensor) +{ + %f0 = constant 0.0: f32 + // Alloc is needed for the **first** insert_slice (due to backward traversal during analysis). + // CHECK: %[[DIM:.*]] = memref.dim %[[FUNC_ARG]] + // This allocs the whole dim to allow for a full clone of t. + // CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) + + // init_tensor itself does not alloc but forwards to the **second** + // insert_slice. InitTensorOp replaces the init_tensor with an out-of-place + // extract_slice. + // CHECK: %[[EXTRACT_SLICE_ALLOC:.*]] = memref.alloc(%[[sz]]) + // CHECK: %[[T_SUBVIEW:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1] + // TODO: This copy can be avoided because the copied data is never read. + // CHECK: linalg.copy(%[[T_SUBVIEW]], %[[EXTRACT_SLICE_ALLOC]]) + %a = linalg.init_tensor[%sz] : tensor + + // CHECK: linalg.fill({{.*}}, %[[EXTRACT_SLICE_ALLOC]]) : f32, memref + %f = linalg.fill(%f0, %a) : f32, tensor -> tensor + + // CHECK: linalg.copy(%[[FUNC_ARG]], %[[ALLOC]]) : memref, memref + // CHECK: %[[SV0_ALLOC:.*]] = memref.subview %[[ALLOC]][0] [%[[sz]]] [1] : memref to memref + // CHECK: linalg.copy(%[[EXTRACT_SLICE_ALLOC]], %[[SV0_ALLOC]]) : memref, memref + %r0 = tensor.insert_slice %f into %t[0][%sz][1]: tensor into tensor + + // CHECK: linalg.copy(%[[EXTRACT_SLICE_ALLOC]], %[[T_SUBVIEW]]) + %r1 = tensor.insert_slice %f into %t[42][%sz][1]: tensor into tensor + + return %r0, %r1: tensor, tensor +} + +// ----- + +// CHECK: func @buffer_forwarding_no_conflict( +// CHECK-SAME: %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref +// CHECK-SAME: %[[sz:[0-9a-zA-Z]*]]: index +func @buffer_forwarding_no_conflict( + %t: tensor {linalg.buffer_layout = affine_map<(d0) -> (d0)>, linalg.inplaceable = true}, + %sz: index) + -> (tensor) +{ + %f0 = constant 0.0: f32 + + // init_tensor itself does not alloc but forwards to the insert_slice. + // InitTensorOp replaces the init_tensor with an inplace extract_slice. + // CHECK: %[[T_SUBVIEW:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1] + %a = linalg.init_tensor[%sz] : tensor + + // CHECK: linalg.fill({{.*}}, %[[T_SUBVIEW]]) : f32, memref -> tensor + + // Self-copy canonicalizes away later. + // CHECK: linalg.copy(%[[T_SUBVIEW]], %[[T_SUBVIEW]]) + %r1 = tensor.insert_slice %f into %t[42][%sz][1]: tensor into tensor + + return %r1: tensor +}