diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -256,12 +256,8 @@ if (!res) return failure(); - // Setup RAII guard to return properly. - bool succeeded = true; LinalgOp tiledOp = res->op; auto guard = llvm::make_scope_exit([&]() { - if (!succeeded) - return; // Return relevant information to derived pattern. result = *res; // Replace filter on both tiledOp and tiledAndPaddedOp, if necessary. @@ -278,7 +274,6 @@ // Try to pad on the fly by rewriting res->op as a padded op. if (failed(rewriteAsPaddedOp(rewriter, *res, options))) { // Set so RAII guard does not propagate TiledLinalgOp to `result`. - succeeded = false; return failure(); } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -134,6 +134,17 @@ .getResult(0) .dyn_cast()) boundingConst = cExpr.getValue(); + } else if (auto dimOp = size.getDefiningOp()) { + // This dimension isn't tiled, so set extend bounding box to its static + // size. + auto shape = dimOp.memrefOrTensor().getType().dyn_cast(); + if (auto constOp = dimOp.index().getDefiningOp()) { + if (auto indexAttr = constOp.value().dyn_cast()) { + auto dimIndex = indexAttr.getInt(); + if (!shape.isDynamicDim(dimIndex)) + boundingConst = shape.getShape()[dimIndex]; + } + } } if (boundingConst && *boundingConst >= 0) return Builder(size.getContext()).getIndexAttr(*boundingConst); diff --git a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir @@ -1,22 +1,28 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns=test-tile-and-pad-pattern -canonicalize | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-and-pad-pattern tile-sizes-for-padding=2,3,4" -canonicalize | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-and-pad-pattern tile-sizes-for-padding=2,3" -canonicalize | FileCheck %s -check-prefix=CHECK-1DIM-TILE + +func @matmul_tensors( + %arg0: tensor, %arg1: tensor, %arg2: tensor) + -> tensor { + %0 = linalg.matmul_i8_i8_i32 {__internal_linalg_transform__ = "tile-and-pad"} + ins(%arg0, %arg1: tensor, tensor) + outs(%arg2: tensor) + -> tensor + return %0 : tensor +} // CHECK-LABEL: func @matmul_tensors( // CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor // CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor // CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor) -> tensor { -func @matmul_tensors( - %arg0: tensor, %arg1: tensor, %arg2: tensor) - -> tensor { // CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor) { // CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor) { // CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor) { // CHECK: %[[sTA:.*]] = subtensor %[[TA]][{{.*}}] : tensor to tensor // CHECK: %[[sTB:.*]] = subtensor %[[TB]][{{.*}}] : tensor to tensor // CHECK: %[[sTC:.*]] = subtensor %[[TC2]][{{.*}}] : tensor to tensor - // Dynamic op has been canonicalized away. // CHECK-NOT: linalg.matmul {{.*}} tensor - // Padding injects static information. // CHECK: %[[pA:.*]] = linalg.pad_tensor %[[sTA]] low[%c0, %c0] high[%{{.*}}, %{{.*}}] // CHECK: : tensor to tensor<2x4xi8> @@ -31,11 +37,62 @@ // CHECK: scf.yield %[[TD]] : tensor // CHECK: scf.yield %[[TD2]] : tensor // CHECK: scf.yield %[[TD1]] : tensor +// CHECK: return %[[TD0]] : tensor + +// CHECK-1DIM-TILE: func @matmul_tensors( +// CHECK-1DIM-TILE: %[[TA:[0-9a-z]+]]: tensor +// CHECK-1DIM-TILE: %[[TB:[0-9a-z]+]]: tensor +// CHECK-1DIM-TILE: %[[TC:[0-9a-z]+]]: tensor) -> tensor { +// CHECK-1DIM-TILE-NOT: scf.for +// CHECK-1DIM-TILE: linalg.matmul_i8_i8_i32 ins(%[[TA]], %[[TB]] : tensor, tensor) outs(%[[TC]] : tensor) -> tensor + +func @matmul_partially_static_tensors( + %arg0: tensor, %arg1: tensor<8x?xi8>, %arg2: tensor) + -> tensor { %0 = linalg.matmul_i8_i8_i32 {__internal_linalg_transform__ = "tile-and-pad"} - ins(%arg0, %arg1: tensor, tensor) + ins(%arg0, %arg1: tensor, tensor<8x?xi8>) outs(%arg2: tensor) -> tensor - -// CHECK: return %[[TD0]] : tensor return %0 : tensor } +// CHECK-LABEL: func @matmul_partially_static_tensors( +// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<8x?xi8> +// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor) -> tensor { +// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor) { +// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor) { +// CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor) { +// CHECK: %[[sTA:.*]] = subtensor %[[TA]][{{.*}}] : tensor to tensor +// CHECK: %[[sTB:.*]] = subtensor %[[TB]][{{.*}}] : tensor<8x?xi8> to tensor +// CHECK: %[[sTC:.*]] = subtensor %[[TC2]][{{.*}}] : tensor to tensor +// CHECK: %[[pA:.*]] = linalg.pad_tensor %[[sTA]] low[%c0, %c0] high[%{{.*}}, %{{.*}}] +// CHECK: : tensor to tensor<2x4xi8> +// CHECK: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] low[%c0, %c0] high[%{{.*}}, %{{.*}}] +// CHECK: : tensor to tensor<4x3xi8> +// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%c0, %c0] high[%{{.*}}, %{{.*}}] +// CHECK: : tensor to tensor<2x3xi32> +// CHECK: %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x4xi8>, tensor<4x3xi8>) +// CHECK-SAME: outs(%[[pC]] : tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: return %[[TD0]] : tensor + + +// CHECK-1DIM-TILE: func @matmul_partially_static_tensors( +// CHECK-1DIM-TILE: %[[TA:[0-9a-z]+]]: tensor +// CHECK-1DIM-TILE: %[[TB:[0-9a-z]+]]: tensor<8x?xi8> +// CHECK-1DIM-TILE: %[[TC:[0-9a-z]+]]: tensor) -> tensor { +// CHECK-1DIM-TILE: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor) { +// CHECK-1DIM-TILE: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor) { +// CHECK-1DIM-TILE: %[[sTA:.*]] = subtensor %[[TA]][{{.*}}] : tensor to tensor +// CHECK-1DIM-TILE: %[[scTA:.*]] = tensor.cast %[[sTA]] : tensor to tensor +// CHECK-1DIM-TILE: %[[sTB:.*]] = subtensor %[[TB]][{{.*}}] : tensor<8x?xi8> to tensor<8x?xi8> +// CHECK-1DIM-TILE: %[[scTB:.*]] = tensor.cast %[[sTB]] : tensor<8x?xi8> to tensor +// CHECK-1DIM-TILE: %[[sTC:.*]] = subtensor %[[TC1]][{{.*}}] : tensor to tensor +// CHECK-1DIM-TILE: %[[pA:.*]] = linalg.pad_tensor %[[scTA]] low[%c0, %c0] high[%{{.*}}, %{{.*}}] +// CHECK-1DIM-TILE: : tensor to tensor<2x8xi8> +// CHECK-1DIM-TILE: %[[pB:.*]] = linalg.pad_tensor %[[scTB]] low[%c0, %c0] high[%{{.*}}, %{{.*}}] +// CHECK-1DIM-TILE: : tensor to tensor<8x3xi8> +// CHECK-1DIM-TILE: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%c0, %c0] high[%{{.*}}, %{{.*}}] +// CHECK-1DIM-TILE: : tensor to tensor<2x3xi32> +// CHECK-1DIM-TILE: %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x8xi8>, tensor<8x3xi8>) +// CHECK-1DIM-TILE: outs(%[[pC]] : tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK-1DIM-TILE: return %[[TD0]] : tensor diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -83,6 +83,10 @@ Option testTileAndPadPattern{ *this, "test-tile-and-pad-pattern", llvm::cl::desc("Test tile and pad pattern"), llvm::cl::init(false)}; + ListOption tileSizesForPadding{ + *this, "tile-sizes-for-padding", + llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore, + llvm::cl::MiscFlags::CommaSeparated}; Option testHoistPadding2Levels{*this, "test-hoist-padding-2-level", llvm::cl::desc("Test hoist padding"), llvm::cl::init(false)}; @@ -514,12 +518,12 @@ return b.create(op.getOwner()->getLoc(), t, b.getZeroAttr(t)); } -static void applyTileAndPadPattern(FuncOp funcOp) { +static void applyTileAndPadPattern(FuncOp funcOp, ArrayRef tileSizes) { MLIRContext *context = funcOp.getContext(); OwningRewritePatternList tilingPattern; auto linalgTilingOptions = linalg::LinalgTilingOptions() - .setTileSizes({2, 3, 4}) + .setTileSizes(tileSizes) .setPaddingValueComputationFunction(getNeutralOfLinalgOp); tilingPattern.insert>( context, linalgTilingOptions, @@ -562,7 +566,7 @@ if (testAffineMinSCFCanonicalizationPatterns) return applyAffineMinSCFCanonicalizationPatterns(getFunction()); if (testTileAndPadPattern) - return applyTileAndPadPattern(getFunction()); + return applyTileAndPadPattern(getFunction(), tileSizesForPadding); if (testHoistPadding2Levels) { getFunction().walk([](linalg::PadTensorOp padTensorOp) { (void)linalg::hoistPaddingOnTensors(padTensorOp, 2);