diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1718,7 +1718,9 @@ static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) { auto prevInsertOp = insertOp.getDest().getDefiningOp(); - auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; }; + auto isSame = [](OpFoldResult a, OpFoldResult b) { + return getConstantIntValue(a) == getConstantIntValue(b); + }; if (!prevInsertOp || prevInsertOp.getSource().getType() != insertOp.getSource().getType() || !prevInsertOp.isSameAs(insertOp, isSame)) @@ -1728,6 +1730,26 @@ return success(); } +/// Folds round-trip extract/insert slice op pairs. +/// Example: +/// ```mlir +/// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] +/// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] +/// ``` +/// can be folded into %val. +static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) { + auto extractOp = insertOp.getSource().getDefiningOp(); + + auto isSame = [](OpFoldResult a, OpFoldResult b) { + return getConstantIntValue(a) == getConstantIntValue(b); + }; + if (!extractOp || extractOp.getSource() != insertOp.getDest() || + !extractOp.isSameAs(insertOp, isSame)) + return nullptr; + + return extractOp.getSource(); +} + OpFoldResult InsertSliceOp::fold(ArrayRef) { if (getSourceType().hasStaticShape() && getType().hasStaticShape() && getSourceType() == getType() && @@ -1735,6 +1757,8 @@ return this->getSource(); if (succeeded(foldInsertAfterInsertSlice(*this))) return getResult(); + if (auto result = foldInsertAfterExtractSlice(*this)) + return result; return OpFoldResult(); } diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -639,7 +639,7 @@ %c1 = arith.constant 1: index %0 = tensor.insert_slice %slice1 into %input[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE2]] into %[[INPUT]] - %1 = tensor.insert_slice %slice2 into %0[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor + %1 = tensor.insert_slice %slice2 into %0[0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor // CHECK: return %[[INSERT]] return %1 : tensor } @@ -1443,7 +1443,7 @@ // ----- // CHECK-LABEL: func.func @canonicalize_parallel_insert_slice_indices( -// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>, +// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>, // CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor, // CHECK-SAME: %[[num_threads:[0-9a-z]*]]: index func.func @canonicalize_parallel_insert_slice_indices( @@ -1470,7 +1470,7 @@ // ----- // CHECK-LABEL: func.func @dont_fold_parallel_insert_slice( -// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>, +// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>, // CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor<1x5xf32>) func.func @dont_fold_parallel_insert_slice( %arg0 : tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> @@ -1487,3 +1487,39 @@ } return %2 : tensor<1x5xf32> } + +// ----- + +// CHECK-LABEL: func.func @fold_insert_slice_after_extract_slice +// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x2x2x4xf32>) +func.func @fold_insert_slice_after_extract_slice(%input: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + %0 = tensor.extract_slice %input[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + %1 = tensor.insert_slice %0 into %input[%c0, 0, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + // CHECK: return %[[INPUT]] + return %1: tensor<1x2x2x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @dont_fold_mismatched_source_dst +func.func @dont_fold_mismatched_source_dst(%input0: tensor<1x2x2x4xf32>, %input1: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + // CHECK: tensor.extract_slice + %0 = tensor.extract_slice %input0[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + // CHECK: tensor.insert_slice + %1 = tensor.insert_slice %0 into %input1[%c0, 0, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + return %1: tensor<1x2x2x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @dont_fold_mismatched_parameters +func.func @dont_fold_mismatched_parameters(%input: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + // CHECK: tensor.extract_slice + %0 = tensor.extract_slice %input[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + // CHECK: tensor.insert_slice + %1 = tensor.insert_slice %0 into %input[%c0, 1, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + return %1: tensor<1x2x2x4xf32> +}