diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -561,15 +561,17 @@ TransformEachOpTrait, TransformOpInterface]> { let description = [{ - Tile a TilingInterface `op` to a tiled `scf.foreach_thread`, applying - tiling by `num_threads`. - If non-empty, the `thread_dim_mapping` is added as an attribute to the + Tile a TilingInterface op to a tiled `scf.foreach_thread`. Tiling is + applied by either specifying `num_threads` or `tile_size`. If `num_threads` + is specified, then the tile size for each dimension `i` is calculated + dynamically via `ceilDiv(dimSize[i], num_threads[i])`. + If non-empty, the `thread_dim_mapping` is added as an attribute to the resulting `scf.foreach_thread`. - Zero tile sizes indicate that the dimension is not tiled, and can be thought - of as tiling by the full size of data. - It is the user's responsibility to ensure that `num_threads` is a valid - tiling specification (i.e. that only tiles parallel dimensions, e.g. in the - Linalg case). + Zero tile sizes indicate that the dimension is not tiled and can be + thought of as tiling by the full size of data. + It is the user's responsibility to ensure that `num_threads/tile_sizes` is + a valid tiling specification (i.e. that only tiles parallel dimensions, + e.g. in the Linalg case). #### Return modes @@ -580,24 +582,39 @@ successfully, the transform succeeds. Otherwise the transform silently fails. - The 2 returned handles point to only the subset of successfully produced + The two returned handles point to only the subset of successfully produced tiled operations, which can all be empty. - These 2 returned handles point to: + These two returned handles point to: - the new scf.foreach_thread op, - the tiled op that implements TilingInterface. + + ### Example using `num_threads` + + ``` + %0 = pdl_match @match_matmul in %arg1 + %3:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20] + ``` + + ### Example using `tile_sizes` + + ``` + %0 = pdl_match @match_matmul in %arg1 + %3:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [10, 20, 0] + ``` }]; let arguments = (ins PDL_Operation:$target, // TODO: dynamic number of threads. - DefaultValuedAttr:$num_threads, + OptionalAttr>:$num_threads, + OptionalAttr>:$tile_sizes, OptionalAttr:$thread_dim_mapping); let results = (outs PDL_Operation:$foreach_thread_op, PDL_Operation:$tiled_op); let assemblyFormat = [{ - $target $num_threads (`(` `mapped` `to` `dims` $thread_dim_mapping^ `)`)? - attr-dict + $target (`num_threads` $num_threads^) : (`tile_sizes` $tile_sizes)? + (`(` `mapped` `to` `dims` $thread_dim_mapping^ `)`)? attr-dict }]; let extraClassDeclaration = [{ diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -531,10 +531,17 @@ Operation *tiledOp; }; FailureOr -tileToForeachThreadOp(OpBuilder &builder, TilingInterface op, +tileToForeachThreadOp(RewriterBase &builder, TilingInterface op, ArrayRef numThreads, ArrayRef threadDimMapping = {}); +/// Same as `tileToForeachThreadOp`, but calculate the number of threads +/// required using the given tileSizes. +FailureOr +tileToForeachThreadOpUsingTileSizes(RewriterBase &builder, TilingInterface op, + ArrayRef tileSizes, + ArrayRef threadDimMapping = {}); + /// All indices returned by IndexOp should be invariant with respect to tiling. /// Therefore, if an operation is tiled, we have to transform the indices /// accordingly, i.e. offset them by the values of the corresponding induction diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -721,8 +721,13 @@ actualValues.push_back(value); continue; } - constants.push_back(dialect->materializeConstant(b, ofr.get(), - b.getIndexType(), loc)); + // Since we are directly specifying `index` as the result type, we need to + // ensure the provided attribute is also an index type. Otherwise, the + // AffineDialect materializer will create invalid `arith.constant` + // operations if the provided Attribute is any other kind of integer. + constants.push_back(dialect->materializeConstant( + b, b.getIndexAttr(ofr.get().cast().getInt()), + b.getIndexType(), loc)); actualValues.push_back(constants.back()->getResult(0)); } } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -800,12 +800,20 @@ IRRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); auto maybeThreadDimMappingAttr = getThreadDimMapping(); - FailureOr tilingResult = - linalg::tileToForeachThreadOp( - rewriter, target, getAsOpFoldResult(getNumThreads()), - maybeThreadDimMappingAttr - ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr) - : ArrayRef{}); + auto dimMapping = + llvm::to_vector(maybeThreadDimMappingAttr + ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr) + : ArrayRef{}); + + FailureOr tilingResult = failure(); + if (Optional numThreads = getNumThreads()) + tilingResult = linalg::tileToForeachThreadOp( + rewriter, target, getAsOpFoldResult(*numThreads), dimMapping); + + if (Optional tileSizes = getTileSizes()) + tilingResult = linalg::tileToForeachThreadOpUsingTileSizes( + rewriter, target, getAsOpFoldResult(*tileSizes), dimMapping); + if (failed(tilingResult)) return emitDefaultSilenceableFailure(target); rewriter.replaceOp(target, tilingResult->tileOp->getResults()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -13,6 +13,7 @@ #include #include "PassDetail.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Passes.h" @@ -182,23 +183,43 @@ } /// Build an `affine_max` of all the `vals`. -static Value buildMax(OpBuilder &b, Location loc, ValueRange vals) { +static OpFoldResult buildMax(OpBuilder &b, Location loc, + ArrayRef vals) { + SmallVector args = getValueOrCreateConstantIndexOp(b, loc, vals); return b.createOrFold( loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()), - vals); + args); } -/// Build an `affine_min` of all the `vals`. -static Value buildMin(OpBuilder &b, Location loc, ValueRange vals) { - return b.createOrFold( - loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()), - vals); +/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less +/// than `iterationSize`. +static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, + OpFoldResult numThreads, + OpFoldResult iterationSize) { + Optional tileSizeConst = getConstantIntValue(tileSize); + Optional numThreadsConst = getConstantIntValue(numThreads); + Optional iterSizeConst = getConstantIntValue(iterationSize); + if (!tileSizeConst || !numThreadsConst || !iterSizeConst) + return false; + return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst; } -FailureOr -linalg::tileToForeachThreadOp(OpBuilder &b, TilingInterface op, - ArrayRef numThreads, - ArrayRef threadDimMapping) { +/// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`. The +/// tiling is specified by the number of tiles/threads `numThreads` and the +/// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is +/// not specified, then it is derived from `numThreads` as `ceilDiv(dimSize[i], +/// numThreads[i])`. If non-empty, the `threadDimMapping` is added as an +/// attribute to the resulting `scf.foreach_thread`. A zero tile sizes indicate +/// that the dimension is not tiled, and can be thought of as tiling by the full +/// size of data. +/// It is the user's responsibility to ensure that `numThreads` is a valid +/// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the +/// Linalg case). If `omitTileOffsetBoundsCheck` is true, then the function will +/// assume that `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds. +static FailureOr tileToForeachThreadOpImpl( + RewriterBase &b, TilingInterface op, ArrayRef numThreads, + Optional> nominalTileSizes, + ArrayRef threadDimMapping, bool omitTileOffsetBoundsCheck) { Location loc = op->getLoc(); OpBuilder::InsertionGuard g(b); SmallVector loopRanges = op.getIterationDomain(b); @@ -224,80 +245,128 @@ Value zero = b.create(loc, 0); Operation *tiledOp = nullptr; + + // Create the ForeachThreadOp. We don't use the lambda body-builder + // version because we require the use of RewriterBase in the body, so we + // manually move the insertion point to the body below. scf::ForeachThreadOp foreachThreadOp = b.create( - loc, materializedNonZeroNumThreads, threadDimMapping, - [&](OpBuilder &b, Location loc, ValueRange threadIds) { - int64_t nLoops = loopRanges.size(); - SmallVector tiledOffsets, tiledSizes; - tiledOffsets.reserve(nLoops); - tiledSizes.reserve(nLoops); - for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; - ++loopIdx) { - bool overflow = loopIdx >= numThreads.size(); - bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0); - // Degenerate case: take the whole domain. - if (overflow || isZero) { - tiledOffsets.push_back(loopRanges[loopIdx].offset); - tiledSizes.push_back(loopRanges[loopIdx].size); - continue; - } - - // Tiled case: compute the offset and size. - AffineExpr i, j, M, N, O; - bindDims(b.getContext(), i, j); - bindSymbols(b.getContext(), M, N, O); - Value size = loopRanges[loopIdx].size; - Value offset = loopRanges[loopIdx].offset; - Value threadId = threadIds[threadIdIdx]; - // TODO: more aggressive foldings. - // Symbolic fixed max size per thread. - // TODO: floor + 0/1 depending on case for better load-balancing. - Value maxSizePerThread = b.createOrFold( - loc, M.ceilDiv(N), - ValueRange{size, materializedNonZeroNumThreads[threadIdIdx]}); - // Dynamic offset shifted by threadId * maxSizePerThread. - Value offsetPerThread = b.createOrFold( - loc, i + j * M, ValueRange{offset, threadId, maxSizePerThread}); - // Dynamic upper-bound depending on the threadId. - Value sizeMinusOffsetPerThread = b.createOrFold( - loc, -i + M, ValueRange{offsetPerThread, size}); - Value tileSizePerThread = buildMin( - b, loc, ValueRange{sizeMinusOffsetPerThread, maxSizePerThread}); - tiledOffsets.push_back(offsetPerThread); - // TODO: if tileSizePerThread <= 0 early exit. - tiledSizes.push_back( - buildMax(b, loc, ValueRange{zero, tileSizePerThread})); - ++threadIdIdx; - } - - SmallVector tiledOps = - op.getTiledImplementation(b, destOperands, tiledOffsets, tiledSizes, - /*tileDestOperands=*/true); - assert(tiledOps.size() == 1 && "expected a single produced tiled op"); - tiledOp = tiledOps.front(); - - auto tilingInterfaceOp = dyn_cast(tiledOp); - assert(tilingInterfaceOp && - "Tiled op does not implement TilingInterface"); - - auto tiledDestOperands = tilingInterfaceOp.getDestinationOperands(b); - - // Create terminator with parallel subset insert operations. - auto performConcurrentlyOp = b.create(loc); - OpBuilder::InsertionGuard g(b); - b.setInsertionPointToStart(performConcurrentlyOp.getBody()); - for (auto it : - llvm::zip(tiledDestOperands, tilingInterfaceOp->getResults(), - destOperands)) { - createMatchingParallelSubsetInsertOp( - b, loc, - cast(std::get<0>(it).getDefiningOp()), - std::get<1>(it), std::get<2>(it)); - } - }); + loc, op->getResultTypes(), ValueRange(materializedNonZeroNumThreads), + threadDimMapping); + + // Fill out the ForeachThreadOp body. + b.setInsertionPointToStart(foreachThreadOp.getBody(0)); + ValueRange threadIds = foreachThreadOp.getThreadIndices(); + int64_t nLoops = loopRanges.size(); + SmallVector tiledOffsets, tiledSizes; + tiledOffsets.reserve(nLoops); + tiledSizes.reserve(nLoops); + for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) { + bool overflow = loopIdx >= numThreads.size(); + bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0); + // Degenerate case: take the whole domain. + if (overflow || isZero) { + tiledOffsets.push_back(loopRanges[loopIdx].offset); + tiledSizes.push_back(loopRanges[loopIdx].size); + continue; + } + + // Tiled case: compute the offset and size. + AffineExpr i, j, M, N, O; + bindDims(b.getContext(), i, j); + bindSymbols(b.getContext(), M, N, O); + Value size = loopRanges[loopIdx].size; + Value offset = loopRanges[loopIdx].offset; + Value threadId = threadIds[threadIdIdx]; + // Symbolic fixed max size per thread. + // TODO: floor + 0/1 depending on case for better load-balancing. + OpFoldResult tileSizePerThread = + nominalTileSizes.hasValue() + ? (*nominalTileSizes)[loopIdx] + : makeComposedFoldedAffineApply( + b, loc, M.ceilDiv(N), + ArrayRef{size, nonZeroNumThreads[threadIdIdx]}); + + // Dynamic offset shifted by threadId * maxSizePerThread. + OpFoldResult offsetPerThread = makeComposedFoldedAffineApply( + b, loc, i + j * M, {offset, threadId, tileSizePerThread}); + // Dynamic upper-bound depending on the threadId. + OpFoldResult residualTileSize = makeComposedFoldedAffineApply( + b, loc, i + j * M - N, + {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size}); + if (!isConstantIntValue(residualTileSize, 0)) { + OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply( + b, loc, -i + M, {offsetPerThread, size}); + tileSizePerThread = makeComposedFoldedAffineMin( + b, loc, AffineMap::getMultiDimIdentityMap(2, b.getContext()), + ArrayRef{sizeMinusOffsetPerThread, tileSizePerThread}); + } + + tiledOffsets.push_back(offsetPerThread); + // TODO: if tileSizePerThread <= 0 early exit. + if (!omitTileOffsetBoundsCheck && + !canOmitTileOffsetInBoundsCheck(tileSizePerThread, + nonZeroNumThreads[threadIdIdx], size)) + tileSizePerThread = buildMax(b, loc, {zero, tileSizePerThread}); + + tiledSizes.push_back(tileSizePerThread); + ++threadIdIdx; + } + + SmallVector tiledOps = + op.getTiledImplementation(b, destOperands, tiledOffsets, tiledSizes, + /*tileDestOperands=*/true); + assert(tiledOps.size() == 1 && "expected a single produced tiled op"); + tiledOp = tiledOps.front(); + + auto tilingInterfaceOp = dyn_cast(tiledOp); + assert(tilingInterfaceOp && "Tiled op does not implement TilingInterface"); + + auto tiledDestOperands = tilingInterfaceOp.getDestinationOperands(b); + + // Create terminator with parallel subset insert operations. + b.setInsertionPointToStart(foreachThreadOp.getTerminator().getBody()); + for (auto it : llvm::zip(tiledDestOperands, tilingInterfaceOp->getResults(), + destOperands)) { + createMatchingParallelSubsetInsertOp( + b, loc, cast(std::get<0>(it).getDefiningOp()), + std::get<1>(it), std::get<2>(it)); + } return ForeachThreadTilingResult{foreachThreadOp, tiledOp}; } +FailureOr +linalg::tileToForeachThreadOp(RewriterBase &b, TilingInterface op, + ArrayRef numThreads, + ArrayRef threadDimMapping) { + return tileToForeachThreadOpImpl(b, op, numThreads, /*nominalTileSizes=*/None, + threadDimMapping, + /*omitTileOffsetBoundsCheck=*/false); +} + +FailureOr +linalg::tileToForeachThreadOpUsingTileSizes( + RewriterBase &b, TilingInterface op, ArrayRef tileSizes, + ArrayRef threadDimMapping) { + SmallVector loopRanges = op.getIterationDomain(b); + unsigned nLoops = loopRanges.size(); + SmallVector numThreads; + numThreads.reserve(nLoops); + AffineExpr s0, s1; + bindSymbols(b.getContext(), s0, s1); + AffineExpr divExpr = s0.ceilDiv(s1); + for (const auto &it : llvm::zip(tileSizes, loopRanges)) { + OpFoldResult numTiles = std::get<0>(it); + if (!isConstantIntValue(numTiles, 0)) + numTiles = makeComposedFoldedAffineApply( + b, op.getLoc(), divExpr, {std::get<1>(it).size, std::get<0>(it)}); + numThreads.push_back(numTiles); + } + return tileToForeachThreadOpImpl(b, op, numThreads, + /*nominalTileSizes=*/tileSizes, + threadDimMapping, + /*omitTileOffsetBoundsCheck=*/true); +} + // Insert a tile `source` into the destination tensor `dest`. The position at // which the tile is inserted (as well as size of tile) is taken from a given // ExtractSliceOp `sliceOp`. diff --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir --- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize | FileCheck %s +// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize -split-input-file | FileCheck %s // Offset per thread: // CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 10))> @@ -43,7 +43,161 @@ transform.sequence %arg0 { ^bb1(%arg1: !pdl.operation): %0 = pdl_match @match_linalg_matmul in %arg1 - %1:2 = transform.structured.tile_to_foreach_thread_op %0 [10, 20] (mapped to dims [1, 0]) + %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20] (mapped to dims [1, 0]) } } } + +// ----- + +// Tests that dimension 0 can eliminate affine.min/max, dimension 1 cannot. + +// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -15 + 300, 15)> +// CHECK-DAG: #[[$map1:.+]] = affine_map<(d0) -> (0, d0)> +// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0) -> (d0 * 10)> +// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (d0 * 15)> + +// CHECK-LABEL: matmul_static( +// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor +func.func @matmul_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> { + // CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index + // CHECK-DAG: %[[c21:.+]] = arith.constant 21 : index + // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c21]]) + // CHECK: %[[TSMIN:.+]] = affine.min #[[$map0]](%[[IV1]]) + // CHECK: %[[TS:.+]] = affine.max #[[$map1]](%[[TSMIN]]) + // CHECK-NOT: affine.min + // CHECK-NOT: affine.max + // CHECK: %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]]) + // CHECK: %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] : + // CHECK: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]]) + // CHECK: %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] : + // CHECK: %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]]) + // CHECK: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]]) + // CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] : + // CHECK: linalg.matmul + // CHECK: scf.foreach_thread.perform_concurrently + // CHECK-NEXT: tensor.parallel_insert_slice + %0 = linalg.matmul ins(%A, %B : tensor<100x200xf32>, tensor<200x300xf32>) + outs(%C : tensor<100x300xf32>) -> (tensor<100x300xf32>) + return %0 : tensor<100x300xf32> +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_linalg_matmul : benefit(1) { + %0 = operands + %1 = types + %2 = operation "linalg.matmul"(%0 : !pdl.range) -> (%1 : !pdl.range) + rewrite %2 with "transform.dialect" + } + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_linalg_matmul in %arg1 + %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 21] + } +} + + +// ----- + +// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)> +// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)> +// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)> +// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)> +// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)> +// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)> + +// CHECK-LABEL: matmul_tile_size_dynamic( +// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor +func.func @matmul_tile_size_dynamic(%A: tensor, %B: tensor, %C: tensor) -> tensor { + // CHECK: %[[M:.+]] = tensor.dim %[[A]], %c0 : + // CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 : + // CHECK: %[[NT0:.+]] = affine.apply #map0()[%[[M]]] + // CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]] + // CHECK: %[[M:.+]] = tensor.dim %[[A]], %c0 : + // CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 : + // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) + // CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]] + // CHECK: %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]] + // CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]]) + // CHECK tensor.extract_slice %[[A]] + // CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]]) + // CHECK tensor.extract_slice %[[B]] + // CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]]) + // CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]]) + // CHECK tensor.extract_slice %[[C]] + // CHECK: linalg.matmul + // CHECK: scf.foreach_thread.perform_concurrently + // CHECK-NEXT: tensor.parallel_insert_slice + %0 = linalg.matmul ins(%A, %B : tensor, tensor) + outs(%C : tensor) -> (tensor) + return %0 : tensor +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_linalg_matmul : benefit(1) { + %0 = operands + %1 = types + %2 = operation "linalg.matmul"(%0 : !pdl.range) -> (%1 : !pdl.range) + rewrite %2 with "transform.dialect" + } + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_linalg_matmul in %arg1 + %1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [10, 20] + } +} + +// ----- + +// Tests that dimension 0 can eliminate affine.min/max, dimension 1 cannot. + +// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -21 + 300, 21)> +// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0) -> (d0 * 10)> +// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (d0 * 21)> + +// CHECK-LABEL: matmul_tile_size_static( +// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor +func.func @matmul_tile_size_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> { + // CHECK-DAG: %[[c10:.+]] = arith.constant 10 : + // CHECK-DAG: %[[c15:.+]] = arith.constant 15 : + // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c15]]) + // CHECK: %[[TS:.+]] = affine.min #[[$map0]](%[[IV1]]) + // CHECK-NOT: affine.max + // CHECK-NOT: affine.min + // CHECK: %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]]) + // CHECK: %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] : + // CHECK: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]]) + // CHECK: %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] : + // CHECK: %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]]) + // CHECK: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]]) + // CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] : + // CHECK: linalg.matmul + // CHECK: scf.foreach_thread.perform_concurrently + // CHECK-NEXT: tensor.parallel_insert_slice + %0 = linalg.matmul ins(%A, %B : tensor<100x200xf32>, tensor<200x300xf32>) + outs(%C : tensor<100x300xf32>) -> (tensor<100x300xf32>) + return %0 : tensor<100x300xf32> +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_linalg_matmul : benefit(1) { + %0 = operands + %1 = types + %2 = operation "linalg.matmul"(%0 : !pdl.range) -> (%1 : !pdl.range) + rewrite %2 with "transform.dialect" + } + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_linalg_matmul in %arg1 + %1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [10, 21] + } +}