Index: mlir/lib/Dialect/Tensor/IR/TensorOps.cpp =================================================================== --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1582,32 +1582,14 @@ return success(); } -/// Same logic for folding InsertSliceOp and ParallelInsertSliceOp, the return -/// type varies though so we wrap it in a FailureOr. -/// -/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp. -template -FailureOr foldInsertOp(InsertOpTy insertOp, ArrayRef) { - if (insertOp.getSourceType().hasStaticShape() && - insertOp.getDestType().hasStaticShape() && - insertOp.getSourceType() == insertOp.getDestType() && - succeeded(foldIdentityOffsetSizeAndStrideOpInterface( - insertOp, insertOp.getDestType()))) - return static_cast(insertOp.getSource()); - if (succeeded(foldInsertAfterInsertSlice(insertOp))) { - // InsertSliceOp has 1 result but ParallelInsertSliceOp has none and should - // return OpFoldResult(). - if (std::is_same::value) - return static_cast(insertOp->getResult(0)); - else - return OpFoldResult(); - } - return failure(); -} - -OpFoldResult InsertSliceOp::fold(ArrayRef operands) { - auto maybeOpFoldResult = foldInsertOp(*this, operands); - return failed(maybeOpFoldResult) ? OpFoldResult() : *maybeOpFoldResult; +OpFoldResult InsertSliceOp::fold(ArrayRef) { + if (getSourceType().hasStaticShape() && getType().hasStaticShape() && + getSourceType() == getType() && + succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) + return this->getSource(); + if (succeeded(foldInsertAfterInsertSlice(*this))) + return getResult(); + return OpFoldResult(); } LogicalResult InsertSliceOp::reifyResultShapes( @@ -2368,7 +2350,7 @@ LogicalResult ParallelInsertSliceOp::fold(ArrayRef operands, SmallVectorImpl &results) { - return foldInsertOp(*this, operands); + return foldInsertAfterInsertSlice(*this); } void ParallelInsertSliceOp::getCanonicalizationPatterns( Index: mlir/test/Dialect/Tensor/canonicalize.mlir =================================================================== --- mlir/test/Dialect/Tensor/canonicalize.mlir +++ mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1466,3 +1466,24 @@ } return %2 : tensor } + +// ----- + +// CHECK-LABEL: func.func @dont_fold_parallel_insert_slice( +// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>, +// CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor<1x5xf32>) +func.func @dont_fold_parallel_insert_slice( + %arg0 : tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> +{ + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // CHECK: scf.foreach_thread () in () -> (tensor<1x5xf32>) { + // CHECK-NEXT: scf.foreach_thread.perform_concurrently { + // CHECK-NEXT: tensor.parallel_insert_slice %[[arg0]] into %[[arg1]][0, 0] [1, 5] [1, 1] : tensor<1x5xf32> into tensor<1x5xf32> + %2 = scf.foreach_thread () in () -> (tensor<1x5xf32>) { + scf.foreach_thread.perform_concurrently { + tensor.parallel_insert_slice %arg0 into %arg1[%c0, %c0] [1, 5] [%c1, %c1] : tensor<1x5xf32> into tensor<1x5xf32> + } + } + return %2 : tensor<1x5xf32> +}