diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1207,7 +1207,6 @@ ]; let hasCanonicalizer = 1; - let hasFolder = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1552,7 +1552,6 @@ /// If we have two consecutive InsertSliceOp writing to the same slice, we /// can mutate the second InsertSliceOp's destination to the first one's. -/// This works similarly when the second op is a ParallelInsertSliceOp. /// /// Example: /// @@ -1568,9 +1567,8 @@ /// ``` /// /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp. -template -static LogicalResult foldInsertAfterInsertSlice(InsertOpTy insertOp) { - auto prevInsertOp = insertOp.getDest().template getDefiningOp(); +static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) { + auto prevInsertOp = insertOp.getDest().getDefiningOp(); auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; }; if (!prevInsertOp || @@ -1582,32 +1580,14 @@ return success(); } -/// Same logic for folding InsertSliceOp and ParallelInsertSliceOp, the return -/// type varies though so we wrap it in a FailureOr. -/// -/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp. -template -FailureOr foldInsertOp(InsertOpTy insertOp, ArrayRef) { - if (insertOp.getSourceType().hasStaticShape() && - insertOp.getDestType().hasStaticShape() && - insertOp.getSourceType() == insertOp.getDestType() && - succeeded(foldIdentityOffsetSizeAndStrideOpInterface( - insertOp, insertOp.getDestType()))) - return static_cast(insertOp.getSource()); - if (succeeded(foldInsertAfterInsertSlice(insertOp))) { - // InsertSliceOp has 1 result but ParallelInsertSliceOp has none and should - // return OpFoldResult(). - if (std::is_same::value) - return static_cast(insertOp->getResult(0)); - else - return OpFoldResult(); - } - return failure(); -} - -OpFoldResult InsertSliceOp::fold(ArrayRef operands) { - auto maybeOpFoldResult = foldInsertOp(*this, operands); - return failed(maybeOpFoldResult) ? OpFoldResult() : *maybeOpFoldResult; +OpFoldResult InsertSliceOp::fold(ArrayRef) { + if (getSourceType().hasStaticShape() && getType().hasStaticShape() && + getSourceType() == getType() && + succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) + return this->getSource(); + if (succeeded(foldInsertAfterInsertSlice(*this))) + return getResult(); + return OpFoldResult(); } LogicalResult InsertSliceOp::reifyResultShapes( @@ -2319,58 +2299,6 @@ return produceSliceErrorMsg(result, *this, expectedType); } -namespace { -/// Pattern to rewrite a parallel_insert_slice op with constant arguments. -class ParallelInsertSliceOpConstantArgumentFolder final - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp, - PatternRewriter &rewriter) const override { - // No constant operand, just return. - if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) { - return matchPattern(operand, matchConstantIndex()); - })) - return failure(); - - // At least one of offsets/sizes/strides is a new constant. - // Form the new list of operands and constant attributes from the - // existing. - SmallVector mixedOffsets(insertSliceOp.getMixedOffsets()); - SmallVector mixedSizes(insertSliceOp.getMixedSizes()); - SmallVector mixedStrides(insertSliceOp.getMixedStrides()); - canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset); - canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic); - canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset); - - // Create the new op in canonical form. - auto sourceType = - tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( - insertSliceOp.getSourceType().getRank(), - insertSliceOp.getDestType(), mixedOffsets, mixedSizes, - mixedStrides); - Value toInsert = insertSliceOp.getSource(); - if (sourceType != insertSliceOp.getSourceType()) { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(insertSliceOp->getParentOp()); - toInsert = rewriter.create(insertSliceOp.getLoc(), - sourceType, toInsert); - } - rewriter.replaceOpWithNewOp( - insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets, - mixedSizes, mixedStrides); - return success(); - } -}; -} // namespace - -LogicalResult -ParallelInsertSliceOp::fold(ArrayRef operands, - SmallVectorImpl &results) { - return foldInsertOp(*this, operands); -} - void ParallelInsertSliceOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { results.add, diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1466,3 +1466,24 @@ } return %2 : tensor } + +// ----- + +// CHECK-LABEL: func.func @dont_fold_parallel_insert_slice( +// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>, +// CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor<1x5xf32>) +func.func @dont_fold_parallel_insert_slice( + %arg0 : tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> +{ + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // CHECK: scf.foreach_thread () in () -> (tensor<1x5xf32>) { + // CHECK-NEXT: scf.foreach_thread.perform_concurrently { + // CHECK-NEXT: tensor.parallel_insert_slice %[[arg0]] into %[[arg1]][0, 0] [1, 5] [1, 1] : tensor<1x5xf32> into tensor<1x5xf32> + %2 = scf.foreach_thread () in () -> (tensor<1x5xf32>) { + scf.foreach_thread.perform_concurrently { + tensor.parallel_insert_slice %arg0 into %arg1[%c0, %c0] [1, 5] [%c1, %c1] : tensor<1x5xf32> into tensor<1x5xf32> + } + } + return %2 : tensor<1x5xf32> +}