diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -630,15 +630,6 @@ if (llvm::all_of(state.findValueInReverseUseDefChain(value, matchesSlice), matchesSlice)) return true; - - // Look for equivalent values. - auto isEquivalent = [&](Value val) { - return state.areEquivalentBufferizedValues(val, insertSliceOp.getDest()); - }; - if (llvm::all_of(state.findValueInReverseUseDefChain( - value, isEquivalent, /*followEquivalentOnly=*/true), - isEquivalent)) - return true; return false; } @@ -727,6 +718,36 @@ struct InsertSliceOpInterface : public DstBufferizableOpInterfaceExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + auto insertSliceOp = cast(op); + RankedTensorType destType = insertSliceOp.getDestType(); + + // The source is always read. + if (&opOperand == &op->getOpOperand(0) /*src*/) + return true; + + // For the destination, it depends... + assert(&opOperand == &insertSliceOp->getOpOperand(1) && "expected dest"); + + // Dest is not read if it is entirely overwritten. E.g.: + // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32> + bool allOffsetsZero = + llvm::all_of(insertSliceOp.getMixedOffsets(), [](OpFoldResult ofr) { + return isConstantIntValue(ofr, 0); + }); + bool sizesMatchDestSizes = llvm::all_of( + llvm::enumerate(insertSliceOp.getMixedSizes()), [&](auto &it) { + return getConstantIntValue(it.value()) == + destType.getDimSize(it.index()); + }); + bool allStridesOne = + llvm::all_of(insertSliceOp.getMixedStrides(), [](OpFoldResult ofr) { + return isConstantIntValue(ofr, 1); + }); + return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne); + } + bool isNotConflicting(Operation *op, OpOperand *uRead, OpOperand *uConflictingWrite, const AnalysisState &state) const { diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -126,12 +126,18 @@ // ----- -// CHECK-LABEL: func @tensor_cast_in_place( -// CHECK-SAME: %[[A:.*]]: memref +// This test case could bufferize in-place with a better analysis. However, it +// is simpler to let the canonicalizer fold away the tensor.insert_slice. + +// CHECK-LABEL: func @tensor_cast_not_in_place( +// CHECK-SAME: %[[A:.*]]: memref, %[[B:.*]]: memref +// CHECK: %[[alloc:.*]] = memref.alloc +// CHECK: memref.copy %[[A]], %[[alloc]] // CHECK: %[[subview:.*]] = memref.subview %[[A]][{{.*}}] [4] [1] : {{.*}} to memref<4xf32 -// CHECK: memref.copy %[[A]], %[[subview]] -func.func @tensor_cast_in_place( - %A : tensor {bufferization.writable = true}, %idx: index) +// CHECK: memref.copy %[[alloc]], %[[subview]] +func.func @tensor_cast_not_in_place( + %A : tensor {bufferization.writable = true}, + %B : tensor {bufferization.writable = false}, %idx: index) -> (tensor) { %r0 = tensor.cast %A : tensor to tensor<4xf32> @@ -241,13 +247,16 @@ // ----- +// This test case could bufferize in-place with a better analysis. However, it +// is simpler to let the canonicalizer fold away the tensor.insert_slice. + // CHECK-LABEL: func @insert_equivalent_tensor func.func @insert_equivalent_tensor(%t: tensor<10xf32>) -> tensor<10xf32> { - // CHECK-NOT: memref.alloc + // CHECK: memref.alloc %cst = arith.constant 4.200000e+01 : f32 // CHECK: linalg.fill %0 = linalg.fill ins(%cst : f32) outs(%t : tensor<10xf32>) -> tensor<10xf32> - // CHECK-NOT: memref.copy + // CHECK: memref.copy %1 = tensor.insert_slice %0 into %t[0][10][1] : tensor<10xf32> into tensor<10xf32> return %1 : tensor<10xf32> } @@ -279,3 +288,45 @@ // CHECK-DAG: memref.dealloc %[[padded_alloc]] return %2 : f32 } + +// ----- + +// CHECK-LABEL: func @insert_slice_regression( +// CHECK-SAME: %[[t:.*]]: memref<10xf32,{{.*}}>, %[[b:.*]]: memref<5xf32 +func.func @insert_slice_regression(%t: tensor<10xf32>, %b: tensor<5xf32>) -> tensor<10xf32> { + %cst = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<10xf32> + // CHECK: linalg.fill {{.*}} outs(%[[alloc]] : memref<10xf32>) + %1 = linalg.fill ins(%cst : f32) outs(%t : tensor<10xf32>) -> tensor<10xf32> + + // Read %1 so that it does not DCE away. + %vec = vector.transfer_read %1[%c0], %cst : tensor<10xf32>, vector<10xf32> + vector.print %vec : vector<10xf32> + + // Write back a different value (not %1). + // CHECK: %[[subview:.*]] = memref.subview %[[t]][0] [5] [1] + // CHECK: memref.copy %[[b]], %[[subview]] + %2 = tensor.insert_slice %b into %t[0][5][1] : tensor<5xf32> into tensor<10xf32> + return %2 : tensor<10xf32> +} + +// ----- + +// CHECK-LABEL: func @insert_slice_full_overwrite( +// CHECK-SAME: %[[t:.*]]: memref<10xf32,{{.*}}>, %[[b:.*]]: memref<10xf32,{{.*}}> +func.func @insert_slice_full_overwrite(%t: tensor<10xf32>, %b: tensor<10xf32>) -> tensor<10xf32> { + %cst = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + // CHECK: linalg.fill {{.*}} outs(%[[t]] : memref<10xf32,{{.*}}>) + %1 = linalg.fill ins(%cst : f32) outs(%t : tensor<10xf32>) -> tensor<10xf32> + + // Read %1 so that it does not DCE away. + %vec = vector.transfer_read %1[%c0], %cst : tensor<10xf32>, vector<10xf32> + vector.print %vec : vector<10xf32> + + // Write back a different value (not %1). + // CHECK: memref.copy %[[b]], %[[t]] + %2 = tensor.insert_slice %b into %t[0][10][1] : tensor<10xf32> into tensor<10xf32> + return %2 : tensor<10xf32> +} \ No newline at end of file