diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -341,6 +341,90 @@ }]; } +def PackOp : Op,]> { + let description = [{ + Pack a LinalgOp by applying a data tiling transformation on the op and + packing the operands according to the `packed_sizes` specification. + + Iterator dimensions are tiled in their canonical order in the op spec. + Operands are packed according to the same canonical order of the op iterator + dimensions. + + Specifying a packed size of 0 for an iterator removes it from consideration + for packing. + + `tensor.pack` (resp. `tensor.unpack`) operations are inserted for the operands + (resp. results) that need to be packed (resp. unpacked) according to the + `packed_sizes` specification. + + #### Example + + Consider a `linalg.matmul` with indexing maps: + ``` + // M N K M K + // affine_map<(d0, d1, d2) -> (d0, d2)> + // K N + // affine_map<(d0, d1, d2) -> (d2, d1)> + // M N + // affine_map<(d0, d1, d2) -> (d0, d1)> + %0 = linalg.matmul ins(%A, %B: tensor, tensor) + outs( %C: tensor) + ``` + + Specifying packed_sizes [2, 3, 4] results in tiling the iterator dimensions + M, N and K, in this order, in both the op and its operands. + ``` + // M N K m n k M K m k + // affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> + // K N n k + // affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)> + // M N m n + // affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> + %0 = linalg.generic_representing_some_higher_d_matmul + ins(%A, %B: tensor, tensor) + outs( %C: tensor) + ``` + In particular, note that the second operand `B` has shape `KxNxnxk` (and not + `KxNxkxn` as one could expect by looking **only** at the operand). + + Other layouts can be obtained unsurprisingly from this canonical + transformation by composing the resulting operation with a (future) + `transform.structured.pack_transpose` op. + This composition allows separating concerns and composes better compared + to adding additional permutation attributes to this transform op. + + #### Return modes + + This operation applies to a single Linalg op, otherwise it fails. + This operation may produce a definiteFailure if the packing fails for any + reason. + + The returned handle point to the packed LinalgOp. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + Variadic:$packed_sizes, + DefaultValuedAttr:$static_packed_sizes); + let results = (outs TransformHandleTypeInterface:$packed_op); + let assemblyFormat = [{ + $target + `packed_sizes` `=` custom($packed_sizes, + $static_packed_sizes) + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure apply( + transform::TransformResults &transformResults, + transform::TransformState &state); + + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedPackedSizes(); + }]; +} + //===----------------------------------------------------------------------===// // PadOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/ScopeExit.h" @@ -35,6 +36,8 @@ using namespace mlir::transform; #define DEBUG_TYPE "linalg-transforms" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define DBGSNL() (llvm::dbgs() << "\n") /// Attempts to apply the pattern specified as template argument to the given /// operation. The pattern is expected to have a `returningMatchAndRewrite` @@ -60,6 +63,67 @@ return cast(result->getOperation()); } +/// Assuming that `ofr` is an index attr or a transform dialect handle mapped +/// to exactly one op with one index result, return that value. +static DiagnosedSilenceableFailure unpackSingleIndexResultPDLOperations( + transform::TransformState &state, TransformOpInterface transformOp, + SmallVector &result, ArrayRef ofrs) { + for (OpFoldResult ofr : ofrs) { + if (ofr.is()) { + if (!ofr.get().isa()) + return transformOp.emitDefiniteFailure() << "expected IntegerAttr"; + result.push_back(ofr); + continue; + } + ArrayRef payloadOps = state.getPayloadOps(ofr.get()); + if (payloadOps.size() != 1) { + DiagnosedSilenceableFailure diag = + transformOp.emitSilenceableError() + << "handle must be mapped to exactly one payload op"; + diag.attachNote(ofr.get().getLoc()) + << "mapped to " << payloadOps.size() << " payload ops"; + return diag; + } + + Operation *op = payloadOps[0]; + if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { + DiagnosedSilenceableFailure diag = + transformOp.emitSilenceableError() + << "payload op must have exactly 1 index result"; + diag.attachNote(op->getLoc()) + << "has " << op->getNumResults() << " results"; + return diag; + } + result.push_back(op->getResult(0)); + } + + return DiagnosedSilenceableFailure::success(); +} + +// Given a list of OpFoldResults that are either index attrs or op +// handles, return a list of OpFoldResults where all op handles are +// replaced with the first (and only) OpResult of that payload op. (There +// must be exactly one mapped payload op and it must have exactly one +// index result.) +static DiagnosedSilenceableFailure unpackSingleIndexResultPDLOperations( + transform::TransformState &state, TransformOpInterface transformOp, + SmallVector &result, Value packedHandle) { + ArrayRef payloadOps = state.getPayloadOps(packedHandle); + for (Operation *op : payloadOps) { + if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { + DiagnosedSilenceableFailure diag = + transformOp.emitSilenceableError() + << "payload op must have exactly 1 index result"; + diag.attachNote(op->getLoc()) + << "has " << op->getNumResults() << " results"; + return diag; + } + result.push_back(op->getResult(0)); + } + + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // DecomposeOp //===----------------------------------------------------------------------===// @@ -743,6 +807,334 @@ modifiesPayload(effects); } +//===---------------------------------------------------------------------===// +// PackOp +//===---------------------------------------------------------------------===// + +SmallVector transform::PackOp::getMixedPackedSizes() { + Builder b(getContext()); + return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b); +} + +/// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim). +static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) { + bool found = false; + for (AffineExpr e : map.getResults()) { + if (!e.isFunctionOfDim(dim)) + continue; + if (found) + return false; + found = true; + } + return true; +} + +/// Return the index of the first result of `map` that is a function of +/// AffineDimExpr(dim), std::nullopt otherwise. +static std::optional getFirstResultIndexFunctionOf(AffineMap map, + int64_t dim) { + for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { + AffineExpr expr = map.getResult(i); + if (!expr.isFunctionOfDim(dim)) + continue; + return i; + } + return std::nullopt; +} + +/// Perform one step of packing of a LinalgOp's metadata along `dim` into the +/// `newDim` at `iteratorTypes.size()` by: +/// 1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`. +/// 2. Appending a `newDim` to the domain of every indexing map. +/// 3. For each operand (i.e. for each map in `indexingMaps`), perform packing +/// by potentially adding a `newDim` result to `map`. +/// The preserved invariant is that `iteratorTypes.size()` is always equal to +/// `map.getNumDims()` for every map in `indexingMaps`. +/// +/// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update. +/// Return a vector that records the optional packing for each operand. +/// Return failure if the packed indexing cannot be represented with a LinalgOp. +/// +/// Further details: +/// ================ +/// The current implementation of packing (i.e. data tiling) consists of +/// rewriting a linearized strip-mined form into a higher-dimensional access. +/// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite +/// `I` into `4 * i + ii`, where `0 <= ii < 4`. +/// The access is further rewritten as `A[i][f(j, k, l)][ii]`. +/// +/// This rewrite into higher dimensional access is not possible for general +/// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr: +/// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we +/// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`. +/// The rewrite of the access would be a form not representable in Linalg: +/// `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`. +/// Note however that as `J` and `ii` iterate, the accesses do not have a +/// particular alignment, so packing does not achieve alignment in this case +/// +/// In the future, we may want to consider a mixed-form that allows some +/// alignment in the presence of multiple accesses: +/// `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]` +/// And would rewrite accesses as: +/// `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]` +static FailureOr>> +packLinalgMetadataOnce(SmallVectorImpl &indexingMaps, + SmallVectorImpl &iteratorTypes, + int64_t dim) { + int64_t newDim = iteratorTypes.size(); + iteratorTypes.push_back(iteratorTypes[dim]); + + SmallVector> packedDimPerIndexingMap( + indexingMaps.size(), std::nullopt); + SmallVector newMaps; + for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e; + ++operandIdx) { + AffineMap map = indexingMaps[operandIdx]; + + // Add the `newDim` to map whatever the case. + assert(map.getNumDims() == newDim && "num dims invariant violation"); + map = map.shiftDims(1, newDim); + + // Get the at-most-1 index of the result that is a function of `dim`. + // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which + // logically chunks dimension `dim` into `K * dim + newDim`, where the + // packing factor `K` is specified separately. + assert(hasAtMostOneResultFunctionOfDim(map, dim) && + "num results invariant violation"); + auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim); + if (!maybeOperandDimensionToPack.has_value()) { + newMaps.push_back(map); + continue; + } + + // We can only pack AffineDimExpr atm. + if (!map.getResult(maybeOperandDimensionToPack.value()) + .isa()) + return failure(); + + // Add `newDim` to the results of the map. + map = map.insertResult(Builder(map.getContext()).getAffineDimExpr(newDim), + map.getNumResults()); + newMaps.push_back(map); + + // Record the that `operandIdx` is packed. + packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack; + } + indexingMaps = newMaps; + + return packedDimPerIndexingMap; +} + +namespace { + +/// Helper struct to encode packing along one dimension of a LinalgOp. +struct PackedOperandsDim { + OpFoldResult packedSize; + SmallVector> packedDimForEachOperand; +}; + +/// Helper struct to encode packing along all dimensions of a LinalgOp. +struct PackedOperandsDimList { + void push_back(PackedOperandsDim &&packedOperandsDims) { + spec.emplace_back(packedOperandsDims); + } + /// Return all the dims that have been packed for operand @ `operandPos`. + SmallVector extractPackedDimsForOperand(int64_t operandPos); + /// Return all the pack sizes by which an operand @ `operandPos` is packed. + SmallVector extractPackSizesForOperand(int64_t operandPos); + +private: + SmallVector spec; +}; + +} // namespace + +SmallVector +PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) { + SmallVector res; + for (int64_t i = 0, e = spec.size(); i < e; ++i) { + if (!spec[i].packedDimForEachOperand[operandPos].has_value()) + continue; + res.push_back(spec[i].packedDimForEachOperand[operandPos].value()); + } + return res; +} + +SmallVector +PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) { + SmallVector res; + for (int64_t i = 0, e = spec.size(); i < e; ++i) { + if (!spec[i].packedDimForEachOperand[operandPos].has_value()) + continue; + res.push_back(spec[i].packedSize); + } + return res; +} + +/// Implement packing of a single LinalgOp by performing packing by +/// `packedSizeHandles`. There must be one packedSizeHandles entry per +/// `linalgOp` iterator. Return the packed Linalg op on success, failure +/// otherwise. +static FailureOr +packOneLinalgOp(RewriterBase &rewriter, transform::TransformState &state, + TransformOpInterface transformOp, linalg::LinalgOp linalgOp, + ArrayRef packedSizeHandles) { + assert(packedSizeHandles.size() == linalgOp.getNumLoops() && + "incorrect number of pack sizes"); + + Location loc = linalgOp->getLoc(); + SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); + SmallVector iteratorTypes = + linalgOp.getIteratorTypesArray(); + LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n"; + llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL(); + llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); + DBGSNL();); + + // Unpack handles to constants or actual SSA index values. + SmallVector packedSizes; + DiagnosedSilenceableFailure status = unpackSingleIndexResultPDLOperations( + state, transformOp, packedSizes, packedSizeHandles); + + // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i]. + PackedOperandsDimList listOfPackedOperandsDim; + for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) { + std::optional maybeConstant = getConstantIntValue(packedSizes[i]); + // Skip tile sizes explicitly set to 0. + if (maybeConstant.has_value() && maybeConstant.value() == 0) + continue; + + PackedOperandsDim packedOperandsDims; + packedOperandsDims.packedSize = packedSizes[i]; + FailureOr>> + maybePackedDimForEachOperand = + packLinalgMetadataOnce(indexingMaps, iteratorTypes, i); + if (failed(maybePackedDimForEachOperand)) + return failure(); + packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand; + listOfPackedOperandsDim.push_back(std::move(packedOperandsDims)); + + LLVM_DEBUG( + DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i] + << "\n"; + llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL(); + llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); DBGSNL(); + llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand, + DBGS() << "packedDimForEachOperand: "); + DBGSNL();); + } + + // Step 2. Propagate packing to all LinalgOp operands. + SmallVector inputsAndInits, results; + for (auto operandsList : + {linalgOp.getDpsInputOperands(), linalgOp.getDpsInitOperands()}) { + for (OpOperand *opOperandPtr : operandsList) { + int64_t pos = opOperandPtr->getOperandNumber(); + Value operand = opOperandPtr->get(); + SmallVector innerPos = + listOfPackedOperandsDim.extractPackedDimsForOperand(pos); + SmallVector innerPackSizes = + listOfPackedOperandsDim.extractPackSizesForOperand(pos); + LLVM_DEBUG( + DBGS() << "operand: " << operand << "\n"; + llvm::interleaveComma(innerPos, DBGS() << "innerPos: "); DBGSNL(); + llvm::interleaveComma(innerPackSizes, DBGS() << "innerPackSizes: "); + DBGSNL();); + if (innerPackSizes.empty()) { + inputsAndInits.push_back(operand); + continue; + } + Value dest = tensor::PackOp::createDestinationTensor( + rewriter, loc, operand, innerPackSizes, innerPos, + /*outerDimsPerm=*/{}); + // TODO: value of the padding attribute should be determined by consumers. + Attribute zeroAttr = + rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType())); + Value zero = rewriter.create(loc, zeroAttr); + inputsAndInits.push_back(rewriter.create( + loc, operand, dest, innerPos, innerPackSizes, zero)); + } + } + + // Step 3. Build the packed op, use the type of `inits` as result types. + ValueRange inputs = + ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs()); + ValueRange inits = + ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits()); + auto packedLinalgOp = rewriter.create( + linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps, + iteratorTypes); + packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0)); + + // Step 4. Propagate packing to all the op results. + for (OpResult result : packedLinalgOp->getResults()) { + int64_t resultNum = result.getResultNumber(); + tensor::PackOp maybePackedInit = + inits[resultNum].getDefiningOp(); + if (!maybePackedInit) { + results.push_back(result); + continue; + } + // Build the symmetrical UnPackOp to the existing PackOp. + results.push_back(rewriter.create( + packedLinalgOp->getLoc(), result, maybePackedInit.getSource(), + maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles())); + } + + // Step 5. Replace `linalgOp`. + rewriter.replaceOp(linalgOp, results); + + // Return packedLinalgOp. + return cast(packedLinalgOp.getOperation()); +} + +DiagnosedSilenceableFailure +transform::PackOp::apply(transform::TransformResults &transformResults, + transform::TransformState &state) { + ArrayRef targetOps = state.getPayloadOps(getTarget()); + // If nothing to pack, propagate success. + if (targetOps.empty()) { + transformResults.set(getPackedOp().cast(), {}); + return DiagnosedSilenceableFailure::success(); + } + // Fail on multi-op handles. + auto linalgOp = dyn_cast(targetOps.front()); + if (targetOps.size() != 1 || !linalgOp) { + // TODO: remove this unnecessary set to empty once crashes are fixed. + transformResults.set(getPackedOp().cast(), {}); + return emitSilenceableError() + << "requires target to map to exactly 1 LinalgOp (got " + << targetOps.size() << ")"; + } + // Fail on mismatched number of pack sizes. + if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) { + // TODO: remove this unnecessary set to empty once crashes are fixed. + transformResults.set(getPackedOp().cast(), {}); + return emitSilenceableError() + << "requires number of packed sizes match the number of loops (" + << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops() + << ")"; + } + + IRRewriter rewriter(linalgOp->getContext()); + rewriter.setInsertionPoint(linalgOp); + FailureOr maybeResult = + packOneLinalgOp(rewriter, state, *this, linalgOp, getMixedPackedSizes()); + if (failed(maybeResult)) + return emitDefiniteFailure("data tiling failed"); + + transformResults.set(getPackedOp().cast(), + maybeResult->getOperation()); + return DiagnosedSilenceableFailure::success(); +} + +void transform::PackOp::getEffects( + SmallVectorImpl &effects) { + transform::consumesHandle(getTarget(), effects); + transform::onlyReadsHandle(getPackedSizes(), effects); + transform::producesHandle(getPackedOp(), effects); +} + //===---------------------------------------------------------------------===// // PadOp //===---------------------------------------------------------------------===// @@ -1608,68 +2000,6 @@ /*mapping=*/mapping); } -/// Assuming that `ofr` is an index attr or a transform dialect handle mapped -/// to exactly one op with one index result, return that value. -static DiagnosedSilenceableFailure unpackPDLOperations( - transform::TransformState &state, TransformOpInterface transformOp, - SmallVector &result, ArrayRef ofrs) { - for (OpFoldResult ofr : ofrs) { - if (ofr.is()) { - if (!ofr.get().isa()) - return transformOp.emitDefiniteFailure() << "expected IntegerAttr"; - result.push_back(ofr); - continue; - } - ArrayRef payloadOps = state.getPayloadOps(ofr.get()); - if (payloadOps.size() != 1) { - DiagnosedSilenceableFailure diag = - transformOp.emitSilenceableError() - << "handle must be mapped to exactly one payload op"; - diag.attachNote(ofr.get().getLoc()) - << "mapped to " << payloadOps.size() << " payload ops"; - return diag; - } - - Operation *op = payloadOps[0]; - if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { - DiagnosedSilenceableFailure diag = - transformOp.emitSilenceableError() - << "payload op must have exactly 1 index result"; - diag.attachNote(op->getLoc()) - << "has " << op->getNumResults() << " results"; - return diag; - } - result.push_back(op->getResult(0)); - } - - return DiagnosedSilenceableFailure::success(); -} - -// Given a list of OpFoldResults that are either index attrs or op -// handles, return a list of OpFoldResults where all op handles are -// replaced with the first (and only) OpResult of that payload op. (There -// must be exactly one mapped payload op and it must have exactly one -// index result.) -static DiagnosedSilenceableFailure -unpackPDLOperations(transform::TransformState &state, - TransformOpInterface transformOp, - SmallVector &result, Value packedHandle) { - ArrayRef payloadOps = state.getPayloadOps(packedHandle); - for (Operation *op : payloadOps) { - if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { - DiagnosedSilenceableFailure diag = - transformOp.emitSilenceableError() - << "payload op must have exactly 1 index result"; - diag.attachNote(op->getLoc()) - << "has " << op->getNumResults() << " results"; - return diag; - } - result.push_back(op->getResult(0)); - } - - return DiagnosedSilenceableFailure::success(); -} - DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl( RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, ArrayRef targets, @@ -1724,18 +2054,18 @@ SmallVector mixedNumThreads; DiagnosedSilenceableFailure status = getPackedNumThreads() - ? unpackPDLOperations(state, transformOp, mixedNumThreads, - getPackedNumThreads()) - : unpackPDLOperations(state, transformOp, mixedNumThreads, - getMixedNumThreads()); + ? unpackSingleIndexResultPDLOperations( + state, transformOp, mixedNumThreads, getPackedNumThreads()) + : unpackSingleIndexResultPDLOperations( + state, transformOp, mixedNumThreads, getMixedNumThreads()); if (!status.succeeded()) return status; SmallVector mixedTileSizes; status = getPackedTileSizes() - ? unpackPDLOperations(state, transformOp, mixedTileSizes, - getPackedTileSizes()) - : unpackPDLOperations(state, transformOp, mixedTileSizes, - getMixedTileSizes()); + ? unpackSingleIndexResultPDLOperations( + state, transformOp, mixedTileSizes, getPackedTileSizes()) + : unpackSingleIndexResultPDLOperations( + state, transformOp, mixedTileSizes, getMixedTileSizes()); if (!status.succeeded()) return status; diff --git a/mlir/test/Dialect/Linalg/transform-op-pack.mlir b/mlir/test/Dialect/Linalg/transform-op-pack.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-pack.mlir @@ -0,0 +1,406 @@ +// RUN: mlir-opt -test-transform-dialect-interpreter -split-input-file -verify-diagnostics -allow-unregistered-dialect %s | FileCheck %s + +#map = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> +#reduction_2d_trait = { + indexing_maps = [#map, #map1], + iterator_types = ["parallel", "reduction"] +} + +// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2) -> (d0)> + +// CHECK-LABEL: @reduction_2d_static +// CHECK-SAME: %[[T0:.+]]: tensor<3x7xf16>, +// CHECK-SAME: %[[T1:.+]]: tensor<3xf16> +func.func @reduction_2d_static(%t0: tensor<3x7xf16>, %t1: tensor<3xf16>) -> tensor<3xf16> { + // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<3x2x4xf16> + // CHECK: %[[PACKED:.*]] = tensor.pack %[[T0]] padding_value(%{{.*}} : f16) + // CHECK-SAME: inner_dims_pos = [1] inner_tiles = [4] into %[[EMPTY]] : tensor<3x7xf16> -> tensor<3x2x4xf16> + // CHECK-NOT: tensor.pack + // CHECK: linalg.generic + // CHECK-SAME: indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]]] + // CHECK-SAME: iterator_types = ["parallel", "reduction", "reduction"] + // CHECK-SAME: ins(%{{.*}} : tensor<3x2x4xf16>) + // CHECK-SAME: outs(%{{.*}} : tensor<3xf16>) + %2 = linalg.generic #reduction_2d_trait ins(%t0 : tensor<3x7xf16>) outs(%t1 : tensor<3xf16>) { + ^bb0(%in: f16, %out: f16): + %3 = arith.addf %in, %out : f16 + linalg.yield %3 : f16 + } -> tensor<3xf16> + + // CHECK-NOT: tensor.unpack + return %2 : tensor<3xf16> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + transform.structured.pack %0 packed_sizes = [0, 4] + : (!pdl.operation) -> (!transform.op<"linalg.generic">) +} + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d1)> +#col_reduction_2d_trait = { + indexing_maps = [#map, #map1], + iterator_types = ["reduction", "parallel"] +} + +// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2) -> (d1)> + +// CHECK-LABEL: @col_reduction_2d_static +// CHECK-SAME: %[[T0:.+]]: tensor<7x3xf16>, +// CHECK-SAME: %[[T1:.+]]: tensor<3xf16> +func.func @col_reduction_2d_static(%t0: tensor<7x3xf16>, %t1: tensor<3xf16>) -> tensor<3xf16> { + // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x4xf16> + // CHECK: %[[PACKED:.*]] = tensor.pack %[[T0]] padding_value(%{{.*}} : f16) + // CHECK-SAME: inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<7x3xf16> -> tensor<2x3x4xf16> + // CHECK-NOT: tensor.pack + // CHECK: linalg.generic + // CHECK-SAME: indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]]] + // CHECK-SAME: iterator_types = ["reduction", "parallel", "reduction"] + // CHECK-SAME: ins(%{{.*}} : tensor<2x3x4xf16>) + // CHECK-SAME: outs(%{{.*}} : tensor<3xf16>) + %2 = linalg.generic #col_reduction_2d_trait ins(%t0 : tensor<7x3xf16>) outs(%t1 : tensor<3xf16>) { + ^bb0(%in: f16, %out: f16): + %3 = arith.addf %in, %out : f16 + linalg.yield %3 : f16 + } -> tensor<3xf16> + + // CHECK-NOT: tensor.unpack + return %2 : tensor<3xf16> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + transform.structured.pack %0 packed_sizes = [4, 0] + : (!pdl.operation) -> (!transform.op<"linalg.generic">) +} + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> +#reduction_2d_trait = { + indexing_maps = [#map, #map1], + iterator_types = ["parallel", "reduction"] +} + +// CHECK-DAG: #[[$DIV4:.*]] = affine_map<()[s0] -> (s0 ceildiv 4)> +// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2) -> (d0)> + +// CHECK-LABEL: @reduction_2d_dynamic +// CHECK-SAME: %[[T0:.+]]: tensor, +// CHECK-SAME: %[[T1:.+]]: tensor +func.func @reduction_2d_dynamic(%t0: tensor, %t1: tensor) -> tensor { + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[D0:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor + // CHECK-DAG: %[[D1:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor + // CHECK: %[[D1B4:.*]] = affine.apply #[[$DIV4]]()[%[[D1]]] + // CHECK: %[[EMPTY:.*]] = tensor.empty(%[[D0]], %[[D1B4]]) : tensor + // CHECK: %[[PACKED:.*]] = tensor.pack %[[T0]] padding_value(%{{.*}} : f16) + // CHECK-SAME: inner_dims_pos = [1] inner_tiles = [4] into %[[EMPTY]] : tensor -> tensor + // CHECK-NOT: tensor.pack + // CHECK: linalg.generic + // CHECK-SAME: indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]]] + // CHECK-SAME: iterator_types = ["parallel", "reduction", "reduction"] + // CHECK-SAME: ins(%{{.*}} : tensor) + // CHECK-SAME: outs(%{{.*}} : tensor) + %2 = linalg.generic #reduction_2d_trait ins(%t0 : tensor) outs(%t1 : tensor) { + ^bb0(%in: f16, %out: f16): + %3 = arith.addf %in, %out : f16 + linalg.yield %3 : f16 + } -> tensor + + // CHECK-NOT: tensor.unpack + return %2 : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + transform.structured.pack %0 packed_sizes = [0, 4] + : (!pdl.operation) -> (!transform.op<"linalg.generic">) +} + + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> +#reduction_2d_trait = { + indexing_maps = [#map, #map1], + iterator_types = ["parallel", "reduction"] +} + +// CHECK-DAG: #[[$DIV3:.*]] = affine_map<()[s0] -> (s0 ceildiv 3)> +// CHECK-DAG: #[[$DIV4:.*]] = affine_map<()[s0] -> (s0 ceildiv 4)> +// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)> + +// CHECK-LABEL: @reduction_2d_dynamic +// CHECK-SAME: %[[T0:.+]]: tensor, +// CHECK-SAME: %[[T1:.+]]: tensor +func.func @reduction_2d_dynamic(%t0: tensor, %t1: tensor) -> tensor { + // CHECK: %[[PACKED_0:.*]] = tensor.pack %[[T0]] padding_value(%{{.*}} : f16) + // CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 4] into %{{.*}} : tensor -> tensor + // CHECK: %[[PACKED_1:.*]] = tensor.pack %[[T1]] padding_value(%{{.*}} : f16) + // CHECK-SAME: inner_dims_pos = [0] inner_tiles = [3] into %{{.*}} : tensor -> tensor + // CHECK-NOT: tensor.pack + // CHECK: linalg.generic + // CHECK-SAME: indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]]] + // CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel", "reduction"] + // CHECK-SAME: ins(%{{.*}} : tensor) + // CHECK-SAME: outs(%{{.*}} : tensor) + %2 = linalg.generic #reduction_2d_trait ins(%t0 : tensor) outs(%t1 : tensor) { + ^bb0(%in: f16, %out: f16): + %3 = arith.addf %in, %out : f16 + linalg.yield %3 : f16 + } -> tensor + + // CHECK: tensor.unpack %{{.*}} inner_dims_pos = [0] inner_tiles = [3] into %{{.*}} : tensor -> tensor + return %2 : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + transform.structured.pack %0 packed_sizes = [3, 4] + : (!pdl.operation) -> (!transform.op<"linalg.generic">) +} + +// ----- + +// M N K m n k M K m k +// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> +// K N n k +// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)> +// M N m n +// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> + +// CHECK-LABEL: @matmul +// CHECK-SAME: %[[A:[0-9a-zA-Z]+]]: tensor, +// CHECK-SAME: %[[B:[0-9a-zA-Z]+]]: tensor, +// CHECK-SAME: %[[C:[0-9a-zA-Z]+]]: tensor +func.func @matmul(%A: tensor, %B: tensor, %C: tensor) + -> tensor { + + // CHECK: %[[PACK_A:.*]] = tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [2, 4] + // CHECK-SAME: : tensor -> tensor + // CHECK: %[[PACK_B:.*]] = tensor.pack %{{.*}} inner_dims_pos = [1, 0] inner_tiles = [3, 4] + // CHECK-SAME: : tensor -> tensor + // CHECK: %[[PACK_C:.*]] = tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [2, 3] + // CHECK-SAME: : tensor -> tensor + + // CHECK: linalg.generic {indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]], #[[$PACKED_MAP_2]]] + // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} + // CHECK-SAME: ins(%{{.*}} : tensor, tensor) + // CHECK-SAME: outs(%{{.*}} : tensor) + %0 = linalg.matmul ins(%A, %B: tensor, tensor) + outs(%C: tensor) + -> tensor + + // CHECK: tensor.unpack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [2, 3] + // CHECK-SAME: : tensor -> tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + // M N K + %1 = transform.structured.pack %0 packed_sizes = [2, 3, 4] + : (!pdl.operation) -> (!transform.op<"linalg.generic">) +} + +// ----- + +// N F H W C KH KW f c +// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d4, d2 + d5, d3 + d6, d8)> +// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d4, d5, d6, d7, d8)> +// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d7)> + +// CHECK-LABEL: @conv_2d_nchw_fchw +// CHECK-SAME: %[[INPUT:.+]]: tensor<14x512x28x28xf32>, +// CHECK-SAME: %[[FILTER:.+]]: tensor<1024x512x1x1xf32> +// CHECK-SAME: %[[INIT:.+]]: tensor<14x1024x28x28xf32> +func.func @conv_2d_nchw_fchw(%i: tensor<14x512x28x28xf32>, %f: tensor<1024x512x1x1xf32>, + %o: tensor<14x1024x28x28xf32>) -> tensor<14x1024x28x28xf32> { + + // CHECK: %[[PACK_INPUT:.*]] = tensor.pack %{{.*}} inner_dims_pos = [1] inner_tiles = [8] + // CHECK-SAME: : tensor<14x512x28x28xf32> -> tensor<14x64x28x28x8xf32> + // CHECK: %[[PACK_FILTER:.*]] = tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [4, 8] + // CHECK-SAME: : tensor<1024x512x1x1xf32> -> tensor<256x64x1x1x4x8xf32> + // CHECK: %[[PACK_INPUT:.*]] = tensor.pack %{{.*}} inner_dims_pos = [1] inner_tiles = [4] + // CHECK-SAME: : tensor<14x1024x28x28xf32> -> tensor<14x256x28x28x4xf32> + // CHECK: linalg.generic {indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]], #[[$PACKED_MAP_2]]] + // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction", "parallel", "reduction"]} + // CHECK-SAME: ins(%{{.*}} : tensor<14x64x28x28x8xf32>, tensor<256x64x1x1x4x8xf32>) + // CHECK-SAME: outs(%{{.*}} : tensor<14x256x28x28x4xf32>) + %0 = linalg.conv_2d_nchw_fchw ins(%i, %f: tensor<14x512x28x28xf32>, tensor<1024x512x1x1xf32>) + outs(%o: tensor<14x1024x28x28xf32>) -> tensor<14x1024x28x28xf32> + + // CHECK: tensor.unpack %{{.*}} inner_dims_pos = [1] inner_tiles = [4] + // CHECK-SAME: : tensor<14x256x28x28x4xf32> -> tensor<14x1024x28x28xf32> + return %0: tensor<14x1024x28x28xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match interface{LinalgOp} in %arg1 + // N F H W C KH KW + %1 = transform.structured.pack %0 packed_sizes = [0, 4, 0, 0, 8, 0, 0] + : (!pdl.operation) -> (!transform.op<"linalg.generic">) +} + +// ----- + +// N H W F KH KW C f c +// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d4, d2 + d5, d6, d8)> +// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d4, d5, d6, d3, d7, d8)> +// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d7)> + +// CHECK-LABEL: @conv_2d_nhwc_hwcf +// CHECK-SAME: %[[INPUT:.+]]: tensor, +// CHECK-SAME: %[[FILTER:.+]]: tensor<1x?x?x?xf32> +// CHECK-SAME: %[[INIT:.+]]: tensor +func.func @conv_2d_nhwc_hwcf(%input: tensor, %filter: tensor<1x?x?x?xf32>, %init: tensor) -> tensor { + + // CHECK: %[[PACK_INPUT:.*]] = tensor.pack %{{.*}} inner_dims_pos = [3] inner_tiles = [6] + // CHECK-SAME: : tensor -> tensor + // CHECK: %[[PACK_FILTER:.*]] = tensor.pack %{{.*}} inner_dims_pos = [3, 2] inner_tiles = [4, 6] + // CHECK-SAME: : tensor<1x?x?x?xf32> -> tensor<1x?x?x?x4x6xf32> + // CHECK: %[[PACK_OUTPUT:.*]] = tensor.pack %{{.*}} inner_dims_pos = [3] inner_tiles = [4] + // CHECK-SAME: : tensor -> tensor + + // CHECK: linalg.generic {indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]], #[[$PACKED_MAP_2]]] + // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction", "parallel", "reduction"]} + // CHECK-SAME: ins(%{{.*}} : tensor, tensor<1x?x?x?x4x6xf32>) + // CHECK-SAME: outs(%{{.*}} : tensor) + %0 = linalg.conv_2d_nhwc_hwcf + ins (%input, %filter: tensor, tensor<1x?x?x?xf32>) + outs (%init: tensor) -> tensor + + // CHECK: tensor.unpack %{{.*}} inner_dims_pos = [3] inner_tiles = [4] + // CHECK-SAME: : tensor -> tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match interface{LinalgOp} in %arg1 + // N H W F KH KW C + %1 = transform.structured.pack %0 packed_sizes = [0, 0, 0, 4, 0, 0, 6] + : (!pdl.operation) -> (!transform.op<"linalg.generic">) +} + +// ----- + +// CHECK-DAG: affine_map<()[s0, s1] -> (s0 ceildiv s1)> +// M N K n k M K k +// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)> +// K N n k +// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d1, d3, d4)> +// M N n +// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> + +// CHECK-LABEL: @matmul_dynamic_pack_size +// CHECK-SAME: %[[A:[0-9a-zA-Z]+]]: tensor, +// CHECK-SAME: %[[B:[0-9a-zA-Z]+]]: tensor, +// CHECK-SAME: %[[C:[0-9a-zA-Z]+]]: tensor +func.func @matmul_dynamic_pack_size(%A: tensor, %B: tensor, %C: tensor) + -> tensor { + // CHECK: %[[TS:.*]] = "some_tile_size"() : () -> index + %sz = "some_tile_size"() : () -> (index) + + // CHECK: %[[PACK_A:.*]] = tensor.pack %[[A]] {{.*}} inner_dims_pos = [1] inner_tiles = [%[[TS]]] + // CHECK-SAME: : tensor -> tensor + // CHECK: %[[PACK_B:.*]] = tensor.pack %[[B]] {{.*}} inner_dims_pos = [1, 0] inner_tiles = [%[[TS]], %[[TS]]] + // CHECK-SAME: : tensor -> tensor + // CHECK: %[[PACK_C:.*]] = tensor.pack %[[C]] {{.*}} inner_dims_pos = [1] inner_tiles = [%[[TS]]] + // CHECK-SAME: : tensor -> tensor + // CHECK: linalg.generic {indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]], #[[$PACKED_MAP_2]]] + // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "reduction"]} + // CHECK-SAME: ins(%{{.*}} : tensor, tensor) + // CHECK-SAME: outs(%{{.*}} : tensor) + %0 = linalg.matmul ins(%A, %B: tensor, tensor) + outs(%C: tensor) + -> tensor + + // CHECK: tensor.unpack %{{.*}} inner_dims_pos = [1] inner_tiles = [%[[TS]]] into %[[C]] + // CHECK-SAME: : tensor -> tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + %sz = transform.structured.match ops{["some_tile_size"]} in %arg1 + %1 = transform.structured.pack %0 packed_sizes = [0, %sz, %sz] + : (!pdl.operation) -> (!transform.op<"linalg.generic">) +} + +// ----- + +func.func @conv_cant_pack(%i: tensor<14x512x28x28xf32>, %f: tensor<1024x512x1x1xf32>, + %o: tensor<14x1024x28x28xf32>) -> tensor<14x1024x28x28xf32> { + %0 = linalg.conv_2d_nchw_fchw ins(%i, %f: tensor<14x512x28x28xf32>, tensor<1024x512x1x1xf32>) + outs(%o: tensor<14x1024x28x28xf32>) -> tensor<14x1024x28x28xf32> + return %0: tensor<14x1024x28x28xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match interface{LinalgOp} in %arg1 + // N F H W C KH KW + // expected-error @below {{data tiling failed}} + %1 = transform.structured.pack %0 packed_sizes = [0, 0, 4, 0, 0, 0, 0] + : (!pdl.operation) -> (!transform.op<"linalg.generic">) +} + +// ----- + +func.func @matmul(%A: tensor, %B: tensor, %C: tensor) + -> (tensor, tensor) { + %0 = linalg.matmul ins(%A, %B: tensor, tensor) + outs(%C: tensor) + -> tensor + %1 = linalg.matmul ins(%A, %B: tensor, tensor) + outs(%C: tensor) + -> tensor + return %0, %1 : tensor, tensor +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + // expected-error @below {{requires target to map to exactly 1 LinalgOp (got 2)}} + %1 = transform.structured.pack %0 packed_sizes = [2, 3, 4] + : (!pdl.operation) -> (!transform.op<"linalg.generic">) +} + + +// ----- + +func.func @matmul(%A: tensor, %B: tensor, %C: tensor) + -> tensor { + %0 = linalg.matmul ins(%A, %B: tensor, tensor) + outs(%C: tensor) + -> tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + // expected-error @below {{requires number of packed sizes match the number of loops (2 vs 3)}} + %1 = transform.structured.pack %0 packed_sizes = [2, 3] + : (!pdl.operation) -> (!transform.op<"linalg.generic">) +}