diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1041,10 +1041,23 @@ return success(); } +static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) { + auto insertOp = extractOp.source().getDefiningOp(); + + auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; }; + if (insertOp && insertOp.source().getType() == extractOp.getType() && + insertOp.isSameAs(extractOp, isSame)) + return insertOp.source(); + + return {}; +} + OpFoldResult ExtractSliceOp::fold(ArrayRef) { if (getSourceType() == getType() && succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) return this->source(); + if (Value slice = foldExtractAfterInsertSlice(*this)) + return slice; return OpFoldResult(); } @@ -1242,12 +1255,50 @@ return success(); } }; + +/// If we have two consecutive InsertSliceOp writing to the same slice, the +/// first one can be removed. +/// +/// Example: +/// +/// ```mlir +/// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1] +/// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1] +/// ``` +/// +/// folds into: +/// +/// ```mlir +/// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1] +/// ``` +struct OverlappingInsertSliceOpRemover final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InsertSliceOp insertOp, + PatternRewriter &rewriter) const override { + auto prevInsertOp = insertOp.dest().getDefiningOp(); + + auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; }; + if (!prevInsertOp || + prevInsertOp.source().getType() != insertOp.source().getType() || + !prevInsertOp.isSameAs(insertOp, isSame)) + return failure(); + + rewriter.replaceOpWithNewOp( + insertOp, insertOp.source(), prevInsertOp.dest(), + insertOp.getMixedOffsets(), insertOp.getMixedSizes(), + insertOp.getMixedStrides()); + return success(); + } +}; } // namespace void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + InsertSliceOpSourceCastInserter, OverlappingInsertSliceOpRemover>( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -409,7 +409,8 @@ // CHECK-DAG: tensor.insert_slice %[[STI2]] into %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor into tensor // Does not hoist, 2 slice / insert_slice for %arg8. %sti2 = tensor.insert_slice %w2 into %arg8[%i, %c0][%step, %step][1, 1] : tensor into tensor - %st22 = tensor.extract_slice %sti2[%i, %c0][%step, %step][1, 1] : tensor to tensor + // Extract with a different stride to make sure we cannot fold this extract with the above insert. + %st22 = tensor.extract_slice %sti2[%i, %c0][%step, %step][2, 1] : tensor to tensor %sti22 = tensor.insert_slice %st22 into %arg8[%i, %c0][%step, %step][1, 1] : tensor into tensor // CHECK: scf.yield {{.*}} : tensor, tensor, vector<1xf32> diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -532,3 +532,30 @@ : tensor into tensor return %r : tensor } + +// ----- + +// CHECK-LABEL: func @fold_extract_insert +// CHECK-SAME: %{{.+}}: tensor, %[[SLICE:.+]]: tensor<4x?x8xf32> +func @fold_extract_insert(%input : tensor, %slice: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<4x?x8xf32>) { + %c0 = constant 0: index + %c1 = constant 1: index + %0 = tensor.insert_slice %slice into %input[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor + %1 = tensor.extract_slice %0[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor to tensor<4x?x8xf32> + // CHECK: return %[[SLICE]] + return %1 : tensor<4x?x8xf32> +} + +// ----- + +// CHECK-LABEL: func @fold_overlapping_insert +// CHECK-SAME: %[[INPUT:.+]]: tensor, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32> +func @fold_overlapping_insert(%input : tensor, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor) { + %c0 = constant 0: index + %c1 = constant 1: index + %0 = tensor.insert_slice %slice1 into %input[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor + // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE2]] into %[[INPUT]] + %1 = tensor.insert_slice %slice2 into %0[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor + // CHECK: return %[[INSERT]] + return %1 : tensor +}