diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -482,6 +482,7 @@ scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx); ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx); tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx); + tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx); memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx); tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); memref::ViewOp::getCanonicalizationPatterns(patterns, ctx); @@ -532,7 +533,7 @@ MLIRContext *ctx = funcOp.getContext(); RewritePatternSet patterns(ctx); patterns.add(patterns.getContext(), options); - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); (void)applyPatternsAndFoldGreedily( funcOp, getLinalgTilingCanonicalizationPatterns(ctx)); @@ -540,6 +541,13 @@ funcOp.walk([](PadTensorOp op) { op->removeAttr(LinalgTransforms::kLinalgTransformMarker); }); + // Apply swap pattern after generating loop nest and running + // canonicalizations. + patterns.clear(); + patterns.add(patterns.getContext()); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsAndFoldGreedily( + funcOp, getLinalgTilingCanonicalizationPatterns(ctx)); } namespace { diff --git a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir --- a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir +++ b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir @@ -92,3 +92,33 @@ } : tensor<7x9xf32> to tensor<15x16xf32> return %0 : tensor<15x16xf32> } + +// ----- + +// TILE1-LABEL: func @static_pad_tile_evenly( +// TILE1-SAME: %[[IN:.*]]: tensor<7x9xf32>, %[[OUT:.*]]: tensor<14x15xf32> +// TILE1-DAG: %[[C0:.*]] = constant 0 : index +// TILE1-DAG: %[[C3:.*]] = constant 3 : index +// TILE1-DAG: %[[C15:.*]] = constant 15 : index +// TILE1: %[[RESULT:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C15]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] = +// TILE1: %[[R2:.*]] = scf.if +// TILE1: %[[GEN:.*]] = tensor.generate +// TILE1: scf.yield %[[GEN]] : tensor<14x3xf32> +// TILE1: else +// TILE1: %[[SLICE:.*]] = tensor.extract_slice %arg0[0, %{{.*}}] [7, %{{.*}}] [1, 1] : tensor<7x9xf32> to tensor<7x?xf32> +// TILE1: %[[PAD:.*]] = linalg.pad_tensor %8 low[0, 0] high[7, %{{.*}}] +// TILE1: %[[CAST:.*]] = tensor.cast %[[PAD]] : tensor<14x?xf32> to tensor<14x3xf32> +// TILE1: scf.yield %[[CAST]] : tensor<14x3xf32> +// TILE1: %[[R3:.*]] = tensor.insert_slice %[[R2]] into %[[INNER_OUT]][0, %[[IV]]] [14, 3] [1, 1] : tensor<14x3xf32> into tensor<14x15xf32> +// TILE1: scf.yield %[[R3]] : tensor<14x15xf32> +// TILE1: return %[[RESULT]] : tensor<14x15xf32> +func @static_pad_tile_evenly(%input_tensor: tensor<7x9xf32>, + %output_tensor: tensor<14x15xf32>, + %pad_value: f32) -> tensor<14x15xf32> { + %0 = linalg.pad_tensor %input_tensor + low[0, 0] high[7, 6] into %output_tensor { + ^bb0(%arg1: index, %arg2: index): + linalg.yield %pad_value : f32 + } : tensor<7x9xf32> to tensor<14x15xf32> + return %0 : tensor<14x15xf32> +}