diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -479,6 +479,10 @@ /// proper interaction with folding. LinalgTilingOptions &setTileSizes(ArrayRef ts); + /// Tile all dynamic dimensions by 1. I.e., scalarize those dimensions. + /// Note: `scalarizeDynamicDims` and `setTileSizes` cannot be used together. + LinalgTilingOptions &scalarizeDynamicDims(); + /// The interchange vector to reorder the tiled loops. SmallVector interchangeVector = {}; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -107,6 +107,7 @@ LinalgTilingOptions & mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef ts) { + assert(!tileSizeComputationFunction && "tile sizes already set"); SmallVector tileSizes(ts.begin(), ts.end()); tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { OpBuilder::InsertionGuard guard(b); @@ -120,6 +121,30 @@ return *this; } +LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() { + assert(!tileSizeComputationFunction && "tile sizes already set"); + tileSizeComputationFunction = [](OpBuilder &b, Operation *op) { + SmallVector tileSizes; + auto linalgOp = dyn_cast(op); + if (!linalgOp) + return tileSizes; + Location loc = linalgOp.getLoc(); + auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc); + AffineMap map = linalgOp.getShapesToLoopsMap(); + if (!map) + return tileSizes; + auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); + // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile + // size 0). + for (Value shapeSize : shapeSizes) + tileSizes.push_back(getConstantIntValue(shapeSize).hasValue() + ? b.create(loc, 0) + : b.create(loc, 1)); + return tileSizes; + }; + return *this; +} + /// Try to compute a static bounding box for `operand` /// Return success if either: /// 1. The operand is already statically shaped, `result` is left unchanged. diff --git a/mlir/test/Dialect/Linalg/tile-scalarize-dynamic-dims.mlir b/mlir/test/Dialect/Linalg/tile-scalarize-dynamic-dims.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile-scalarize-dynamic-dims.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-scalarize-dynamic-dims" -for-loop-canonicalization -canonicalize -split-input-file | \ +// RUN: FileCheck %s + +// CHECK-LABEL: func @matmul_partly_dynamic_tensor( +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C1:.*]] = constant 1 : index +// CHECK: tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[UB1:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[UB2:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK: scf.for %[[IV0:.*]] = %[[C0]] to %[[UB1]] step %[[C1]] +// CHECK: scf.for %[[IV1:.*]] = %[[C0]] to %[[UB2]] step %[[C1]] +// CHECK: %[[S1:.*]] = tensor.extract_slice %[[ARG0]][%[[IV0]], %[[IV1]]] [1, 1] [1, 1] : tensor to tensor<1x1xf32> +// CHECK: %[[S2:.*]] = tensor.extract_slice %[[ARG1]][%[[IV1]], 0] [1, 2000] [1, 1] : tensor to tensor<1x2000xf32> +// CHECK: %[[S3:.*]] = tensor.extract_slice %{{.*}}[%[[IV0]], 0] [1, 2000] [1, 1] : tensor to tensor<1x2000xf32> +// CHECK: linalg.matmul ins(%[[S1]], %[[S2]] : tensor<1x1xf32>, tensor<1x2000xf32>) outs(%[[S3]] : tensor<1x2000xf32>) -> tensor<1x2000xf32> +func @matmul_partly_dynamic_tensor(%arg0: tensor, %arg1: tensor) + -> tensor { + %c0 = constant 0 : index + %c1 = constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %out = linalg.init_tensor [%d0, 2000] : tensor + %r = linalg.matmul {__internal_linalg_transform__ = "tile"} + ins(%arg0, %arg1: tensor, tensor) + outs(%out: tensor) -> tensor + return %r : tensor +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -87,6 +87,10 @@ Option testTilePattern{*this, "test-tile-pattern", llvm::cl::desc("Test tile pattern"), llvm::cl::init(false)}; + Option testTileScalarizeDynamicDims{ + *this, "test-tile-scalarize-dynamic-dims", + llvm::cl::desc("Test tiling of dynamic dims by 1"), + llvm::cl::init(false)}; Option testHoistPadding{*this, "test-hoist-padding", llvm::cl::desc("Test hoist padding"), llvm::cl::init(0)}; @@ -566,12 +570,19 @@ } static void applyTilePattern(FuncOp funcOp, ArrayRef tileSizes, - bool padTiles, ArrayRef peeledLoops) { + bool padTiles, ArrayRef peeledLoops, + bool scalarizeDynamicDims) { MLIRContext *context = funcOp.getContext(); RewritePatternSet tilingPattern(context); auto linalgTilingOptions = - linalg::LinalgTilingOptions().setTileSizes(tileSizes).setPeeledLoops( - peeledLoops); + linalg::LinalgTilingOptions().setPeeledLoops(peeledLoops); + if (scalarizeDynamicDims) { + linalgTilingOptions.scalarizeDynamicDims(); + assert(tileSizes.empty() && + "tileSizes and scalarizeDynamicDims is mutually exclusive"); + } else { + linalgTilingOptions.setTileSizes(tileSizes); + } if (padTiles) linalgTilingOptions.setPaddingValueComputationFunction( getNeutralOfLinalgOp); @@ -709,7 +720,11 @@ return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling, skipPartial); if (testTilePattern) - return applyTilePattern(getFunction(), tileSizes, padTiles, peeledLoops); + return applyTilePattern(getFunction(), tileSizes, padTiles, peeledLoops, + /*scalarizeDynamicDims=*/false); + if (testTileScalarizeDynamicDims) + return applyTilePattern(getFunction(), tileSizes, padTiles, + /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); if (testHoistPadding) { getFunction().walk([&](linalg::PadTensorOp padTensorOp) { (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding);