diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -104,63 +104,72 @@ /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def /// chain, starting from the OpOperand and always following the aliasing /// OpOperand, that eventually ends at a single tensor::EmptyOp. +/// +/// E.g.: +/// %0 = tensor.empty() : tensor<10xf32> +/// %1 = linalg.fill ... outs(%0 : tensor<10xf32>) +/// %2 = tensor.insert_slice %0 into %t ... +/// +/// In the above example, the anchor is the source operand of the insert_slice +/// op. When tracing back the reverse use-def chain, we end up at a +/// tensor.empty op. LogicalResult mlir::bufferization::eliminateEmptyTensors( RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state, AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) { OpBuilder::InsertionGuard g(rewriter); - WalkResult status = op->walk([&](Operation *op) { + op->walk([&](Operation *op) { for (OpOperand &operand : op->getOpOperands()) { // Skip operands that do not bufferize inplace. if (!state.isInPlace(operand)) continue; // All values that are needed to create the replacement op. SmallVector neededValues; - // Is this a matching OpOperand? + // Is this an anchor? if (!anchorMatchFunc(operand, neededValues)) continue; - SetVector maybeEmptyTensor = state.findValueInReverseUseDefChain( - operand.get(), /*condition=*/[&](Value val) { return false; }, - /*followEquivalentOnly=*/true); - // Replace only if the reverse use-def chain ends at exactly one - // tensor::EmptyOp. - if (maybeEmptyTensor.size() != 1 || - !maybeEmptyTensor.front().getDefiningOp()) - continue; - Value emptyTensor = maybeEmptyTensor.front(); + // Find tensor.empty ops on the reverse SSA use-def chain. Only follow + // equivalent tensors. I.e., stop when there are ops such as extract_slice + // on the path. + SetVector emptyTensors = state.findValueInReverseUseDefChain( + operand.get(), /*condition=*/ + [&](Value val) { return val.getDefiningOp(); }, + /*followEquivalentOnly=*/true, /*alwaysIncludeLeaves=*/false); - // Replace only if the types match. - // TODO: This could be extended to support IR such as: - // %0 = tensor.empty() : tensor<128xf32> - // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>) - // %2 = tensor.expand_shape %1 ... - // %3 = tensor.insert_slice %2 into ... - if (emptyTensor.getType() != operand.get().getType()) - continue; + for (Value v : emptyTensors) { + Operation *emptyTensorOp = v.getDefiningOp(); - // Find a suitable insertion point. - Operation *insertionPoint = - findValidInsertionPoint(emptyTensor.getDefiningOp(), neededValues); - if (!insertionPoint) - continue; + // Replace only if the types match. We do not support slices or casts. + // TODO: This could be extended to support IR such as: + // %0 = tensor.empty() : tensor<128xf32> + // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>) + // %2 = tensor.expand_shape %1 ... + // %3 = tensor.insert_slice %2 into ... + if (v.getType() != operand.get().getType()) + continue; - // Create a replacement for the tensor::EmptyOp. - rewriter.setInsertionPoint(insertionPoint); - Value replacement = rewriteFunc(rewriter, emptyTensor.getLoc(), operand); - if (!replacement) - continue; + // Find a suitable insertion point. If no suitable insertion point for + // the replacement can be found, skip this replacement. + Operation *insertionPoint = + findValidInsertionPoint(emptyTensorOp, neededValues); + if (!insertionPoint) + continue; - // Replace the tensor::EmptyOp. - rewriter.replaceOp(emptyTensor.getDefiningOp(), replacement); - state.resetCache(); - } + rewriter.setInsertionPoint(insertionPoint); + Value replacement = + rewriteFunc(rewriter, emptyTensorOp->getLoc(), operand); + if (!replacement) + continue; - // Advance to the next operation. - return WalkResult::advance(); + // Replace the tensor::EmptyOp. + rewriter.replaceOp(emptyTensorOp, replacement); + state.resetCache(); + } + } }); - return failure(status.wasInterrupted()); + return success(); } /// Try to eliminate tensor::EmptyOps inside `op`. An tensor::EmptyOp can be @@ -253,6 +262,7 @@ void EmptyTensorElimination::runOnOperation() { Operation *op = getOperation(); OneShotBufferizationOptions options; + options.allowReturnAllocs = true; OneShotAnalysisState state(op, options); if (failed(analyzeOp(op, state))) { signalPassFailure(); diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir @@ -169,3 +169,38 @@ return %r1: tensor } + +// ----- + +// CHECK-LABEL: func @eleminate_multiple_ops( +// CHECK-SAME: %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref +// CHECK-SAME: %[[sz:[0-9a-zA-Z]*]]: index +func.func @eleminate_multiple_ops(%t: tensor {bufferization.buffer_layout = affine_map<(d0) -> (d0)>}, %sz: index, %c: i1) + -> (tensor) +{ + %cst1 = arith.constant 0.0: f32 + %cst2 = arith.constant 1.0: f32 + + // CHECK: %[[r:.*]] = scf.if %{{.*}} -> (memref + %if = scf.if %c -> tensor { + // CHECK: %[[T_SUBVIEW_1:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1] + %a1 = tensor.empty(%sz) : tensor + // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW_1]] : memref) -> tensor + // CHECK: scf.yield %[[T_SUBVIEW_1]] + scf.yield %f1 : tensor + } else { + // CHECK: %[[T_SUBVIEW_2:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1] + %a2 = tensor.empty(%sz) : tensor + // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW_2]] : memref) -> tensor + // CHECK: scf.yield %[[T_SUBVIEW_2]] + scf.yield %f2 : tensor + } + + // Self-copy could canonicalize away later. + // CHECK: %[[T_SUBVIEW_3:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1] + // CHECK: memref.copy %[[r]], %[[T_SUBVIEW_3]] + %r1 = tensor.insert_slice %if into %t[42][%sz][1]: tensor into tensor + return %r1: tensor +}