diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -134,6 +134,37 @@ TileUsingSCFForOp tilingPattern; }; +/// Pattern to tile an op that implements the `TilingInterface` using +/// `scf.foreach_thread`. +struct SCFTileForeachResult { + Operation *tiledOp; + SmallVector loops; +}; +struct TileUsingSCFForeachOp + : public OpInterfaceRewritePattern { + /// Construct a generic pattern applied to all TilingInterface ops. + TileUsingSCFForeachOp(MLIRContext *context, SCFTilingOptions options, + PatternBenefit benefit = 1); + + /// Construct a generic pattern applied to `opName`. + TileUsingSCFForeachOp(StringRef opName, MLIRContext *context, + SCFTilingOptions options, PatternBenefit benefit = 1); + + /// `matchAndRewrite` implementation that returns the significant transformed + /// pieces of IR. + FailureOr + returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const; + + LogicalResult matchAndRewrite(TilingInterface op, + PatternRewriter &rewriter) const override { + return returningMatchAndRewrite(op, rewriter); + } + +private: + /// Options to control tiling; + SCFTilingOptions options; +}; + } // namespace scf } // namespace mlir diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/TilingInterface.h" @@ -403,3 +404,280 @@ } return tileAndFuseResult; } + +//===----------------------------------------------------------------------===// +// TileUsingSCFForEachOp +//===----------------------------------------------------------------------===// + +/// Construct an AffineMap such that result at position i contains the +/// AffineExpr for calculating the tileOffset via +/// `tileOffset = loopLB + tileIdx * tileSize`. +/// `tileSize` is expected to be a symbol. In the resulting AffineMap, the first +/// `numTiledLoops` dimension variables represent the `tileIdx`s, the second +/// `numTiledLoops` dimension variables represent the `loopLB`s. There should be +/// `numTiledLoops` symbol variables reprenting the tile sizes. +static AffineMap getTileOffsetsMap(OpBuilder &b, unsigned numTiledLoops) { + SmallVector tileOffsetExprs; + tileOffsetExprs.reserve(numTiledLoops); + auto getDim = [&](unsigned dimIdx) { return b.getAffineDimExpr(dimIdx); }; + for (unsigned i = 0; i < numTiledLoops; i++) { + auto s0 = b.getAffineSymbolExpr(i); + tileOffsetExprs.push_back(getDim(numTiledLoops + i) + getDim(i) * s0); + } + auto tileOffsetMaps = AffineMap::get( + /*dimCount=*/2 * numTiledLoops, /*symbolCount=*/numTiledLoops, + /*results=*/tileOffsetExprs, b.getContext()); + return tileOffsetMaps; +} + +/// Calculate `min(tileSize, ub - tileOffset)` fore each tiled loop. +static SmallVector getTileSizeBounds(OpBuilder &b, Location loc, + unsigned numTiledLoops, + ArrayRef tileSizes, + ArrayRef tileOffsets, + ArrayRef ubs) { + // The tile size to use (to avoid out of bounds access) is minimum of + // `tileSize` and `ub - tileOffset`. + AffineExpr d0, s0, s1; + bindSymbols(b.getContext(), s0, s1); + bindDims(b.getContext(), d0); + AffineMap tileSizeMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext()); + + return llvm::to_vector( + llvm::map_range(llvm::zip(tileOffsets, tileSizes, ubs), [&](auto it) { + return b.createOrFold( + loc, tileSizeMap, + ValueRange{std::get<0>(it), std::get<1>(it), std::get<2>(it)}); + })); +} + +/// Calculate `ceilDiv(ub-lb, tileSize)` for each **tiled** loop. +static SmallVector getNumTiles(OpBuilder &b, Location loc, + ArrayRef tiledLoopRanges, + ArrayRef tiledLoopTileSizes) { + assert(tiledLoopRanges.size() == tiledLoopTileSizes.size()); + // Create an affine expression representing the number of tiles as + // `ceilDiv(ub-lb, tileSize)`. + AffineExpr s0, d0, d1; + bindSymbols(b.getContext(), s0); + bindDims(b.getContext(), d0, d1); + AffineExpr numTilesExpr = (d0 - d1).ceilDiv(s0); + // Apply this map to each of the tiled loops. + return llvm::to_vector( + llvm::map_range(llvm::enumerate(tiledLoopRanges), [&](auto it) -> Value { + return makeComposedAffineApply(b, loc, numTilesExpr, + {it.value().size, it.value().offset, + tiledLoopTileSizes[it.index()]}); + })); +} + +/// Gather `values[indices]` and store them in `dest` in the same order given by +/// `indices`. +template +void gather(ArrayRef values, ArrayRef indices, + SmallVector &dest) { + dest.resize(indices.size()); + unsigned destIdx = 0; + for (auto idx : indices) + dest[destIdx++] = values[idx]; +} + +/// For each value `v` at position `i` in `values`, store `v` into +/// `dest[indices[i]]`. +template +static void scatter(SrcContainerTy &&values, ArrayRef indices, + DstContainerTy &dest) { + unsigned srcIdx = 0; + for (auto val : values) + dest[indices[srcIdx++]] = val; +} + +/// Generate a single `scf.foreach_thread` operation that represents the tiled +/// loop nest. In `offsets` and `sizes`, return the multi-dimensional offset and +/// size of the tile processed within the inner most loop. Upon returning, the +/// insertion point for `builder` will be positioned within the loop body just +/// before the terminator. +static scf::ForeachThreadOp generateForeachLoopNest( + OpBuilder &builder, Location loc, TypeRange resultTypes, + ArrayRef loopRanges, ArrayRef tileSizeVals, + ArrayRef tiledLoops, ArrayRef threadDimMapping, + SmallVector &tileOffsets, + SmallVector &tileSizes) { + + size_t nTiledLoops = tiledLoops.size(); + assert(!loopRanges.empty() && "expected at least one loop range"); + assert(loopRanges.size() == tileSizeVals.size() && + "expected as many tile sizes as loop ranges"); + assert(nTiledLoops > 0 && " expected ata least one tiled loop"); + + // Initialize the outputs using the loop range parameters. + tileOffsets.resize(loopRanges.size()); + tileSizes.resize(loopRanges.size()); + llvm::copy(llvm::map_range(loopRanges, [](Range r) { return r.offset; }), + tileOffsets.begin()); + llvm::copy(llvm::map_range(loopRanges, [](Range r) { return r.size; }), + tileSizes.begin()); + + // Select out information corresponding to the tiled loops. + SmallVector tiledLoopRanges; + gather(loopRanges, tiledLoops, /*dest=*/tiledLoopRanges); + SmallVector tiledLoopTileSizes; + gather(tileSizeVals, tiledLoops, /*dest=*/ + tiledLoopTileSizes); + + // Create a single scf.foreach_thread operation for all tiled loops. + // "numThreads" here actually means "number of subsets required". + auto loop = builder.create( + loc, resultTypes, /*numThreads=*/ + getNumTiles(builder, loc, tiledLoopRanges, tiledLoopTileSizes), + threadDimMapping); + + // Inside the loop, create new variables for lb, ub, and step. + builder.setInsertionPointToStart(loop.getBody()); + auto iterVals = loop.getBody()->getArguments(); + + SmallVector vars(iterVals.begin(), iterVals.end()); + vars.reserve(nTiledLoops * 4); + auto tiledLoopRangeOffsets = + llvm::map_range(tiledLoopRanges, [](Range r) { return r.offset; }); + vars.append(tiledLoopRangeOffsets.begin(), tiledLoopRangeOffsets.end()); + vars.append(tiledLoopTileSizes.begin(), tiledLoopTileSizes.end()); + auto tiledLoopRangeSizes = + llvm::map_range(tiledLoopRanges, [](Range r) -> Value { return r.size; }); + vars.append(tiledLoopRangeSizes.begin(), tiledLoopRangeSizes.end()); + AffineMap tileOffsetMap = getTileOffsetsMap(builder, nTiledLoops); + SmallVector tileLb = + applyMapToValues(builder, loc, tileOffsetMap, + makeArrayRef(vars).slice(0, nTiledLoops * 3)); + auto tileSizeBounds = getTileSizeBounds( + builder, loc, nTiledLoops, + /*tileSizes=*/makeArrayRef(vars).slice(nTiledLoops * 2, nTiledLoops), + /*tileOffsets=*/tileLb, + /*ubs=*/ + makeArrayRef(vars).slice(nTiledLoops * 3, nTiledLoops)); + scatter(tileLb, tiledLoops, tileOffsets); + scatter(tileSizeBounds, tiledLoops, tileSizes); + + return loop; +} + +/// Tile `op`'s parallel dimensions using SCF foreach. +static FailureOr +tileWithForEach(PatternRewriter &rewriter, TilingInterface op, + ArrayRef tileSizes, ArrayRef loopRanges, + ArrayRef tiledLoops, + ArrayRef threadDimMapping) { + Location loc = op->getLoc(); + scf::SCFTileForeachResult result; + OpBuilder::InsertionGuard g(rewriter); + if (loopRanges.empty() || tiledLoops.empty()) + return failure(); + + SmallVector tileOffsets; + SmallVector adjustedTileSizes; + scf::ForeachThreadOp foreachThreadOp = generateForeachLoopNest( + rewriter, loc, op->getResultTypes(), loopRanges, tileSizes, tiledLoops, + threadDimMapping, tileOffsets, adjustedTileSizes); + result.loops.push_back(foreachThreadOp); + + // We should now be inside the scf::ForeachThreadOp body. + SmallVector destOperands = op.getDestinationOperands(rewriter); + SmallVector tiledOps = op.getTiledImplementation( + rewriter, destOperands, tileOffsets, adjustedTileSizes, + /*tileDestOperands=*/true); + if (tiledOps.size() != 1) + return rewriter.notifyMatchFailure( + op, "expected tiled implementation to return a single op"); + result.tiledOp = tiledOps.front(); + + // Populate the terminator. + TilingInterface tiledOp = dyn_cast(result.tiledOp); + SmallVector tiledDestOperands = + tiledOp.getDestinationOperands(rewriter); + rewriter.setInsertionPointToStart(foreachThreadOp.getTerminator().getBody()); + for (const auto &it : llvm::enumerate(tiledDestOperands)) { + Operation *definingOp = it.value().getDefiningOp(); + if (auto subsetExtractOp = dyn_cast(definingOp)) { + rewriter.create( + loc, tiledOp->getResult(it.index()), destOperands[it.index()], + subsetExtractOp.getMixedOffsets(), subsetExtractOp.getMixedSizes(), + subsetExtractOp.getMixedStrides()); + } + // Other operations, e.g. `memref.subview`, do not need operations in the + // terminator. + } + return result; +} + +scf::TileUsingSCFForeachOp::TileUsingSCFForeachOp(MLIRContext *context, + scf::SCFTilingOptions options, + PatternBenefit benefit) + : OpInterfaceRewritePattern(context, benefit), + options(std::move(options)) {} + +scf::TileUsingSCFForeachOp::TileUsingSCFForeachOp(StringRef opName, + MLIRContext *context, + scf::SCFTilingOptions options, + PatternBenefit benefit) + : OpInterfaceRewritePattern(context, benefit), + options(std::move(options)) {} + +FailureOr +scf::TileUsingSCFForeachOp::returningMatchAndRewrite( + TilingInterface op, PatternRewriter &rewriter) const { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(op); + + if (!options.tileSizeComputationFunction) + return rewriter.notifyMatchFailure( + op, "missing tile size computation function"); + + // 1. Get the range of the loops that are represented by the operation. + SmallVector iterationDomain = op.getIterationDomain(rewriter); + size_t numLoops = iterationDomain.size(); + if (numLoops == 0) + return rewriter.notifyMatchFailure( + op, "unable to tile op with no iteration domain"); + + // 2. Materialize the tile sizes. Enforce the convention that "tiling by + // zero" skips tiling a particular dimension. This convention is + // significantly simpler to handle instead of adjusting affine maps to + // account for missing dimensions. + SmallVector tileSizeVector = + options.tileSizeComputationFunction(rewriter, op); + auto zero = rewriter.create(op.getLoc(), 0); + if (tileSizeVector.size() < iterationDomain.size()) + tileSizeVector.append(numLoops - tileSizeVector.size(), zero); + for (const auto &it : llvm::enumerate(op.getLoopIteratorTypes())) { + if (it.value() != getParallelIteratorTypeName()) + tileSizeVector[it.index()] = zero; + } + + // Create list of tiled indices. If no loops are tiled, do nothing. + SmallVector tiledLoops; + for (unsigned i = 0; i < iterationDomain.size(); i++) { + if (!matchPattern(tileSizeVector[i], m_Zero())) + tiledLoops.push_back(i); + } + if (tiledLoops.empty()) + return failure(); + + FailureOr tilingResult = + tileWithForEach(rewriter, op, tileSizeVector, iterationDomain, tiledLoops, + /*threadDimMapping=*/{}); + if (failed(tilingResult)) + return failure(); + + // 3. If the original operations has results, modify the loop nest to yield + // the replacement values. If there are no loops, the tiledOp's results are + // the replacements. + mlir::Operation::result_range replacements = + tilingResult->loops.empty() ? tilingResult->tiledOp->getResults() + : tilingResult->loops.front()->getResults(); + if (op->getNumResults() > 0) + rewriter.replaceOp(op, replacements); + else + rewriter.eraseOp(op); + + return tilingResult; +} diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir --- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt -test-tiling-interface=tile-using-scf-for -split-input-file %s | FileCheck %s +// RUN: mlir-opt -test-tiling-interface=tile-using-scf-foreach -split-input-file %s | FileCheck %s --check-prefix=FOREACH func.func @simple_matmul(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { @@ -41,6 +42,44 @@ // CHECK: scf.yield %[[INNER]] // CHECK: return %[[OUTER]] +// FOREACH-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)> +// FOREACH-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)> +// FOREACH-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 10)> +// FOREACH-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 * 20)> +// FOREACH-DAG: #[[MAP4:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// FOREACH-DAG: #[[MAP5:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> +// FOREACH: func.func @simple_matmul( +// FOREACH-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// FOREACH-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// FOREACH-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// FOREACH-DAG: %[[C0:.+]] = arith.constant 0 : index +// FOREACH-DAG: %[[C1:.+]] = arith.constant 1 : index +// FOREACH-DAG: %[[C10:.+]] = arith.constant 10 : index +// FOREACH-DAG: %[[C20:.+]] = arith.constant 20 : index +// FOREACH-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// FOREACH-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// FOREACH-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]] +// FOREACH: %[[NUM_TILES0:.+]] = affine.apply #[[MAP0]]()[%[[M]]] +// FOREACH: %[[NUM_TILES1:.+]] = affine.apply #[[MAP1]]()[%[[N]]] +// FOREACH: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NUM_TILES0]], %[[NUM_TILES1]]) +// FOREACH: %[[LB0:.+]] = affine.apply #[[MAP2]](%[[IV0]]) +// FOREACH: %[[LB1:.+]] = affine.apply #[[MAP3]](%[[IV1]]) +// FOREACH: %[[TSIZE0:.+]] = affine.min #[[MAP4]](%[[LB0]])[%[[C10]], %[[M]]] +// FOREACH: %[[TSIZE1:.+]] = affine.min #[[MAP5]](%[[LB1]])[%[[C20]], %[[N]]] +// FOREACH: %[[LB0:.+]] = affine.apply #[[MAP2]](%[[IV0]]) +// FOREACH: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[LB0]], 0] [%[[TSIZE0]], %[[K]]] +// FOREACH: %[[LB1:.+]] = affine.apply #[[MAP3]](%[[IV1]]) +// FOREACH: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[LB1]]] [%[[K]], %[[TSIZE1]]] +// FOREACH-DAG: %[[LB0:.+]] = affine.apply #[[MAP2]](%[[IV0]]) +// FOREACH-DAG: %[[LB1:.+]] = affine.apply #[[MAP3]](%[[IV1]]) +// FOREACH: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ARG2]][%[[LB0]], %[[LB1]]] [%[[TSIZE0]], %[[TSIZE1]]] +// FOREACH: %[[GEMM_TILE:.+]] = linalg.matmul +// FOREACH-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : +// FOREACH-SAME: outs(%[[INIT_TILE]] : +// FOREACH: scf.{{.+}}.perform_concurrently { +// FOREACH-NEXT: tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[ARG2]] +// FOREACH-SAME: [%[[LB0]], %[[LB1]]] [%[[TSIZE0]], %[[TSIZE1]]] + // ----- func.func @simple_matmul_memref(%arg0 : memref, %arg1 : memref, @@ -81,6 +120,12 @@ // CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : // CHECK-SAME: outs(%[[OUT_TILE]] : +// FOREACH-LABEL: func.func @simple_matmul_memref +// FOREACH: %[[NUM_TILES0:.+]] = affine.apply +// FOREACH: %[[NUM_TILES1:.+]] = affine.apply +// FOREACH: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NUM_TILES0]], %[[NUM_TILES1]]) +// FOREACH-NOT: {{perform_concurrently}} + // ----- #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> @@ -134,6 +179,45 @@ // CHECK: scf.yield %[[INNER]]#0, %[[INNER]]#1 // CHECK: return %[[OUTER]]#0, %[[OUTER]]#1 +// FOREACH-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 10)> +// FOREACH-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * 20)> +// FOREACH-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// FOREACH-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> +// FOREACH: func.func @multi_result +// FOREACH-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>) +// FOREACH-DAG: %[[NUM_TILE0:.+]] = arith.constant 13 : index +// FOREACH-DAG: %[[NUM_TILE1:.+]] = arith.constant 15 : index +// FOREACH-DAG: %[[C10:.+]] = arith.constant 10 : index +// FOREACH-DAG: %[[C20:.+]] = arith.constant 20 : index +// FOREACH-DAG: %[[C128:.+]] = arith.constant 128 : index +// FOREACH-DAG: %[[C300:.+]] = arith.constant 300 : index +// FOREACH-DAG: %[[INIT0:.+]] = linalg.init_tensor [128, 300, 200] +// FOREACH-DAG: %[[INIT1:.+]] = linalg.init_tensor [300, 128, 200] +// FOREACH: %[[RESULT:.+]]:2 = scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NUM_TILE0]], %[[NUM_TILE1]]) +// FOREACH-DAG: %[[LB0:.+]] = affine.apply #[[MAP0]](%[[IV0]]) +// FOREACH-DAG: %[[LB1:.+]] = affine.apply #[[MAP1]](%[[IV1]]) +// FOREACH: %[[TS0:.+]] = affine.min #[[MAP2]](%[[LB0]])[%[[C10]], %[[C128]]] +// FOREACH: %[[TS1:.+]] = affine.min #[[MAP3]](%[[LB1]])[%[[C20]], %[[C300]]] +// FOREACH-DAG: %[[LB0:.+]] = affine.apply #[[MAP0]](%[[IV0]]) +// FOREACH-DAG: %[[LB1:.+]] = affine.apply #[[MAP1]](%[[IV1]]) +// FOREACH: %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]] +// FOREACH-SAME: [%[[LB0]], 0, %[[LB1]]] [%[[TS0]], 200, %[[TS1]]] +// FOREACH-DAG: %[[LB00:.+]] = affine.apply #[[MAP0]] +// FOREACH-DAG: %[[LB01:.+]] = affine.apply #[[MAP1]] +// FOREACH: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[INIT0]] +// FOREACH-SAME: [%[[LB00]], %[[LB01]], 0] [%[[TS0]], %[[TS1]], 200] [1, 1, 1] +// FOREACH-DAG: %[[LB11:.+]] = affine.apply #[[MAP1]] +// FOREACH-DAG: %[[LB10:.+]] = affine.apply #[[MAP0]] +// FOREACH: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[INIT1]] +// FOREACH-SAME: [%[[LB11]], %[[LB10]], 0] [%[[TS1]], %[[TS0]], 200] [1, 1, 1] +// FOREACH: %[[RESULT_TILE:.+]]:2 = linalg.generic +// FOREACH-SAME: ins(%[[ARG_TILE]] : +// FOREACH-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] : +// FOREACH: perform_concurrently +// FOREACH-NEXT: parallel_insert_slice %[[RESULT_TILE]]#0 into %[[INIT0]][%[[LB00]], %[[LB01]], 0] [%[[TS0]], %[[TS1]], 200] +// FOREACH-NEXT: parallel_insert_slice %[[RESULT_TILE]]#1 into %[[INIT1]][%[[LB11]], %[[LB10]], 0] [%[[TS1]], %[[TS0]], 200] +// FOREACH: return %[[RESULT]]#0, %[[RESULT]]#1 + // ----- func.func @conv2D(%arg0 : tensor, %arg1 : tensor, @@ -192,3 +276,59 @@ // CHECK-SAME: outs(%[[INIT_TILE]] : // CHECK: tensor.insert_slice %[[CONV_TILE]] into %[[INIT2]] // CHECK-SAME: [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]] + +// FOREACH-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)> +// FOREACH-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)> +// FOREACH-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 ceildiv 30)> +// FOREACH-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 * 10)> +// FOREACH-DAG: #[[MAP4:.+]] = affine_map<(d0) -> (d0 * 20)> +// FOREACH-DAG: #[[MAP5:.+]] = affine_map<(d0) -> (d0 * 30)> +// FOREACH-DAG: #[[MAP6:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// FOREACH-DAG: #[[MAP7:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> +// FOREACH-DAG: #[[MAP8:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)> +// FOREACH-DAG: #[[MAP9:.+]] = affine_map<(d0) -> (d0 * 40)> +// FOREACH-DAG: #[[MAP10:.+]] = affine_map<(d0)[s0] -> (d0 * 2 + s0 - 2)> +// FOREACH-DAG: #[[MAP11:.+]] = affine_map<(d0) -> (d0 * 90)> +// FOREACH-DAG: #[[MAP12:.+]] = affine_map<(d0)[s0] -> (d0 * 3 + s0 - 3)> +// FOREACH: func.func @conv2D +// FOREACH-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor +// FOREACH-SAME: %[[FILTER:[a-zA-Z0-9]+]]: tensor +// FOREACH-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor +// FOREACH-DAG: %[[C0:.+]] = arith.constant 0 : index +// FOREACH-DAG: %[[C1:.+]] = arith.constant 1 : index +// FOREACH-DAG: %[[C3:.+]] = arith.constant 3 : index +// FOREACH-DAG: %[[C2:.+]] = arith.constant 2 : index +// FOREACH-DAG: %[[C10:.+]] = arith.constant 10 : index +// FOREACH-DAG: %[[C20:.+]] = arith.constant 20 : index +// FOREACH-DAG: %[[C30:.+]] = arith.constant 30 : index +// FOREACH-DAG: %[[N:.+]] = tensor.dim %[[INPUT]], %[[C0]] +// FOREACH-DAG: %[[C:.+]] = tensor.dim %[[INPUT]], %[[C3]] +// FOREACH-DAG: %[[P:.+]] = tensor.dim %[[FILTER]], %[[C0]] +// FOREACH-DAG: %[[Q:.+]] = tensor.dim %[[FILTER]], %[[C1]] +// FOREACH-DAG: %[[F:.+]] = tensor.dim %[[FILTER]], %[[C3]] +// FOREACH-DAG: %[[R:.+]] = tensor.dim %[[INIT]], %[[C1]] +// FOREACH-DAG: %[[S:.+]] = tensor.dim %[[INIT]], %[[C2]] +// FOREACH-DAG: %[[NT0:.+]] = affine.apply #[[MAP0]]()[%[[N]]] +// FOREACH-DAG: %[[NT1:.+]] = affine.apply #[[MAP1]]()[%[[R]]] +// FOREACH-DAG: %[[NT2:.+]] = affine.apply #[[MAP2]]()[%[[S]]] +// FOREACH: %[[RESULT:.+]] = scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]], %[[IV2:.+]]) in (%[[NT0]], %[[NT1]], %[[NT2]]) +// FOREACH-DAG: %[[OFFT0:.+]] = affine.apply #[[MAP3]](%[[IV0]]) +// FOREACH-DAG: %[[OFFT1:.+]] = affine.apply #[[MAP4]](%[[IV1]]) +// FOREACH-DAG: %[[OFFT2:.+]] = affine.apply #[[MAP5]](%[[IV2]]) +// FOREACH: %[[TSIZE0:.+]] = affine.min #[[MAP6]](%[[OFFT0]])[%[[C10]], %[[N]]] +// FOREACH: %[[TSIZE1:.+]] = affine.min #[[MAP7]](%[[OFFT1]])[%[[C20]], %[[R]]] +// FOREACH: %[[TSIZE2:.+]] = affine.min #[[MAP8]](%[[OFFT2]])[%[[C30]], %[[S]]] +// FOREACH: %[[LB0:.+]] = affine.apply #[[MAP3]](%[[IV0]]) +// FOREACH: %[[LB1:.+]] = affine.apply #[[MAP9]](%[[IV1]]) +// FOREACH: %[[UB1:.+]] = affine.apply #[[MAP10]](%[[TSIZE1]])[%[[P]]] +// FOREACH: %[[LB2:.+]] = affine.apply #[[MAP11]](%[[IV2]]) +// FOREACH: %[[UB2:.+]] = affine.apply #[[MAP12]](%[[TSIZE2]])[%[[Q]]] +// FOREACH: %[[LHS:.+]] = tensor.extract_slice %[[INPUT]][%[[LB0]], %[[LB1]], %[[LB2]], 0] [%[[TSIZE0]], %[[UB1]], %[[UB2]], %[[C]]] +// FOREACH: %[[LHS:.+]] = tensor.extract_slice %[[FILTER]][0, 0, 0, 0] [%[[P]], %[[Q]], %[[C]], %[[F]]] +// FOREACH: %[[LB0:.+]] = affine.apply #[[MAP3]](%[[IV0]]) +// FOREACH: %[[LB1:.+]] = affine.apply #[[MAP4]](%[[IV1]]) +// FOREACH: %[[LB2:.+]] = affine.apply #[[MAP5]](%[[IV2]]) +// FOREACH: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[LB0]], %[[LB1]], %[[LB2]], 0] [%[[TSIZE0]], %[[TSIZE1]], %[[TSIZE2]], %[[F]]] +// FOREACH: %[[TILE:.+]] = linalg.conv_2d_nhwc_hwcf +// FOREACH: perform_concurrently +// FOREACH-NEXT: parallel_insert_slice %[[TILE]] into %[[INIT]][%[[LB0]], %[[LB1]], %[[LB2]], 0] [%[[TSIZE0]], %[[TSIZE1]], %[[TSIZE2]], %[[F]]] \ No newline at end of file diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -108,6 +108,42 @@ linalg::LinalgTransformationFilter filter; }; +/// Construct a generic pattern applied to all TilingInterface ops that verify +/// `filter`. +struct TestTileUsingSCFForEachOWithFilter : public scf::TileUsingSCFForeachOp { + TestTileUsingSCFForEachOWithFilter(MLIRContext *context, + scf::SCFTilingOptions options, + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter(), + PatternBenefit benefit = 1) + : scf::TileUsingSCFForeachOp(context, options, benefit), filter(filter) {} + + /// Construct a generic pattern applied to `opName`. + TestTileUsingSCFForEachOWithFilter(StringRef opName, MLIRContext *context, + scf::SCFTilingOptions options, + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter(), + PatternBenefit benefit = 1) + : scf::TileUsingSCFForeachOp(context, options, benefit), filter(filter) {} + + LogicalResult matchAndRewrite(TilingInterface op, + PatternRewriter &rewriter) const override { + if (failed(filter.checkAndNotify(rewriter, op))) + return failure(); + + FailureOr tilingResult = + returningMatchAndRewrite(op, rewriter); + if (failed(tilingResult)) { + return failure(); + } + filter.replaceLinalgTransformationFilter(rewriter, tilingResult->tiledOp); + return success(); + } + +private: + linalg::LinalgTransformationFilter filter; +}; + /// Test pass for testing the use of `TilingInterface`. struct TestTilingInterfacePass : public PassWrapper> { @@ -126,6 +162,10 @@ return "Test tiling using TilingInterface"; } + Option useForeach{*this, "tile-using-scf-foreach", + llvm::cl::desc("Use the scf.foreach_thread variant"), + llvm::cl::init(false)}; + Option testTiling{ *this, "tile-using-scf-for", llvm::cl::desc( @@ -183,6 +223,22 @@ context, {10}, "gemm_fusion", patterns); return; } + + if (useForeach) { + // 1. Tiling M and N dims of `linalg.matmul` on tensors. + addPatternForTiling( + context, {10, 20}, "simple_gemm", patterns); + // 2. Tiling M, N and K of `linalg.matmul` on buffers. + addPatternForTiling( + context, {10, 20, 30}, "simple_gemm_memref", patterns); + // 3. Tiling 3D parallel generic op which implements a transpose + addPatternForTiling( + context, {10, 0, 20}, "parallel_generic_transpose", patterns); + // 4. Tiling 2D conv op. Here we tile parallel dimensions rather than + // reduction dimensions in the `testTiling` case. + addPatternForTiling( + context, {10, 20, 30, 0, 0, 0, 0}, "simple_conv", patterns); + } } void TestTilingInterfacePass::runOnOperation() {