diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -457,7 +457,7 @@ namespace { /// Convert `extract_slice` operations to rank-reduced versions. -struct UseRankReducedExtractSliceOp +struct RankReducedExtractSliceOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -487,27 +487,37 @@ }; /// Convert `insert_slice` operations to rank-reduced versions. -struct UseRankReducedInsertSliceOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp. +template +struct RankReducedInsertSliceOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, + LogicalResult matchAndRewrite(InsertOpTy insertSliceOp, PatternRewriter &rewriter) const override { - RankedTensorType sourceType = insertOp.getSourceType(); - SmallVector offsets = insertOp.getMixedOffsets(); - SmallVector sizes = insertOp.getMixedSizes(); - SmallVector strides = insertOp.getMixedStrides(); + RankedTensorType sourceType = insertSliceOp.getSourceType(); + SmallVector offsets = insertSliceOp.getMixedOffsets(); + SmallVector sizes = insertSliceOp.getMixedSizes(); + SmallVector strides = insertSliceOp.getMixedStrides(); auto reassociation = getReassociationMapForFoldingUnitDims(sizes); if (!reassociation || reassociation->size() == static_cast(sourceType.getRank())) return failure(); - Location loc = insertOp.getLoc(); - auto reshapedSource = rewriter.create( - loc, insertOp.getSource(), *reassociation); - rewriter.replaceOpWithNewOp( - insertOp, reshapedSource, insertOp.getDest(), - insertOp.getMixedOffsets(), insertOp.getMixedSizes(), - insertOp.getMixedStrides()); + Location loc = insertSliceOp.getLoc(); + tensor::CollapseShapeOp reshapedSource; + { + OpBuilder::InsertionGuard g(rewriter); + // The only difference between InsertSliceOp and ParallelInsertSliceOp is + // the the insertion point is just before the ParallelCombiningOp in the + // parallel case. + if (std::is_same::value) + rewriter.setInsertionPoint(insertSliceOp->getParentOp()); + reshapedSource = rewriter.create( + loc, insertSliceOp.getSource(), *reassociation); + } + rewriter.replaceOpWithNewOp( + insertSliceOp, reshapedSource, insertSliceOp.getDest(), + insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), + insertSliceOp.getMixedStrides()); return success(); } }; @@ -518,8 +528,9 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns( RewritePatternSet &patterns) { auto *context = patterns.getContext(); - patterns.add( + patterns.add, + RankReducedInsertSliceOp>( context); linalg::FillOp::getCanonicalizationPatterns(patterns, context); linalg::InitTensorOp::getCanonicalizationPatterns(patterns, context); diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -827,3 +827,23 @@ // CHECK-LABEL: func @sparse_case // CHECK-NEXT: linalg.init_tensor // CHECK-NEXT: linalg.generic + +// ----- + +func.func @reduce_dispatch_0() -> tensor<4x2xf32> { + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = linalg.init_tensor [4, 2] : tensor<4x2xf32> + %res = scf.foreach_thread (%arg0, %arg1) in (%c4, %c2) -> (tensor<4x2xf32>) { + %1 = linalg.init_tensor [1, 1] : tensor<1x1xf32> + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32> + scf.foreach_thread.perform_concurrently { + // CHECK: tensor.parallel_insert_slice %{{[0-9a-z]*}} into %{{[0-9a-z]*}} + // CHECK-SAME: [%{{.*}}, %{{.*}}] [1, 1] [1, 1] : tensor into tensor<4x2xf32> + tensor.parallel_insert_slice %2 into %0[%arg0, %arg1] [1, 1] [1, 1] : + tensor<1x1xf32> into tensor<4x2xf32> + } + } + return %res: tensor<4x2xf32> +}