diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -390,8 +390,12 @@ /// In the above example, Values with a star satisfy the condition. When /// starting the traversal from Value 1, the resulting SetVector is: /// { 2, 7, 8, 5 } - SetVector findValueInReverseUseDefChain( - Value value, llvm::function_ref condition) const; + /// + /// If `followEquivalentOnly` is set, only equivalent OpOperands are selected. + SetVector + findValueInReverseUseDefChain(Value value, + llvm::function_ref condition, + bool followEquivalentOnly = false) const; /// Find the Values of the last preceding write of a given Value. /// diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -398,7 +398,8 @@ // evaluates to true. OpOperands of such matching Values are not traversed any // further. llvm::SetVector AnalysisState::findValueInReverseUseDefChain( - Value value, llvm::function_ref condition) const { + Value value, llvm::function_ref condition, + bool followEquivalentOnly) const { llvm::SetVector result, workingSet; workingSet.insert(value); @@ -410,8 +411,19 @@ } OpResult opResult = value.cast(); + BufferizableOpInterface bufferizableOp = + options.dynCastBufferizableOp(opResult.getDefiningOp()); SmallVector opOperands = getAliasingOpOperand(opResult); - if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) { + + // Stop iterating in either one of these cases: + // * The current op is not bufferizable or excluded in the filter. + // * There are no OpOperands to follow. + // * There is an OpOperand, but it is not an equivalent tensor (only if + // `followEquivalentOnly` is set). + if (!bufferizableOp || opOperands.empty() || + (followEquivalentOnly && + bufferizableOp.bufferRelation(opResult, *this) != + BufferRelation::Equivalent)) { result.insert(value); continue; } diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -611,9 +611,9 @@ /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. /// equivalent operand / result and same offset/sizes/strides specification). template -static bool areEquivalentExtractSliceOps(const AnalysisState &state, - ExtractSliceOp extractSliceOp, - OpTy insertSliceOp) { +static bool areEquivalentSlices(const AnalysisState &state, + ExtractSliceOp extractSliceOp, + OpTy insertSliceOp) { if (!extractSliceOp || !insertSliceOp) return false; if (extractSliceOp != insertSliceOp && @@ -626,20 +626,31 @@ return true; } -/// Return true if `value` is originating from an ExtractSliceOp that matches -/// the given InsertSliceOp. +/// Return true if `value` is originating from the InsertSliceOp's destination +/// or an ExtractSliceOp that matches the given InsertSliceOp. template -static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, - OpTy insertSliceOp) { - auto condition = [&](Value val) { +static bool matchesInsertDestination(const AnalysisState &state, Value value, + OpTy insertSliceOp) { + // Look for matching slices. + auto matchesSlice = [&](Value val) { if (auto extractSliceOp = val.getDefiningOp()) - if (areEquivalentExtractSliceOps(state, extractSliceOp, insertSliceOp)) + if (areEquivalentSlices(state, extractSliceOp, insertSliceOp)) return true; return false; }; + if (llvm::all_of(state.findValueInReverseUseDefChain(value, matchesSlice), + matchesSlice)) + return true; - return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), - condition); + // Look for equivalent values. + auto isEquivalent = [&](Value val) { + return state.areEquivalentBufferizedValues(val, insertSliceOp.getDest()); + }; + if (llvm::all_of(state.findValueInReverseUseDefChain( + value, isEquivalent, /*followEquivalentOnly=*/true), + isEquivalent)) + return true; + return false; } template @@ -661,8 +672,8 @@ // TODO: Use insertSliceOp.getDestOpOperand etc. when available. if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && - hasMatchingExtractSliceOp(state, uConflictingWrite->get(), - insertSliceOp)) + matchesInsertDestination(state, uConflictingWrite->get(), + insertSliceOp)) // Case 1: The main insight is that InsertSliceOp reads only part of // the destination tensor. The overwritten area is not read. If // uConflictingWrite writes into exactly the memory location that is @@ -679,7 +690,7 @@ if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && - hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) + matchesInsertDestination(state, uRead->get(), insertSliceOp)) // Case 2: The read of the source tensor and the write to the dest // tensor via an InsertSliceOp is not a conflict if the read is // reading exactly that part of an equivalent tensor that the @@ -712,8 +723,8 @@ if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && state.areEquivalentBufferizedValues(uRead->get(), insertSliceOp.getSource()) && - hasMatchingExtractSliceOp(state, insertSliceOp.getSource(), - insertSliceOp)) + matchesInsertDestination(state, insertSliceOp.getSource(), + insertSliceOp)) return true; return false; diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -126,15 +126,12 @@ // ----- -// CHECK-LABEL: func @tensor_cast_not_in_place( -// CHECK-SAME: %[[A:.*]]: memref, %[[B:.*]]: memref -// CHECK: %[[alloc:.*]] = memref.alloc -// CHECK: memref.copy %[[A]], %[[alloc]] +// CHECK-LABEL: func @tensor_cast_in_place( +// CHECK-SAME: %[[A:.*]]: memref // CHECK: %[[subview:.*]] = memref.subview %[[A]][{{.*}}] [4] [1] : {{.*}} to memref<4xf32 -// CHECK: memref.copy %[[alloc]], %[[subview]] -func.func @tensor_cast_not_in_place( - %A : tensor {bufferization.writable = true}, - %B : tensor {bufferization.writable = false}, %idx: index) +// CHECK: memref.copy %[[A]], %[[subview]] +func.func @tensor_cast_in_place( + %A : tensor {bufferization.writable = true}, %idx: index) -> (tensor) { %r0 = tensor.cast %A : tensor to tensor<4xf32> @@ -243,3 +240,16 @@ %r = tensor.extract %0[%idx, %idx] : tensor return %r : index } + +// ----- + +// CHECK-LABEL: func @insert_equivalent_tensor +func.func @insert_equivalent_tensor(%t: tensor<10xf32>) -> tensor<10xf32> { + // CHECK-NOT: memref.alloc + %cst = arith.constant 4.200000e+01 : f32 + // CHECK: linalg.fill + %0 = linalg.fill ins(%cst : f32) outs(%t : tensor<10xf32>) -> tensor<10xf32> + // CHECK-NOT: memref.copy + %1 = tensor.insert_slice %0 into %t[0][10][1] : tensor<10xf32> into tensor<10xf32> + return %1 : tensor<10xf32> +}