diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -99,6 +99,15 @@ return nullptr; } +namespace { +/// A replacement for a tensor.empty op. +struct EmptyTensorReplacement { + OpOperand *operand; + SmallVector neededValues; + bool isEquivalent; +}; +} // namespace + /// Try to eliminate tensor::EmptyOps inside `op`. A tensor::EmptyOp is replaced /// with the result of `rewriteFunc` if it is anchored on a matching /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def @@ -108,8 +117,9 @@ RewriterBase &rewriter, Operation *op, AnalysisState &state, AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) { OpBuilder::InsertionGuard g(rewriter); + DenseMap emptyTensors; - WalkResult status = op->walk([&](Operation *op) { + op->walk([&](Operation *op) { for (OpOperand &operand : op->getOpOperands()) { // Skip operands that do not bufferize inplace. if (!state.isInPlace(operand)) @@ -119,47 +129,64 @@ // Is this a matching OpOperand? if (!anchorMatchFunc(operand, neededValues)) continue; - SetVector maybeEmptyTensor = state.findValueInReverseUseDefChain( + SetVector maybeEmptyTensors = state.findValueInReverseUseDefChain( operand.get(), /*condition=*/[&](Value val) { return false; }, /*followEquivalentOnly=*/true); - // Replace only if the reverse use-def chain ends at exactly one - // tensor::EmptyOp. - if (maybeEmptyTensor.size() != 1 || - !maybeEmptyTensor.front().getDefiningOp()) - continue; - Value emptyTensor = maybeEmptyTensor.front(); - - // Replace only if the types match. - // TODO: This could be extended to support IR such as: - // %0 = tensor.empty() : tensor<128xf32> - // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>) - // %2 = tensor.expand_shape %1 ... - // %3 = tensor.insert_slice %2 into ... - if (emptyTensor.getType() != operand.get().getType()) - continue; + for (Value v : maybeEmptyTensors) { + // Replace only if the reverse use-def chain ends at exactly one + // tensor::EmptyOp. + auto emptyTensor = v.getDefiningOp(); + if (!emptyTensor) + continue; - // Find a suitable insertion point. - Operation *insertionPoint = - findValidInsertionPoint(emptyTensor.getDefiningOp(), neededValues); - if (!insertionPoint) - continue; + // Replace only if the types match. + // TODO: This could be extended to support IR such as: + // %0 = tensor.empty() : tensor<128xf32> + // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>) + // %2 = tensor.expand_shape %1 ... + // %3 = tensor.insert_slice %2 into ... + if (v.getType() != operand.get().getType()) + continue; - // Create a replacement for the tensor::EmptyOp. - rewriter.setInsertionPoint(insertionPoint); - Value replacement = rewriteFunc(rewriter, emptyTensor.getLoc(), operand); - if (!replacement) - continue; + // Check if there is already a good replacement for this op. + auto it = emptyTensors.find(emptyTensor.getOperation()); + if (it != emptyTensors.end() && it->second.isEquivalent) + continue; - // Replace the tensor::EmptyOp. - rewriter.replaceOp(emptyTensor.getDefiningOp(), replacement); - } + // Find a suitable insertion point. + Operation *insertionPoint = + findValidInsertionPoint(emptyTensor.getOperation(), neededValues); + if (!insertionPoint) + continue; - // Advance to the next operation. - return WalkResult::advance(); + bool isEquivalent = + state.areEquivalentBufferizedValues(v, operand.get()); + emptyTensors[emptyTensor.getOperation()] = + EmptyTensorReplacement{&operand, neededValues, isEquivalent}; + } + } }); - return failure(status.wasInterrupted()); + for (const auto &it : emptyTensors) { + // Find a suitable insertion point. + Operation *insertionPoint = + findValidInsertionPoint(it.first, it.second.neededValues); + if (!insertionPoint) + continue; + + // Create a replacement for the tensor::EmptyOp. + rewriter.setInsertionPoint(insertionPoint); + Value replacement = + rewriteFunc(rewriter, it.first->getLoc(), *it.second.operand); + if (!replacement) + continue; + + // Replace the tensor::EmptyOp. + rewriter.replaceOp(it.first, replacement); + } + + return success(); } /// Try to eliminate tensor::EmptyOps inside `op`. An tensor::EmptyOp can be @@ -252,6 +279,7 @@ void EmptyTensorElimination::runOnOperation() { Operation *op = getOperation(); OneShotBufferizationOptions options; + options.allowReturnAllocs = true; OneShotAnalysisState state(op, options); if (failed(analyzeOp(op, state))) { signalPassFailure(); diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir @@ -169,3 +169,38 @@ return %r1: tensor } + +// ----- + +// CHECK-LABEL: func @eleminate_multiple_ops( +// CHECK-SAME: %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref +// CHECK-SAME: %[[sz:[0-9a-zA-Z]*]]: index +func.func @eleminate_multiple_ops(%t: tensor {bufferization.buffer_layout = affine_map<(d0) -> (d0)>}, %sz: index, %c: i1) + -> (tensor) +{ + %cst1 = arith.constant 0.0: f32 + %cst2 = arith.constant 1.0: f32 + + // CHECK: %[[r:.*]] = scf.if %{{.*}} -> (memref + %if = scf.if %c -> tensor { + // CHECK: %[[T_SUBVIEW_1:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1] + %a1 = tensor.empty(%sz) : tensor + // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW_1]] : memref) -> tensor + // CHECK: scf.yield %[[T_SUBVIEW_1]] + scf.yield %f1 : tensor + } else { + // CHECK: %[[T_SUBVIEW_2:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1] + %a2 = tensor.empty(%sz) : tensor + // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW_2]] : memref) -> tensor + // CHECK: scf.yield %[[T_SUBVIEW_2]] + scf.yield %f2 : tensor + } + + // Self-copy could canonicalize away later. + // CHECK: %[[T_SUBVIEW_3:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1] + // CHECK: memref.copy %[[r]], %[[T_SUBVIEW_3]] + %r1 = tensor.insert_slice %if into %t[42][%sz][1]: tensor into tensor + return %r1: tensor +}