diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1041,10 +1041,27 @@ return success(); } +/// If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, +/// we can return the InsertSliceOp's source directly. +// TODO: This only checks the immediate producer; extend to go up the +// insert/extract chain if the slices are disjoint. +static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) { + auto insertOp = extractOp.source().getDefiningOp(); + + auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; }; + if (insertOp && insertOp.source().getType() == extractOp.getType() && + insertOp.isSameAs(extractOp, isSame)) + return insertOp.source(); + + return {}; +} + OpFoldResult ExtractSliceOp::fold(ArrayRef) { if (getSourceType() == getType() && succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) return this->source(); + if (Value slice = foldExtractAfterInsertSlice(*this)) + return slice; return OpFoldResult(); } @@ -1085,11 +1102,41 @@ build(b, result, source, dest, offsetValues, sizeValues, strideValues); } +/// If we have two consecutive InsertSliceOp writing to the same slice, we +/// can mutate the second InsertSliceOp's destination to the first one's. +/// +/// Example: +/// +/// ```mlir +/// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1] +/// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1] +/// ``` +/// +/// folds into: +/// +/// ```mlir +/// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1] +/// ``` +static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) { + auto prevInsertOp = insertOp.dest().getDefiningOp(); + + auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; }; + if (!prevInsertOp || + prevInsertOp.source().getType() != insertOp.source().getType() || + !prevInsertOp.isSameAs(insertOp, isSame)) + return failure(); + + insertOp.destMutable().assign(prevInsertOp.dest()); + return success(); +} + OpFoldResult InsertSliceOp::fold(ArrayRef) { if (getSourceType().hasStaticShape() && getType().hasStaticShape() && getSourceType() == getType() && succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) return this->source(); + if (succeeded(foldInsertAfterInsertSlice(*this))) + return getResult(); return OpFoldResult(); } diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -409,7 +409,8 @@ // CHECK-DAG: tensor.insert_slice %[[STI2]] into %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor into tensor // Does not hoist, 2 slice / insert_slice for %arg8. %sti2 = tensor.insert_slice %w2 into %arg8[%i, %c0][%step, %step][1, 1] : tensor into tensor - %st22 = tensor.extract_slice %sti2[%i, %c0][%step, %step][1, 1] : tensor to tensor + // Extract with a different stride to make sure we cannot fold this extract with the above insert. + %st22 = tensor.extract_slice %sti2[%i, %c0][%step, %step][2, 1] : tensor to tensor %sti22 = tensor.insert_slice %st22 into %arg8[%i, %c0][%step, %step][1, 1] : tensor into tensor // CHECK: scf.yield {{.*}} : tensor, tensor, vector<1xf32> diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -532,3 +532,30 @@ : tensor into tensor return %r : tensor } + +// ----- + +// CHECK-LABEL: func @fold_extract_insert +// CHECK-SAME: %{{.+}}: tensor, %[[SLICE:.+]]: tensor<4x?x8xf32> +func @fold_extract_insert(%input : tensor, %slice: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<4x?x8xf32>) { + %c0 = constant 0: index + %c1 = constant 1: index + %0 = tensor.insert_slice %slice into %input[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor + %1 = tensor.extract_slice %0[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor to tensor<4x?x8xf32> + // CHECK: return %[[SLICE]] + return %1 : tensor<4x?x8xf32> +} + +// ----- + +// CHECK-LABEL: func @fold_overlapping_insert +// CHECK-SAME: %[[INPUT:.+]]: tensor, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32> +func @fold_overlapping_insert(%input : tensor, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor) { + %c0 = constant 0: index + %c1 = constant 1: index + %0 = tensor.insert_slice %slice1 into %input[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor + // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE2]] into %[[INPUT]] + %1 = tensor.insert_slice %slice2 into %0[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor + // CHECK: return %[[INSERT]] + return %1 : tensor +}