diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -494,6 +494,7 @@ scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx); ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx); tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx); + tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx); memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx); tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); memref::ViewOp::getCanonicalizationPatterns(patterns, ctx); @@ -513,7 +514,15 @@ #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" >::insert(patterns, options); patterns.add(patterns.getContext(), options); +} + +static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { + MLIRContext *ctx = funcOp.getContext(); + RewritePatternSet patterns(ctx); patterns.add(patterns.getContext()); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsAndFoldGreedily( + funcOp, getLinalgTilingCanonicalizationPatterns(ctx)); } static void @@ -527,6 +536,7 @@ MLIRContext *ctx = funcOp.getContext(); RewritePatternSet patterns(ctx); insertTilingPatterns(patterns, options); + patterns.add(patterns.getContext()); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); (void)applyPatternsAndFoldGreedily( funcOp, getLinalgTilingCanonicalizationPatterns(ctx)); @@ -534,6 +544,10 @@ funcOp.walk([](LinalgOp op) { op->removeAttr(LinalgTransforms::kLinalgTransformMarker); }); + + // Apply swap pattern after generating loop nest and running + // canonicalizations. + applyExtractSliceOfPadTensorSwapPattern(funcOp); } namespace { diff --git a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir --- a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir +++ b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir @@ -92,3 +92,33 @@ } : tensor<7x9xf32> to tensor<15x16xf32> return %0 : tensor<15x16xf32> } + +// ----- + +// TILE1-LABEL: func @static_pad_tile_evenly( +// TILE1-SAME: %[[IN:.*]]: tensor<7x9xf32>, %[[OUT:.*]]: tensor<14x15xf32> +// TILE1-DAG: %[[C0:.*]] = constant 0 : index +// TILE1-DAG: %[[C3:.*]] = constant 3 : index +// TILE1-DAG: %[[C15:.*]] = constant 15 : index +// TILE1: %[[RESULT:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C15]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] = +// TILE1: %[[R2:.*]] = scf.if +// TILE1: %[[GEN:.*]] = tensor.generate +// TILE1: scf.yield %[[GEN]] : tensor<14x3xf32> +// TILE1: else +// TILE1: %[[SLICE:.*]] = tensor.extract_slice %arg0[0, %{{.*}}] [7, %{{.*}}] [1, 1] : tensor<7x9xf32> to tensor<7x?xf32> +// TILE1: %[[PAD:.*]] = linalg.pad_tensor %8 low[0, 0] high[7, %{{.*}}] +// TILE1: %[[CAST:.*]] = tensor.cast %[[PAD]] : tensor<14x?xf32> to tensor<14x3xf32> +// TILE1: scf.yield %[[CAST]] : tensor<14x3xf32> +// TILE1: %[[R3:.*]] = tensor.insert_slice %[[R2]] into %[[INNER_OUT]][0, %[[IV]]] [14, 3] [1, 1] : tensor<14x3xf32> into tensor<14x15xf32> +// TILE1: scf.yield %[[R3]] : tensor<14x15xf32> +// TILE1: return %[[RESULT]] : tensor<14x15xf32> +func @static_pad_tile_evenly(%input_tensor: tensor<7x9xf32>, + %output_tensor: tensor<14x15xf32>, + %pad_value: f32) -> tensor<14x15xf32> { + %0 = linalg.pad_tensor %input_tensor + low[0, 0] high[7, 6] into %output_tensor { + ^bb0(%arg1: index, %arg2: index): + linalg.yield %pad_value : f32 + } : tensor<7x9xf32> to tensor<14x15xf32> + return %0 : tensor<14x15xf32> +} diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir --- a/mlir/test/Dialect/Linalg/tile.mlir +++ b/mlir/test/Dialect/Linalg/tile.mlir @@ -20,10 +20,6 @@ // TILE-234-DAG: #[[$bound_map_3:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)> // TILE-234-DAG: #[[$bound_map_4:.*]] = affine_map<(d0)[s0] -> (4, -d0 + s0)> -// TILE-2-DAG: #[[$bound_map_static:.*]] = affine_map<(d0) -> (2, -d0 + 10)> -// TILE-02-DAG: #[[$bound_map_static:.*]] = affine_map<(d0) -> (2, -d0 + 12)> -// TILE-002-DAG: #[[$bound_map_static:.*]] = affine_map<(d0) -> (2, -d0 + 16)> - // TILE-2-DAG: #[[$stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)> // TILE-02-DAG: #[[$stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)> // TILE-234-DAG: #[[$stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)> @@ -132,10 +128,8 @@ // TILE-2-DAG: %[[C2:.*]] = constant 2 : index // TILE-2-DAG: %[[M:.*]] = constant 10 : index // TILE-2: scf.for %[[I:.*]] = %{{.*}} to %[[M]] step %{{.*}} { -// TILE-2: %[[MIN2:.*]] = affine.min #[[$bound_map_static]](%[[I]]) -// TILE-2: %[[sAi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [%[[MIN2]], 16] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref -// TILE-2: %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[I]]) -// TILE-2: %[[sCi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [%[[MIN22]], 12] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref +// TILE-2: %[[sAi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [2, 16] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<2x16xf32, #[[$strided2D]]> +// TILE-2: %[[sCi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [2, 12] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<2x12xf32, #[[$strided2D]]> // TILE-2: linalg.matmul ins(%[[sAi]], %{{.*}}{{.*}} outs(%[[sCi]] // TILE-02-LABEL: func @matmul_static( @@ -143,10 +137,8 @@ // TILE-02-DAG: %[[C2:.*]] = constant 2 : index // TILE-02-DAG: %[[N:.*]] = constant 12 : index // TILE-02: scf.for %[[J:.*]] = %{{.*}} to %[[N]] step %{{.*}} { -// TILE-02: %[[MIN2:.*]] = affine.min #[[$bound_map_static]](%[[J]]) -// TILE-02: %[[sBj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [16, %[[MIN2]]] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<16x?xf32, #[[$strided2D]]> -// TILE-02: %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[J]]) -// TILE-02: %[[sCj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [10, %[[MIN22]]] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<10x?xf32, #[[$strided2D]]> +// TILE-02: %[[sBj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [16, 2] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<16x2xf32, #[[$strided2D]]> +// TILE-02: %[[sCj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [10, 2] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<10x2xf32, #[[$strided2D]]> // TILE-02: linalg.matmul ins(%{{.*}}, %[[sBj]]{{.*}} outs(%[[sCj]] // TILE-002-LABEL: func @matmul_static( @@ -154,10 +146,8 @@ // TILE-002-DAG: %[[C2:.*]] = constant 2 : index // TILE-002-DAG: %[[C16:.*]] = constant 16 : index // TILE-002: scf.for %[[K:.*]] = %{{.*}}{{.*}} to %[[C16]] step %{{.*}} { -// TILE-002: %[[MIN2:.*]] = affine.min #[[$bound_map_static]](%[[K]]) -// TILE-002: %[[sAj:.*]] = memref.subview %{{.*}}[0, %[[K]]] [10, %[[MIN2]]] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<10x?xf32, #[[$strided2D]]> -// TILE-002: %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[K]]) -// TILE-002: %[[sBj:.*]] = memref.subview %{{.*}}[%[[K]], 0] [%[[MIN22]], 12] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref +// TILE-002: %[[sAj:.*]] = memref.subview %{{.*}}[0, %[[K]]] [10, 2] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<10x2xf32, #[[$strided2D]]> +// TILE-002: %[[sBj:.*]] = memref.subview %{{.*}}[%[[K]], 0] [2, 12] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<2x12xf32, #[[$strided2D]]> // TILE-002: linalg.matmul ins(%[[sAj]], %[[sBj]]{{.*}} outs(%{{.*}} // TILE-234-LABEL: func @matmul_static( @@ -171,9 +161,9 @@ // TILE-234: scf.for %[[I:.*]] = %{{.*}}{{.*}} to %[[C10]] step %{{.*}} { // TILE-234: scf.for %[[J:.*]] = %{{.*}}{{.*}} to %[[C12]] step %{{.*}} { // TILE-234: scf.for %[[K:.*]] = %{{.*}}{{.*}} to %[[C16]] step %{{.*}} { -// TILE-234: %[[sAik:.*]] = memref.subview %{{.*}}[%[[I]], %[[K]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref -// TILE-234: %[[sBkj:.*]] = memref.subview %{{.*}}[%[[K]], %[[J]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref -// TILE-234: %[[sCij:.*]] = memref.subview %{{.*}}[%[[I]], %[[J]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref +// TILE-234: %[[sAik:.*]] = memref.subview %{{.*}}[%[[I]], %[[K]]] [2, 4] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<2x4xf32, #[[$strided2D]]> +// TILE-234: %[[sBkj:.*]] = memref.subview %{{.*}}[%[[K]], %[[J]]] [4, 3] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<4x3xf32, #[[$strided2D]]> +// TILE-234: %[[sCij:.*]] = memref.subview %{{.*}}[%[[I]], %[[J]]] [2, 3] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<2x3xf32, #[[$strided2D]]> // // TILE-234: linalg.matmul ins(%[[sAik]], %[[sBkj]]{{.*}} outs(%[[sCij]] @@ -312,7 +302,7 @@ // TILE-234: for // TILE-234-NOT: for // TILE-234: memref.subview{{.*}} : memref<127x99xf32> -// TILE-234: linalg.fill{{.*}} : f32, memref +// TILE-234: linalg.fill{{.*}} : f32, memref func @fill(%arg0: memref, %arg1: f32) {