diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -12,6 +12,7 @@ #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/TilingInterface.h" namespace mlir { namespace linalg { diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -595,4 +595,126 @@ }]; } +def GetNumThreadsOp : Op, + DeclareOpInterfaceMethods]> { + + let description = [{ + Given a TilingInterface operation and a set of tile sizes that can either + be static index values or dynamic SSA values, compute the nubmer of + threads required for tiling the TilingInterface op using + `scf.foreach_thread`. This is equivalent to the number of tiles required, + and is computed by emitting the IR for calculating + `ceilDiv(ub[i]-lb[i],tileSize[i])`, where the `i` stands for the iteration + space dimension index of the `TilingInterface` op. Therefore, the caller + should provide as many `tile_size` values as the rank of the target op's + iteration space. The returned `num_threads[i]` value will be zero if + `tile_sizes[i]` is zero. + + Example: + ``` + %0 = pdl_match @match_matmul in %arg1 + %1:3 = transform.structured.get_num_threads %0 tile_sizes [10, 20, 0] + %3:2 = transform.structured.tile_to_foreach %0 num_threads [%1#0, %1#1, 0] tile_sizes [10, 20, 0] + ``` + }]; + + let arguments = (ins PDL_Operation:$target, + Variadic:$dynamic_tile_sizes, + DefaultValuedAttr:$static_tile_sizes); + + let results = (outs Variadic:$num_threads); + + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + /// Return the tile sizes as a vector of OpFoldResult's. + SmallVector getMixedTileSizes(); + }]; +} + + +def TileToForeachThreadOp : + Op, + DeclareOpInterfaceMethods, + AttrSizedOperandSegments]> { + let description = [{ + Tile a TilingInterface `op` to a tiled `scf.foreach_thread`. The number + of threads and tile sizes are directly specified by the caller and + can be static index values or dynamic SSA values. The caller must ensure + that these values are sufficient to tile the entire iteration space. The + caller can utilize the transform operations `structured.get_num_threads` + or `structured.get_num_tiles` to compute the number of threads from the + tile sizes or vice-versa (see below examples). + + The `tile_sizes` and `num_threads` arrays should be equal in length. + + In the case that `structured.get_num_threads` is used to determine + the value for `num_threads[i]`, it is guaranteed that the offset of each + tile is within the bounds of the iteration space of dimension `i`. However, + when `structured.get_tile_size` is used, additional checks must be inserted + to ensure that tile lower bounds do not exceed the iteration space bounds. + + If non-empty, the `threadDimMapping` is added as an attribute to the + resulting `scf.foreach_thread`. + + A tile size of `0` is used as a sentinel value to indicate that a dimension + should not be tiled. It is the user's responsibility to ensure that + `num_threads` and `tile_sizes` form a valid tiling specification (i.e. that + only tiles parallel dimensions are tiled in the `linalg` case). + + ### Return modes: + This operation ignores non-TilingInterface ops and drops them in the return. + + If all the operations referred to by the `target` PDLOperation tile + properly, the transform succeeds. Otherwise the transform silently fails. + + The 2 returned handles points to only the subset of successfully produced + tiled operations, which can all be empty. + This 2 returned handles point to: + - the new `scf.foreach_thread` op + - the tiled TilingInterface op instance + + ### Example using `structured.get_num_threads`: + ``` + %0 = pdl_match @match_matmul in %arg1 + %1:3 = transform.structured.get_num_threads %0 tile_sizes [10, 20, 0] + %3:2 = transform.structured.tile_to_foreach %0 num_threads [%1#0, %1#1, 0] tile_sizes [10, 20, 0] + ``` + + ### Example using `structured.get_tile_size`: + ``` + %0 = pdl_match @match_matmul in %arg1 + %1:3 = transform.structured.get_tile_size %0 num_threads [10, 20] + %3:2 = transform.structured.tile_to_foreach %0 num_threads [10, 20, 0] tile_sizes [%1:0, %1:1, 0] + ``` + }]; + + let arguments = (ins PDL_Operation:$target, + Variadic:$dynamic_num_threads, + Variadic:$dynamic_tile_sizes, + DefaultValuedAttr:$static_tile_sizes, + DefaultValuedAttr:$static_num_threads, + DefaultValuedAttr:$thread_dim_mapping); + let results = (outs PDL_Operation:$foreach_thread_op, + PDL_Operation:$tiled_op); + + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + /// Returns the list of number of threads, which may be static (Attribute) or + /// dynamic (Value). + SmallVector getMixedNumThreads(); + + /// Determine the `inBounds` properties of each tiled dimension + /// based on inspecting dynamic tile size producers. + SmallVector deriveInBounds(); + + /// Returns the list of tile sizes, which may be static (Attribute) or + /// dynamic (Value). + SmallVector getMixedTileSizes(); + }]; +} + #endif // LINALG_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -134,6 +134,32 @@ TileUsingSCFForOp tilingPattern; }; +/// Tile a TilingInterface `op` to a tiled `scf.foreach_thread`. The number +/// of threads and tile sizes are directly specified by the caller. The caller +/// must ensure that these values are sufficient to tile the entire iteration +/// space. The `tile_sizes` and `num_threads` arrays should be equal in length. +/// If non-empty, the `threadDimMapping` is added as an attribute to the +/// resulting `scf.foreach_thread`. +/// The `inBounds` array specifies whether we can assume that all tile offsets +/// for each iteration domain index will be within the iteration domain bounds. +/// By default, `inBounds` is assumed to be false for each dimension index. The +/// caller can set `inBounds` to be equal to true if `numThreads` is derived +/// from the iteration bounds of `op` as well as the `tileSizes (rather than +/// `tileSize` being determined from the other two variables). +/// If `inBounds` is given, it should be of the same length as `tileSize`. +/// A tile size of `0` is used as a sentinel value to indicate not to +/// tile that a dimension is not tiled. It is the user's responsibility to +/// ensure that `numThreads` is a valid tiling specification (i.e. that only +/// parallel dimensions have non-zero tile sizes in the `linalg` case). +struct SCFTileForeachResult { + Operation *tiledOp; + scf::ForeachThreadOp loop; +}; +FailureOr tileUsingSCFForeachOp( + OpBuilder &b, TilingInterface op, ArrayRef tileSize, + ArrayRef numThreads, ArrayRef threadDimMapping = {}, + ArrayRef inBounds = {}); + } // namespace scf } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Parser/Parser.h" @@ -829,6 +830,365 @@ return DiagnosedSilenceableFailure(success()); } +//===----------------------------------------------------------------------===// +// GetNumThreadsOp +//===----------------------------------------------------------------------===// + +/// For the given transform operation `op` and all its `targets`, +/// fill the `producers` vector with all the producers of the `variadicOperand` +/// vector for reach target. Verify that there are enough producers for each +/// position in the variadic operand vector and that each producer produces one +/// index result. +static Optional +getAndVerifyVariadicDynamicIndexValueSingleResultProducers( + Operation *op, ArrayRef targets, + Operation::operand_range variadicOperand, transform::TransformState &state, + SmallVector> &producers) { + + auto emitSilenceableError = [&]() -> DiagnosedSilenceableFailure { + Diagnostic diag(op->getLoc(), DiagnosticSeverity::Error); + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + }; + + for (Value producerHandle : variadicOperand) { + producers.push_back(state.getPayloadOps(producerHandle)); + if (producers.back().size() != targets.size()) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() + << "expected as many dynamic value producing operations(" + << producers.back().size() << ") as target ops (" << targets.size() + << ")"; + diag.attachNote(producerHandle.getLoc()) << "for this handle"; + return diag; + } + + for (Operation *op : producers.back()) { + if (op->getNumResults() == 1 && + op->getResult(0).getType().isa()) + continue; + DiagnosedSilenceableFailure diag = + emitSilenceableError() + << "expected dynamic values to be produced by ops " + "with a single index-type result"; + diag.attachNote(op->getLoc()) << " producer op"; + diag.attachNote(producerHandle.getLoc()) << "for this handle"; + return diag; + } + } + return None; +} + +/// Given an array of OpFoldResult's, representing index Value's or attributes, +/// create an output array of equivalent Value's. For static integer attributes, +/// return the result of a `arith::ConstantIndexOp`, and pass through dynamic +/// Values unchanged. +static SmallVector +getValueOrConstant(OpBuilder &b, Location loc, ArrayRef mixed, + ArrayRef> dynamicValueProducers, + unsigned producerIndex) { + unsigned dynamicIdx = 0; + SmallVector result; + result.reserve(mixed.size()); + for (auto ofr : mixed) { + Optional staticValue = getConstantIntValue(ofr); + if (staticValue) { + result.push_back(b.create(loc, *staticValue)); + continue; + } + result.push_back( + dynamicValueProducers[dynamicIdx++][producerIndex]->getResult(0)); + } + return result; +} + +DiagnosedSilenceableFailure +transform::GetNumThreadsOp::apply(TransformResults &transformResults, + TransformState &state) { + + ArrayRef targets = state.getPayloadOps(getTarget()); + SmallVector> dynamicTileSizesProducers; + dynamicTileSizesProducers.reserve(getStaticTileSizes().size()); + + if (Optional diag = + getAndVerifyVariadicDynamicIndexValueSingleResultProducers( + *this, targets, getDynamicTileSizes(), state, + dynamicTileSizesProducers)) + return std::move(*diag); + + SmallVector numThreadsOps; + for (const auto &en : llvm::enumerate(targets)) { + auto target = dyn_cast(en.value()); + if (!target) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() << "only TilingInterface ops are supported"; + diag.attachNote(en.value()->getLoc()) << "target op"; + return diag; + } + + OpBuilder builder(target->getContext()); + builder.setInsertionPoint(target); + SmallVector ranges = target.getIterationDomain(builder); + + SmallVector tileSizes = + getValueOrConstant(builder, getLoc(), getMixedTileSizes(), + dynamicTileSizesProducers, en.index()); + + unsigned dynamicIdx = 0; + for (const auto &it : llvm::enumerate(getStaticTileSizes())) { + if (it.index() >= ranges.size()) + return DiagnosedSilenceableFailure::definiteFailure(); + int64_t staticTileSize = it.value().cast().getInt(); + if (staticTileSize == 0) { + numThreadsOps.push_back( + builder.create(target->getLoc(), 0)); + continue; + } + Value tileSize = + staticTileSize == ShapedType::kDynamicSize + ? state.getPayloadOps(getDynamicTileSizes()[dynamicIdx++])[0] + ->getResult(0) + : builder + .create(target->getLoc(), + staticTileSize) + .getResult(); + + Value size = ranges[it.index()].size; + AffineExpr s0 = builder.getAffineSymbolExpr(0); + AffineExpr d0 = builder.getAffineDimExpr(0); + Operation *numThreadsOp = makeComposedAffineApply( + builder, en.value()->getLoc(), s0.ceilDiv(d0), {tileSize, size}); + numThreadsOps.push_back(numThreadsOp); + } + } + + for (const auto &en : llvm::enumerate(numThreadsOps)) + transformResults.set(getNumThreads()[en.index()].cast(), + en.value()); + + return DiagnosedSilenceableFailure::success(); +} + +void transform::GetNumThreadsOp::getEffects( + SmallVectorImpl &effects) { + onlyReadsHandle(getTarget(), effects); + onlyReadsHandle(getDynamicTileSizes(), effects); + producesHandle(getResults(), effects); + modifiesPayload(effects); +} + +ParseResult transform::GetNumThreadsOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand target; + SmallVector dynamicTileSizes; + ArrayAttr staticTileSizes; + auto pdlOperationType = pdl::OperationType::get(parser.getContext()); + if (parser.parseOperand(target) || + parser.resolveOperand(target, pdlOperationType, result.operands) || + parser.parseKeyword("tile_sizes") || + parseOperandsOrIntegersSizesList(parser, dynamicTileSizes, + staticTileSizes) || + parser.resolveOperands(dynamicTileSizes, pdlOperationType, + result.operands) || + parser.parseOptionalAttrDict(result.attributes)) + return ParseResult::failure(); + + result.addAttribute(getStaticTileSizesAttrName(result.name), staticTileSizes); + result.addTypes(SmallVector(staticTileSizes.size(), pdlOperationType)); + return success(); +} + +void transform::GetNumThreadsOp::print(OpAsmPrinter &p) { + p << ' ' << getTarget(); + p << " " + << "tile_sizes"; + printOperandsOrIntegersSizesList(p, getOperation(), getDynamicTileSizes(), + getStaticTileSizes()); + p.printOptionalAttrDict((*this)->getAttrs(), {getStaticTileSizesAttrName()}); +} + +static SmallVector +getOpFoldResultVector(MLIRContext *ctx, Operation::operand_range dynamicValues, + ArrayAttr staticValues) { + SmallVector results; + results.reserve(staticValues.size()); + unsigned dynamicPos = 0; + Builder builder(ctx); + for (auto attr : staticValues) { + int64_t val = attr.cast().getInt(); + if (val == ShapedType::kDynamicSize) { + results.push_back(dynamicValues[dynamicPos++]); + continue; + } + results.push_back(builder.getIndexAttr(val)); + } + return results; +} + +SmallVector transform::GetNumThreadsOp::getMixedTileSizes() { + return getOpFoldResultVector(getContext(), getDynamicTileSizes(), + getStaticTileSizes()); +} + +//===----------------------------------------------------------------------===// +// TileToForeachThreadOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::TileToForeachThreadOp::apply(TransformResults &transformResults, + TransformState &state) { + + ArrayRef targets = state.getPayloadOps(getTarget()); + + // Populate the producers for the variadic dynamic tile size / num_threads + // arguments. + SmallVector> dynamicTileSizesProducers, + dynamicNumThreadsProducers; + dynamicTileSizesProducers.reserve(getStaticTileSizes().size()); + dynamicNumThreadsProducers.reserve(getStaticNumThreads().size()); + if (Optional diag = + getAndVerifyVariadicDynamicIndexValueSingleResultProducers( + *this, targets, getDynamicTileSizes(), state, + dynamicTileSizesProducers)) + return std::move(*diag); + + if (Optional diag = + getAndVerifyVariadicDynamicIndexValueSingleResultProducers( + *this, targets, getDynamicNumThreads(), state, + dynamicNumThreadsProducers)) + return std::move(*diag); + + // Determine which tile_size/num_threads indices are "in bounds". Here this + // means that num_threads is determined from tile_sizes instead of the other + // way around. + SmallVector inBounds = deriveInBounds(); + + // For each target, perform the tiling transformation. + SmallVector tiled, foreachThreadOp; + for (auto &en : llvm::enumerate(targets)) { + auto tilingIfaceOp = dyn_cast(en.value()); + if (!tilingIfaceOp) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() << "only TilingInterface ops are supported"; + diag.attachNote(en.value()->getLoc()) << "target op"; + return diag; + } + + SimpleRewriter rewriter(tilingIfaceOp.getContext()); + rewriter.setInsertionPoint(tilingIfaceOp); + SmallVector tileSizes = + getValueOrConstant(rewriter, getLoc(), getMixedTileSizes(), + dynamicTileSizesProducers, en.index()); + SmallVector numThreads = + getValueOrConstant(rewriter, getLoc(), getMixedNumThreads(), + dynamicNumThreadsProducers, en.index()); + FailureOr result = scf::tileUsingSCFForeachOp( + rewriter, tilingIfaceOp, getAsOpFoldResult(tileSizes), + getAsOpFoldResult(numThreads), {}, ArrayRef(inBounds)); + if (failed(result)) + return DiagnosedSilenceableFailure::definiteFailure(); + + tiled.push_back(result->tiledOp); + foreachThreadOp.push_back(result->loop); + rewriter.replaceOp(tilingIfaceOp, result->loop->getResults()); + } + + transformResults.set(getTiledOp().cast(), tiled); + transformResults.set(getForeachThreadOp().cast(), foreachThreadOp); + + return DiagnosedSilenceableFailure::success(); +} + +void transform::TileToForeachThreadOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getTarget(), effects); + onlyReadsHandle(getDynamicNumThreads(), effects); + onlyReadsHandle(getDynamicTileSizes(), effects); + producesHandle(getForeachThreadOp(), effects); + producesHandle(getTiledOp(), effects); + modifiesPayload(effects); +} + +ParseResult transform::TileToForeachThreadOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand target; + SmallVector dynamicNumThreads; + SmallVector dynamicTileSizes; + ArrayAttr staticNumThreads, staticTileSizes; + auto pdlOperationType = pdl::OperationType::get(parser.getContext()); + if (parser.parseOperand(target) || + parser.resolveOperand(target, pdlOperationType, result.operands) || + parser.parseKeyword("num_threads") || + parseOperandsOrIntegersSizesList(parser, dynamicNumThreads, + staticNumThreads) || + parser.parseKeyword("tile_sizes") || + parseOperandsOrIntegersSizesList(parser, dynamicTileSizes, + staticTileSizes) || + parser.resolveOperands(dynamicNumThreads, pdlOperationType, + result.operands) || + parser.resolveOperands(dynamicTileSizes, pdlOperationType, + result.operands) || + parser.parseOptionalAttrDict(result.attributes)) + return ParseResult::failure(); + + result.addAttribute(getStaticTileSizesAttrName(result.name), staticTileSizes); + result.addAttribute(getStaticNumThreadsAttrName(result.name), + staticNumThreads); + result.addTypes(SmallVector(2, pdlOperationType)); + + result.addAttribute(getOperandSegmentSizeAttr(), + parser.getBuilder().getI32VectorAttr( + {1, static_cast(dynamicNumThreads.size()), + static_cast(dynamicTileSizes.size())})); + return success(); +} + +void transform::TileToForeachThreadOp::print(OpAsmPrinter &p) { + p << ' ' << getTarget(); + p << " " + << "num_threads"; + printOperandsOrIntegersSizesList(p, getOperation(), getDynamicNumThreads(), + getStaticNumThreads()); + p << " " + << "tile_sizes"; + printOperandsOrIntegersSizesList(p, getOperation(), getDynamicTileSizes(), + getStaticTileSizes()); + p.printOptionalAttrDict((*this)->getAttrs(), + {getStaticNumThreadsAttrName(), + getStaticTileSizesAttrName(), + getOperandSegmentSizesAttrName()}); +} + +SmallVector +transform::TileToForeachThreadOp::getMixedNumThreads() { + return getOpFoldResultVector(getContext(), getDynamicNumThreads(), + getStaticNumThreads()); +} + +SmallVector +transform::TileToForeachThreadOp::getMixedTileSizes() { + return getOpFoldResultVector(getContext(), getDynamicTileSizes(), + getStaticTileSizes()); +} + +SmallVector transform::TileToForeachThreadOp::deriveInBounds() { + unsigned dynamicIdx = 0; + Operation::operand_range dynamicNumThreads = getDynamicNumThreads(); + SmallVector inBounds; + inBounds.reserve(getStaticNumThreads().size()); + for (const auto &numThreads : llvm::enumerate(getStaticNumThreads())) { + if (numThreads.value().cast().getInt() == + ShapedType::kDynamicSize) { + if (isa( + dynamicNumThreads[dynamicIdx++].getDefiningOp())) { + inBounds.push_back(true); + continue; + } + } + inBounds.push_back(false); + } + return inBounds; +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -14,9 +14,12 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/TilingInterface.h" @@ -403,3 +406,220 @@ } return tileAndFuseResult; } + +//===----------------------------------------------------------------------===// +// TileUsingSCFForEachOp +//===----------------------------------------------------------------------===// + +/// Construct an AffineMap such that result at position i contains the +/// AffineExpr for calculating the tileOffset via +/// `tileOffset = loopLB + tileIdx * tileSize`. +/// `tileSize` is expected to be a symbol. In the resulting AffineMap, the first +/// `numTiledLoops` dimension variables represent the `tileIdx`s, the second +/// `numTiledLoops` dimension variables represent the `loopLB`s. There should be +/// `numTiledLoops` symbol variables reprenting the tile sizes. +static AffineMap getTileOffsetsMap(OpBuilder &b, unsigned numTiledLoops) { + SmallVector tileOffsetExprs; + tileOffsetExprs.reserve(numTiledLoops); + auto getDim = [&](unsigned dimIdx) { return b.getAffineDimExpr(dimIdx); }; + for (unsigned i = 0; i < numTiledLoops; i++) { + auto s0 = b.getAffineSymbolExpr(i); + tileOffsetExprs.push_back(getDim(numTiledLoops + i) + getDim(i) * s0); + } + auto tileOffsetMaps = AffineMap::get( + /*dimCount=*/2 * numTiledLoops, /*symbolCount=*/numTiledLoops, + /*results=*/tileOffsetExprs, b.getContext()); + return tileOffsetMaps; +} + +/// Calculate `min(tileSize, ub - tileOffset)` fore each tiled loop. +static SmallVector +getTileSizeBounds(OpBuilder &b, Location loc, unsigned numTiledLoops, + ArrayRef tileSizes, ArrayRef tileOffsets, + ArrayRef ubs, ArrayRef inBounds) { + // The tile size to use (to avoid out of bounds access) is minimum of + // `tileSize` and `ub - tileOffset`. + AffineExpr d0, s0, s1; + bindSymbols(b.getContext(), s0, s1); + bindDims(b.getContext(), d0); + AffineMap tileSizeMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext()); + AffineMap tileSizeMaxMap = + AffineMap::get(1, 0, {d0, b.getAffineConstantExpr(0)}, b.getContext()); + + return llvm::to_vector(llvm::map_range( + llvm::zip(tileOffsets, tileSizes, ubs, inBounds), [&](auto it) -> Value { + Value minResult = b.create( + loc, tileSizeMap, + ValueRange{std::get<0>(it), std::get<1>(it), std::get<2>(it)}); + if (std::get<3>(it)) + return minResult; + return b.create( + loc, tileSizeMaxMap, + ValueRange{getValueOrCreateConstantIndexOp(b, loc, minResult)}); + })); +} + +/// Gather `values[indices]` and store them in `dest` in the same order given by +/// `indices`. +template +void gather(ArrayRef values, ArrayRef indices, + SmallVector &dest) { + dest.resize(indices.size()); + unsigned destIdx = 0; + for (auto idx : indices) + dest[destIdx++] = values[idx]; +} + +/// For each value `v` at position `i` in `values`, store `v` into +/// `dest[indices[i]]`. +template +static void scatter(SrcContainerTy &&values, ArrayRef indices, + DstContainerTy &dest) { + unsigned srcIdx = 0; + for (auto val : values) + dest[indices[srcIdx++]] = val; +} + +/// Generate a single `scf.foreach_thread` operation that represents the tiled +/// loop nest. In `offsets` and `sizes`, return the multi-dimensional offset and +/// size of the tile processed within the inner most loop. Upon returning, the +/// insertion point for `builder` will be positioned within the loop body just +/// before the terminator. +static scf::ForeachThreadOp generateForeachLoopNest( + OpBuilder &builder, Location loc, TypeRange resultTypes, + ArrayRef loopRanges, ArrayRef tileSizes, + ArrayRef numThreads, ArrayRef tiledLoops, + ArrayRef threadDimMapping, ArrayRef inBounds, + SmallVector &tileOffsets, + SmallVector &adjustedTileSizes) { + + size_t nTiledLoops = tiledLoops.size(); + assert(!loopRanges.empty() && "expected at least one loop range"); + assert(loopRanges.size() == tileSizes.size() && + "expected as many tile sizes as loop ranges"); + assert(nTiledLoops > 0 && " expected ata least one tiled loop"); + assert(inBounds.empty() || + (inBounds.size() == tileSizes.size() && + " inBounds should be the same length as tileSizes")); + + // Initialize the outputs using the loop range parameters. + tileOffsets.resize(loopRanges.size()); + adjustedTileSizes.resize(loopRanges.size()); + tileOffsets = llvm::to_vector(llvm::map_range( + loopRanges, [](Range r) { return getAsOpFoldResult(r.offset); })); + adjustedTileSizes = llvm::to_vector(llvm::map_range( + loopRanges, [](Range r) { return getAsOpFoldResult(r.size); })); + + // Select out information corresponding to the tiled loops. + SmallVector tiledLoopRanges; + gather(loopRanges, tiledLoops, tiledLoopRanges); + SmallVector tiledLoopTileSizes; + gather(tileSizes, tiledLoops, tiledLoopTileSizes); + SmallVector nonZeroNumThreads; + gather(numThreads, tiledLoops, nonZeroNumThreads); + SmallVector relevantInBoundsFlags(tileSizes.size(), false); + if (!inBounds.empty()) + gather(inBounds, tiledLoops, relevantInBoundsFlags); + + // Create a single scf.foreach_thread operation for all tiled loops. + // "numThreads" here actually means "number of subsets required". + auto loop = builder.create( + loc, resultTypes, nonZeroNumThreads, threadDimMapping); + + // Inside the loop, create new variables for lb, ub, and step. + builder.setInsertionPointToStart(loop.getBody()); + auto iterVals = loop.getBody()->getArguments(); + + SmallVector vars(iterVals.begin(), iterVals.end()); + vars.reserve(nTiledLoops * 4); + auto tiledLoopRangeOffsets = + llvm::map_range(tiledLoopRanges, [](Range r) { return r.offset; }); + vars.append(tiledLoopRangeOffsets.begin(), tiledLoopRangeOffsets.end()); + vars.append(tiledLoopTileSizes.begin(), tiledLoopTileSizes.end()); + auto tiledLoopRangeSizes = + llvm::map_range(tiledLoopRanges, [](Range r) -> Value { return r.size; }); + vars.append(tiledLoopRangeSizes.begin(), tiledLoopRangeSizes.end()); + AffineMap tileOffsetMap = getTileOffsetsMap(builder, nTiledLoops); + SmallVector tileLb = + applyMapToValues(builder, loc, tileOffsetMap, + makeArrayRef(vars).slice(0, nTiledLoops * 3)); + auto tileSizeBounds = getTileSizeBounds( + builder, loc, nTiledLoops, + /*tileSizes=*/makeArrayRef(vars).slice(nTiledLoops * 2, nTiledLoops), + /*tileOffsets=*/tileLb, + /*ubs=*/ + makeArrayRef(vars).slice(nTiledLoops * 3, nTiledLoops), + relevantInBoundsFlags); + scatter(tileLb, tiledLoops, tileOffsets); + scatter(tileSizeBounds, tiledLoops, adjustedTileSizes); + + return loop; +} + +/// Tile `op`'s parallel dimensions using SCF foreach. +FailureOr mlir::scf::tileUsingSCFForeachOp( + OpBuilder &builder, TilingInterface op, ArrayRef tileSizes, + ArrayRef numThreads, ArrayRef threadDimMapping, + ArrayRef inBounds) { + Location loc = op->getLoc(); + scf::SCFTileForeachResult result; + OpBuilder::InsertionGuard g(builder); + + if (tileSizes.size() != numThreads.size()) + return failure(); + + if (!inBounds.empty() && inBounds.size() != tileSizes.size()) + return failure(); + + // Get the range of the loops that are represented by the operation. + SmallVector iterationDomain = op.getIterationDomain(builder); + size_t numLoops = iterationDomain.size(); + if (numLoops == 0) + return failure(); + + // Create list of tiled indices. If no loops are tiled, do nothing. + SmallVector tiledLoops; + for (unsigned i = 0; i < iterationDomain.size(); i++) { + if (!matchPattern( + getValueOrCreateConstantIndexOp(builder, loc, tileSizes[i]), + m_Zero())) + tiledLoops.push_back(i); + } + if (tiledLoops.empty()) + return failure(); + + SmallVector tileOffsets; + SmallVector adjustedTileSizes; + result.loop = generateForeachLoopNest( + builder, loc, op->getResultTypes(), iterationDomain, + getValueOrCreateConstantIndexOp(builder, loc, tileSizes), + getValueOrCreateConstantIndexOp(builder, loc, numThreads), tiledLoops, + threadDimMapping, inBounds, tileOffsets, adjustedTileSizes); + + // We should now be inside the scf::ForeachThreadOp body. + SmallVector destOperands = op.getDestinationOperands(builder); + SmallVector tiledOps = op.getTiledImplementation( + builder, destOperands, tileOffsets, adjustedTileSizes, + /*tileDestOperands=*/true); + if (tiledOps.size() != 1) + return failure(); + result.tiledOp = tiledOps.front(); + + // Populate the terminator. + TilingInterface tiledOp = dyn_cast(result.tiledOp); + SmallVector tiledDestOperands = + tiledOp.getDestinationOperands(builder); + builder.setInsertionPointToStart(result.loop.getTerminator().getBody()); + for (const auto &it : llvm::enumerate(tiledDestOperands)) { + Operation *definingOp = it.value().getDefiningOp(); + if (auto subsetExtractOp = dyn_cast(definingOp)) { + builder.create( + loc, tiledOp->getResult(it.index()), destOperands[it.index()], + subsetExtractOp.getMixedOffsets(), subsetExtractOp.getMixedSizes(), + subsetExtractOp.getMixedStrides()); + } + // Other operations, e.g. `memref.subview`, do not need operations in the + // terminator. + } + return result; +} diff --git a/mlir/test/Dialect/Linalg/transform-ops-tile-to-foreach.mlir b/mlir/test/Dialect/Linalg/transform-ops-tile-to-foreach.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-ops-tile-to-foreach.mlir @@ -0,0 +1,299 @@ +// RUN: mlir-opt -test-transform-dialect-interpreter -canonicalize -split-input-file %s | FileCheck %s + +func.func @simple_matmul(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> tensor { + %0 = linalg.matmul {__internal_linalg_transform__ = "simple_gemm"} + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + return %0 : tensor +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_matmul : benefit(1) { + %args = operands + %results = types + %op = operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %op with "transform.dialect" + } + + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_matmul in %arg1 + %1:3 = transform.structured.get_num_threads %0 tile_sizes [10, 20, 0] + %3:2 = transform.structured.tile_to_foreach %0 num_threads [%1#0, %1#1, 0] tile_sizes [10, 20, 0] + } +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 10)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 * 20)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)> +// CHECK: func.func @simple_matmul( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]] +// CHECK: %[[NUM_TILES0:.+]] = affine.apply #[[MAP0]]()[%[[M]]] +// CHECK: %[[NUM_TILES1:.+]] = affine.apply #[[MAP1]]()[%[[N]]] +// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]] +// CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NUM_TILES0]], %[[NUM_TILES1]]) +// CHECK-DAG: %[[TSIZE0:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[M]]] +// CHECK-DAG: %[[TSIZE1:.+]] = affine.min #[[MAP5]](%[[IV1]])[%[[N]]] +// CHECK-DAG: %[[LB0:.+]] = affine.apply #[[MAP2]](%[[IV0]]) +// CHECK: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[LB0]], 0] [%[[TSIZE0]], %[[K]]] +// CHECK: %[[LB1:.+]] = affine.apply #[[MAP3]](%[[IV1]]) +// CHECK: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[LB1]]] [%[[K]], %[[TSIZE1]]] +// CHECK: %[[LB0:.+]] = affine.apply #[[MAP2]](%[[IV0]] +// CHECK: %[[LB1:.+]] = affine.apply #[[MAP3]](%[[IV1]]) +// CHECK: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ARG2]][%[[LB0]], %[[LB1]]] [%[[TSIZE0]], %[[TSIZE1]]] +// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : +// CHECK-SAME: outs(%[[INIT_TILE]] : +// CHECK: scf.{{.+}}.perform_concurrently { +// CHECK-NEXT: tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[ARG2]] +// CHECK-SAME: [%[[LB0]], %[[LB1]]] [%[[TSIZE0]], %[[TSIZE1]]] + +// ----- + +func.func @simple_matmul_memref(%arg0 : memref, %arg1 : memref, + %arg2 : memref) { + linalg.matmul {__internal_linalg_transform__ = "simple_gemm_memref"} + ins(%arg0, %arg1 : memref, memref) + outs(%arg2 : memref) + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_matmul : benefit(1) { + %args = operands + %results = types + %op = operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %op with "transform.dialect" + } + + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_matmul in %arg1 + %1:3 = transform.structured.get_num_threads %0 tile_sizes [10, 20, 0] + %3:2 = transform.structured.tile_to_foreach %0 num_threads [%1#0, %1#1, 0] tile_sizes [10, 20, 0] + } +} + +// CHECK-LABEL: func.func @simple_matmul_memref +// CHECK: %[[NUM_TILES0:.+]] = affine.apply +// CHECK: %[[NUM_TILES1:.+]] = affine.apply +// CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NUM_TILES0]], %[[NUM_TILES1]]) +// CHECK-NOT: {{perform_concurrently}} + +// ----- + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> +func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>) { + %init0 = linalg.init_tensor [128, 300, 200] : tensor<128x300x200xf32> + %init1 = linalg.init_tensor [300, 128, 200] : tensor<300x128x200xf32> + %0:2 = linalg.generic { + indexing_maps = [#map0, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel"]} + {__internal_linalg_transform__ = "parallel_generic_transpose"} + ins(%arg0 : tensor<128x200x300xf32>) + outs(%init0, %init1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + linalg.yield %b0, %b0 : f32, f32 + } -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>) + return %0#0, %0#1 : tensor<128x300x200xf32>, tensor<300x128x200xf32> +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_matmul : benefit(1) { + %args = operands + %results = types + %op = operation "linalg.generic"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %op with "transform.dialect" + } + + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_matmul in %arg1 + %1:3 = transform.structured.get_num_threads %0 tile_sizes [10, 0, 20] + %3:2 = transform.structured.tile_to_foreach %0 num_threads [%1#0, 0, %1#2] tile_sizes [10, 0, 20] + } +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 10)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * 20)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * -10 + 128, 10)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 * -20 + 300, 20)> +// CHECK: func.func @multi_result +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>) +// CHECK-DAG: %[[NUM_TILE0:.+]] = arith.constant 13 : index +// CHECK-DAG: %[[NUM_TILE1:.+]] = arith.constant 15 : index +// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [128, 300, 200] +// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [300, 128, 200] +// CHECK: %[[RESULT:.+]]:2 = scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NUM_TILE0]], %[[NUM_TILE1]]) +// CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP2]](%[[IV0]]) +// CHECK-DAG: %[[TS1:.+]] = affine.min #[[MAP3]](%[[IV1]]) +// CHECK-DAG: %[[LB0:.+]] = affine.apply #[[MAP0]](%[[IV0]]) +// CHECK-DAG: %[[LB1:.+]] = affine.apply #[[MAP1]](%[[IV1]]) +// CHECK: %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]] +// CHECK-SAME: [%[[LB0]], 0, %[[LB1]]] [%[[TS0]], 200, %[[TS1]]] +// CHECK-DAG: %[[LB00:.+]] = affine.apply #[[MAP0]](%[[IV0]]) +// CHECK-DAG: %[[LB01:.+]] = affine.apply #[[MAP1]](%[[IV1]]) +// CHECK: %[[INIT0_TILE:.+]] = linalg.init_tensor [%[[TS0]], %[[TS1]], 200] : +// CHECK-DAG: %[[LB11:.+]] = affine.apply #[[MAP1]](%[[IV1]]) +// CHECK-DAG: %[[LB10:.+]] = affine.apply #[[MAP0]](%[[IV0]]) +// CHECK: %[[INIT1_TILE:.+]] = linalg.init_tensor [%[[TS1]], %[[TS0]], 200] : +// CHECK: %[[RESULT_TILE:.+]]:2 = linalg.generic +// CHECK-SAME: ins(%[[ARG_TILE]] : +// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] : +// CHECK: perform_concurrently +// CHECK-NEXT: parallel_insert_slice %[[RESULT_TILE]]#0 into %[[INIT0]][%[[LB00]], %[[LB01]], 0] [%[[TS0]], %[[TS1]], 200] +// CHECK-NEXT: parallel_insert_slice %[[RESULT_TILE]]#1 into %[[INIT1]][%[[LB11]], %[[LB10]], 0] [%[[TS1]], %[[TS0]], 200] +// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 + +// ----- + +func.func @conv2D(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> tensor { + %0 = linalg.conv_2d_nhwc_hwcf { + strides = dense<[2, 3]> : tensor<2xi64>, + dilation = dense<[4, 5]> : tensor<2xi64>, + __internal_linalg_transform__ = "simple_conv"} + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + return %0 : tensor +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_matmul : benefit(1) { + %args = operands + %results = types + %op = operation "linalg.conv_2d_nhwc_hwcf"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %op with "transform.dialect" + } + + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_matmul in %arg1 + %1:7 = transform.structured.get_num_threads %0 tile_sizes [10, 20, 30, 0, 0, 0, 0] + %3:2 = transform.structured.tile_to_foreach %0 num_threads [%1#0, %1#1, %1#2, 0, 0, 0, 0] + tile_sizes [10, 20, 30, 0, 0, 0, 0] + } +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 ceildiv 30)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 * 10)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0) -> (d0 * 20)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0) -> (d0 * 30)> +// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)> +// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)> +// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0)[s0] -> (d0 * -30 + s0, 30)> +// CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0) -> (d0 * 40)> +// CHECK-DAG: #[[MAP10:.+]] = affine_map<(d0)[s0] -> (d0 * 2 + s0 - 2)> +// CHECK-DAG: #[[MAP11:.+]] = affine_map<(d0) -> (d0 * 90)> +// CHECK-DAG: #[[MAP12:.+]] = affine_map<(d0)[s0] -> (d0 * 3 + s0 - 3)> +// CHECK: func.func @conv2D +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[FILTER:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[N:.+]] = tensor.dim %[[INPUT]], %[[C0]] +// CHECK-DAG: %[[R:.+]] = tensor.dim %[[INIT]], %[[C1]] +// CHECK-DAG: %[[S:.+]] = tensor.dim %[[INIT]], %[[C2]] +// CHECK: %[[NT0:.+]] = affine.apply #[[MAP0]]()[%[[N]]] +// CHECK: %[[NT1:.+]] = affine.apply #[[MAP1]]()[%[[R]]] +// CHECK: %[[NT2:.+]] = affine.apply #[[MAP2]]()[%[[S]]] +// CHECK-DAG: %[[C:.+]] = tensor.dim %[[INPUT]], %[[C3]] +// CHECK-DAG: %[[P:.+]] = tensor.dim %[[FILTER]], %[[C0]] +// CHECK-DAG: %[[Q:.+]] = tensor.dim %[[FILTER]], %[[C1]] +// CHECK-DAG: %[[F:.+]] = tensor.dim %[[FILTER]], %[[C3]] +// CHECK-DAG: %[[R:.+]] = tensor.dim %[[INIT]], %[[C1]] +// CHECK-DAG: %[[S:.+]] = tensor.dim %[[INIT]], %[[C2]] +// CHECK-DAG: %[[N:.+]] = tensor.dim %[[INPUT]], %[[C0]] +// CHECK: %[[RESULT:.+]] = scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]], %[[IV2:.+]]) in (%[[NT0]], %[[NT1]], %[[NT2]]) +// CHECK-DAG: %[[TSIZE0:.+]] = affine.min #[[MAP6]](%[[IV0]])[%[[N]]] +// CHECK-DAG: %[[TSIZE1:.+]] = affine.min #[[MAP7]](%[[IV1]])[%[[R]]] +// CHECK-DAG: %[[TSIZE2:.+]] = affine.min #[[MAP8]](%[[IV2]])[%[[S]]] +// CHECK-DAG: %[[LB0:.+]] = affine.apply #[[MAP3]](%[[IV0]]) +// CHECK-DAG: %[[LB3:.+]] = affine.apply #[[MAP9]](%[[IV1]]) +// CHECK-DAG: %[[LB4:.+]] = affine.apply #[[MAP11]](%[[IV2]]) +// CHECK-DAG: %[[TSIZE3:.+]] = affine.apply #[[MAP10]](%[[TSIZE1]])[%[[P]]] +// CHECK-DAG: %[[TSIZE4:.+]] = affine.apply #[[MAP12]](%[[TSIZE2]])[%[[Q]]] +// CHECK: %[[LHS:.+]] = tensor.extract_slice %[[INPUT]][%[[LB0]], %[[LB3]], %[[LB4]], 0] [%[[TSIZE0]], %[[TSIZE3]], %[[TSIZE4]], %[[C]]] [1, 1, 1, 1] : +// CHECK: %[[LHS:.+]] = tensor.extract_slice %[[FILTER]][0, 0, 0, 0] [%[[P]], %[[Q]], %[[C]], %[[F]]] +// CHECK-DAG: %[[LB0:.+]] = affine.apply #[[MAP3]](%[[IV0]]) +// CHECK-DAG: %[[LB1:.+]] = affine.apply #[[MAP4]](%[[IV1]]) +// CHECK-DAG: %[[LB2:.+]] = affine.apply #[[MAP5]](%[[IV2]]) +// CHECK: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[LB0]], %[[LB1]], %[[LB2]], 0] [%[[TSIZE0]], %[[TSIZE1]], %[[TSIZE2]], %[[F]]] +// CHECK: %[[TILE:.+]] = linalg.conv_2d_nhwc_hwcf +// CHECK: perform_concurrently +// CHECK-NEXT: parallel_insert_slice %[[TILE]] into %[[INIT]][%[[LB0]], %[[LB1]], %[[LB2]], 0] [%[[TSIZE0]], %[[TSIZE1]], %[[TSIZE2]], %[[F]]] + +// ----- + +// CHECK-DAG: #[[$MAP_ADD_I0:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 10)> +// CHECK-DAG: #[[$MAP_ADD_I1:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 20)> + +// CHECK-LABEL: @indexed_semantics +func.func @indexed_semantics(%arg0: tensor, %arg1: tensor) -> tensor { + // Check that we correctly amend "linalg.index" results. + + // CHECK: scf.foreach_thread (%[[I0:.+]], %[[I1:.+]]) in (%{{.+}}, %{{.+}}) -> + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + {__internal_linalg_transform__ = "indexed_semantics"} + ins(%arg0: tensor) + outs(%arg1: tensor) { + ^bb0(%arg2: f32, %arg3: f32): + // CHECK: %[[INDEX0:.+]] = linalg.index 0 + // CHECK: %[[INDEX0_AMENDED:.+]] = affine.apply #[[$MAP_ADD_I0]](%[[INDEX0]], %[[I0]]) + %1 = linalg.index 0 : index + // CHECK: %[[INDEX1:.+]] = linalg.index 1 + // CHECK: %[[INDEX1_AMENDED:.+]] = affine.apply #[[$MAP_ADD_I1]](%[[INDEX1]], %[[I1]]) + %2 = linalg.index 1 : index + // CHECK: arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]] + %3 = arith.addi %1, %2 : index + %4 = arith.index_cast %3 : index to i64 + %5 = arith.uitofp %4 : i64 to f32 + %6 = arith.addf %5, %arg2 : f32 + linalg.yield %6 : f32 + } -> (tensor) + return %0 : tensor +} + + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_matmul : benefit(1) { + %args = operands + %results = types + %op = operation "linalg.generic"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %op with "transform.dialect" + } + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_matmul in %arg1 + %1:2 = transform.structured.get_num_threads %0 tile_sizes [10, 20] + %3:2 = transform.structured.tile_to_foreach %0 num_threads [%1#0, %1#1] + tile_sizes [10, 20] + } +} \ No newline at end of file