diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -370,7 +370,7 @@ def PackOp : Op, - DeclareOpInterfaceMethods,]> { + DeclareOpInterfaceMethods]> { let description = [{ Pack a LinalgOp by applying a data tiling transformation on the op and packing the operands according to the `packed_sizes` specification. @@ -453,6 +453,84 @@ }]; } +//===----------------------------------------------------------------------===// +// PackGreedilyOp +//===----------------------------------------------------------------------===// +def PackGreedilyOp : Op, + DeclareOpInterfaceMethods]> { + let description = [{ + Target a Linalg op and rewrite it into packed LinalgOp form by trying to + infer whether a known suboperation is embedded + + Different packing strategies are applied in order, when one applies + successfully, the transform returns: + 1. Gemm packing: Try to infer a gemm operation embedded in the target op. + Specifically, this looks for 2 parallel dimensions that participate in + an outer-product and 1 reduction dimension. + These dimensions are referred as (m, n, k) to match canonical gemm + terminology. + The packed sizes for (m, n, k) are specified by `gemm_packed_sizes`. + The ordering of the packed dimensions (mm, nn, kk) is specified by the + `gemm_inner_dims_order` attribute. + + Packing occurs as follows: + 1. Find the dimensions to pack according to the strategy. + 2. The target is converted to linalg.generic form. + 3. An interchange transform is applied to isolate the dimensions to pack as + the most minor indexing dimensions of the linalg.generic. The most minor + dimensions are themselves ordered according to `inner_dims_order`. + 4. Packing is performed by `packed_sizes` and following `inner_dims_order`. + + By normalizing the most minor dimensions to `inner_dims_order`, the transform + guarantees that packing immediates generates inner dimensions in a desirable + layout. + + Outer dimension layout permutations are not controlled by this transform op + at the moment and can be obtained by composing with the pack_transpose + transformation. + + #### Return modes + + This operation ignores non-Linalg ops and drops them in the return. + It returns the list of packed Linalg ops or the original op when all available + packing strategies failed to apply. + }]; + + // TODO: Transform_ConcreteOpType needs interface. + let arguments = (ins TransformHandleTypeInterface:$target, + Variadic:$gemm_packed_sizes, + DefaultValuedAttr + :$static_gemm_packed_sizes, + DefaultValuedAttr + :$gemm_inner_dims_order); + let results = (outs Transform_ConcreteOpType<"linalg.generic">:$packed_op); + + let builders = [ + OpBuilder<(ins "Value":$target, + "ArrayRef":$mixedGemmPackedSizes, + CArg<"ArrayRef", "{}">:$gemmDimsInnerDimsOrder)> + ]; + + let assemblyFormat = [{ + $target + oilist( + `gemm_packed_sizes` `=` custom($gemm_packed_sizes, + $static_gemm_packed_sizes) + `gemm_inner_dims_order` `=` $gemm_inner_dims_order + ) + attr-dict + `:` functional-type($target, results) + }]; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + /// Returns the list of tile sizes, which may be static (Attribute) or + /// dynamic (Value). + SmallVector getMixedGemmPackedSizes(); + }]; +} + //===----------------------------------------------------------------------===// // PackTransposeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -30,9 +30,11 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/SetOperations.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/Debug.h" @@ -134,7 +136,7 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::DecomposeOp::applyToOne(linalg::LinalgOp target, +transform::DecomposeOp::applyToOne(LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { #define DOWNSCALE(trans) \ @@ -642,7 +644,7 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GeneralizeOp::applyToOne(linalg::LinalgOp target, +transform::GeneralizeOp::applyToOne(LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Exit early if no transformation is needed. @@ -663,7 +665,7 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::InterchangeOp::applyToOne(linalg::GenericOp target, +transform::InterchangeOp::applyToOne(GenericOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { ArrayRef interchangeVector = getIteratorInterchange(); @@ -730,7 +732,7 @@ if (getInterface().has_value()) { auto iface = getInterface().value(); if (iface == transform::MatchInterfaceEnum::LinalgOp && - !isa(op)) + !isa(op)) return; if (iface == transform::MatchInterfaceEnum::TilingInterface && isa(op)) @@ -885,7 +887,7 @@ // attributes for multiple variadic operands. In the absence of this, horrible // bugs ensue. Type linalgOpHType = transform::OperationType::get( - builder.getContext(), linalg::GenericOp::getOperationName()); + builder.getContext(), GenericOp::getOperationName()); build(builder, result, /*resultType=*/linalgOpHType, /*target=*/target, @@ -908,7 +910,7 @@ return DiagnosedSilenceableFailure::success(); } // Fail on multi-op handles. - auto linalgOp = dyn_cast(targetOps.front()); + auto linalgOp = dyn_cast(targetOps.front()); if (targetOps.size() != 1 || !linalgOp) { return emitSilenceableError() << "requires target to map to exactly 1 LinalgOp (got " @@ -946,6 +948,268 @@ transform::modifiesPayload(effects); } +//===---------------------------------------------------------------------===// +// PackGreedilyOp. +//===---------------------------------------------------------------------===// + +LogicalResult transform::PackGreedilyOp::verify() { + if (!isPermutationVector(getGemmInnerDimsOrder())) { + return emitOpError() << getGemmInnerDimsOrderAttrName() + << " is not a valid permutation"; + } + // TODO: relax to allow empty once we have another strategy than just gemm. + if (getGemmInnerDimsOrder().size() != 3 || + getMixedGemmPackedSizes().size() != 3) { + return emitOpError() << " needs 3 entries for gemm_packed_sizes and " + << getGemmInnerDimsOrderAttrName() + << " order for the gemm strategy"; + } + return success(); +} + +namespace { +auto par = utils::IteratorType::parallel; +auto red = utils::IteratorType::reduction; +} // namespace + +/// Return the set of AffineDimExpr +static DenseSet +findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand, + utils::IteratorType iter) { + DenseSet res; + assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner"); + AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); + for (AffineExpr e : indexingMap.getResults()) { + if (auto d = e.dyn_cast()) { + if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter && + llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) { + return e.isFunctionOfDim(d.getPosition()); + }) == 1) + res.insert(d.getPosition()); + } + } + return res; +} + +struct GemmDimsForPacking { + int64_t mPos, nPos, kPos; +}; +/// Greedily look for 2 parallel (m and n) and 1 reduction (k) dimension that +/// form a gemm. Such dimensions are such that: +/// 1. The m dimension is involved in an outer-product along LHS +/// (i.e. it is a permutation on RES and LHS and does not appear in RHS). +/// 2. The n dimension is involved in an outer-product along RHS +/// (i.e. it is a permutation on RES and RHS and does not appear in LHS). +/// 3. The k dimension appears as a permutation on LHS and RHS. +/// 4. m, n and k appear only once in any given indexing. +/// +/// This allows detecting that some gemm is embedded within `linalgOp`. +/// +/// When multiple possibilities for selecting m, n and k appear, we just pick +/// an arbitrary one (i.e. the first in a DenseSet). +// TODO: Better heuristic (e.g pick dims based on packing-based metric). +static FailureOr getGemmDims(LinalgOp linalgOp) { + assert(linalgOp.getNumDpsInits() == 1 && "wrong number of dps inits"); + assert(linalgOp.getNumDpsInputs() == 2 && "wrong number of dps inputs"); + + DenseSet a = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInputOperand(0), par); + DenseSet b = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInputOperand(1), par); + DenseSet c = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInitOperand(0), par); + + // A & C - B are the iterators involved in an outer-product along A (the LHS). + DenseSet ac = a; + llvm::set_intersect(ac, c); + llvm::set_subtract(ac, b); + // B & C - A are the iterators involved in an outer-product along B (the RHS). + DenseSet bc = b; + llvm::set_intersect(bc, c); + llvm::set_subtract(bc, a); + + // Note: if we ever need them, A & B & C would be "batch" dimensions. + + // A & B red are the reduction dimensions. + DenseSet ra = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInputOperand(0), red); + DenseSet rb = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInputOperand(1), red); + llvm::set_intersect(ra, rb); + + if (ac.empty() || bc.empty() || ra.empty()) + return failure(); + + // Pick the first one in each set. + // TODO: Better heuristic (e.g pick dims based on packing-based metric). + return GemmDimsForPacking{*ac.begin(), *bc.begin(), *ra.begin()}; +} + +/// Return a permutation vector of size permSize that would result in moving +/// positions into desiredPositions. +/// +/// For example, permSize == 5, positions = {2, 4}, desiredPositions = {1, 0} +/// would result in a {4, 2, 0, 1, 3} permutation vector. +static SmallVector +computePermutationVector(int64_t permSize, ArrayRef positions, + ArrayRef desiredPositions) { + SmallVector res(permSize, -1); + DenseSet seen; + for (auto [pos, desiredPos] : llvm::zip(positions, desiredPositions)) { + res[desiredPos] = pos; + seen.insert(pos); + } + int64_t nextPos = 0; + for (int64_t &entry : res) { + if (entry != -1) + continue; + while (seen.contains(nextPos)) + ++nextPos; + entry = nextPos; + ++nextPos; + } + return res; +} + +/// Pack a LinalgOp by greedily inferring gemm dimensions (m, n, k) +/// where m and n are proper parallel dimensions and k is a proper reduction +/// dimension. +/// Packing occurs by rewriting the op as a linalg.generic and calling +/// linalg::pack by `mnkPackedSizes`. +/// The order of the packed dimensions is customizable: the `mnkOrder` is a +/// permutation of {0, 1, 2} to reorder {m, n, k} into one of the 8 possible +/// forms. +/// The outer dimensions of the operands are not permuted at this time, this is +/// left for future work. +static FailureOr +packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp, + ArrayRef mnkPackedSizes, + ArrayRef mnkOrder) { + assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes"); + assert(mnkOrder.size() == 3 && "unexpected mnkOrder size"); + assert(isPermutationVector(mnkOrder) && "expected a permutation"); + + int64_t numLoops = linalgOp.getNumLoops(); + if (numLoops <= 2) { + return rewriter.notifyMatchFailure(linalgOp, + "need 3+ loops to find a gemm to pack"); + } + + // Locally adjust the desired iterator position of mnk and packing sizes. + int64_t numPackedDims = mnkPackedSizes.size(); + SmallVector mmnnkkPos(numPackedDims); + for (int64_t i = 0, e = numPackedDims; i < e; ++i) + mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i]; + SmallVector packedSizes(mnkPackedSizes.size()); + for (int64_t i = 0, e = numPackedDims; i < e; ++i) + packedSizes[mnkOrder[i]] = mnkPackedSizes[i]; + + // 1. Infer dims that are important for gemm. + FailureOr res = getGemmDims(linalgOp); + if (failed(res)) { + return rewriter.notifyMatchFailure(linalgOp, + "couldn't infer gemm iterators"); + } + + // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most + // minor iterators. If we wanted a different normalization order, this is + // where it would have to start. + int64_t mPos = res->mPos, nPos = res->nPos, kPos = res->kPos; + LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); + DBGS() << "Start packing generic op greedily with (m@" << mPos + << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp + << "\n";); + + // 2.a. Rewrite as a generic. + auto genericOp = dyn_cast(linalgOp.getOperation()); + if (!genericOp) { + FailureOr generalizeResult = + generalizeNamedOp(rewriter, linalgOp); + assert(succeeded(generalizeResult) && "unexpected failure generalizing op"); + genericOp = *generalizeResult; + } + + // 2.b. Interchange to move the dimensions (k, m, n) as most-minor iterators. + // Note that this only normalized the iteration order and does not change the + // indexings of any operand. + SmallVector permutation = + computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos); + LLVM_DEBUG(llvm::interleaveComma(permutation, DBGS() << "perm: "); DBGSNL();); + // Sign .. unsigned pollution. + SmallVector unsignedPerm(permutation.begin(), permutation.end()); + FailureOr interchangeResult = + interchangeGenericOp(rewriter, genericOp, unsignedPerm); + assert(succeeded(interchangeResult) && "unexpected failure interchanging op"); + genericOp = *interchangeResult; + LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";); + + // At this point, the op iterators are normalized to {leading, k, m, n}. + // The layouts induced by packing will always be: + // - LHS{leading_lhs, kk, mm} + // - RHS{leading_rhs, kk, nn} + // - RES{leading_res, mm, nn} + // If we wanted to change the packed order, we would reorder (k, m, n) to + // something else above. + // + // Additional permutations of the outer dims of the operands (i.e. + // leading_lhs, leading_rhs and leading_res) could follow by computing the + // desired outerPerm for each operand. + // This is left for future work. + + // Add leading zeros to match numLoops. + SmallVector adjustedPackedSizes(numLoops - packedSizes.size(), + rewriter.getIndexAttr(0)); + llvm::append_range(adjustedPackedSizes, packedSizes); + + // TODO: If we wanted to give the genericOp a name after packing, after + // calling `pack` would be a good time. + return linalg::pack(rewriter, genericOp, adjustedPackedSizes); +} + +DiagnosedSilenceableFailure +PackGreedilyOp::apply(transform::TransformResults &transformResults, + transform::TransformState &state) { + ArrayRef targetOps = state.getPayloadOps(getTarget()); + + SmallVector results; + IRRewriter rewriter(getContext()); + for (Operation *op : targetOps) { + auto linalgOp = dyn_cast(op); + if (!linalgOp) + continue; + // linalgOp will be replaced and the insertion point may be invalidated if + // we set it before -> set it after. + rewriter.setInsertionPointAfter(linalgOp); + // Failing to pack greedily is perfectly fine. + // In the future we will want to order packings according to some metric. + FailureOr gemm = packGemmGreedily( + /*rewriter=*/rewriter, + /*linalgOp=*/linalgOp, + /*mnkPackedSizes=*/getMixedGemmPackedSizes(), + /*mnkOrder=*/getGemmInnerDimsOrder()); + if (succeeded(gemm)) { + results.push_back(*gemm); + continue; + } + results.push_back(linalgOp); + } + transformResults.set(getPackedOp().cast(), results); + return DiagnosedSilenceableFailure::success(); +} + +SmallVector PackGreedilyOp::getMixedGemmPackedSizes() { + Builder b(getContext()); + return getMixedValues(getStaticGemmPackedSizes(), getGemmPackedSizes(), b); +} + +void transform::PackGreedilyOp::getEffects( + SmallVectorImpl &effects) { + transform::consumesHandle(getTarget(), effects); + transform::onlyReadsHandle(getGemmPackedSizes(), effects); + transform::producesHandle(getPackedOp(), effects); + transform::modifiesPayload(effects); +} + //===---------------------------------------------------------------------===// // PackTransposeOp //===---------------------------------------------------------------------===// @@ -1030,7 +1294,7 @@ return emitSilenceableError() << "requires target to map to a " "tensor.pack or tensor.unpack"; } - LinalgOp linalgOpTarget = dyn_cast(linalgOps.front()); + LinalgOp linalgOpTarget = dyn_cast(linalgOps.front()); if (!linalgOpTarget) return emitSilenceableError() << "requires a LinalgOp target"; @@ -1102,7 +1366,7 @@ //===---------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::PadOp::applyToOne(linalg::LinalgOp target, +transform::PadOp::applyToOne(LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Convert the integer packing flags to booleans. @@ -1214,7 +1478,7 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::PromoteOp::applyToOne(linalg::LinalgOp target, +transform::PromoteOp::applyToOne(LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { LinalgPromotionOptions promotionOptions; @@ -1308,7 +1572,7 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::ScalarizeOp::applyToOne(linalg::LinalgOp target, +transform::ScalarizeOp::applyToOne(LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { scf::SCFTilingOptions tilingOptions; @@ -1560,7 +1824,7 @@ } DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne( - linalg::LinalgOp target, transform::ApplyToEachResultList &results, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { ControlSplitReductionFn splitFn = [&](LinalgOp) { return linalg::SplitReductionOptions{int64_t(getSplitFactor()), @@ -1605,7 +1869,7 @@ } DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne( - linalg::LinalgOp target, transform::ApplyToEachResultList &results, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { TrivialPatternRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); @@ -1649,7 +1913,7 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForeachThreadOp::applyToOne( - linalg::LinalgOp target, transform::ApplyToEachResultList &results, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { TrivialPatternRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); diff --git a/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir b/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir @@ -0,0 +1,228 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter --split-input-file | FileCheck %s + +!A_mk = tensor<1023x255xf32> +!B_kn = tensor<255x127xf32> +!C_mn = tensor<1023x127xf32> + +// Normalized dims are: ( k, m, n)(kk, mm, nn) +// CHECK-DAG: #[[$mk_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)> +// CHECK-DAG: #[[$kn_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> +// CHECK-DAG: #[[$mn_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)> + +// CHECK-LABEL: @matmul_mk_kn_mn( +func.func @matmul_mk_kn_mn(%A : !A_mk, %B : !B_kn, %C : !C_mn) -> !C_mn { + // CHECK: linalg.generic + // CHECK-SAME: indexing_maps = [#[[$mk_kkmm]], #[[$kn_kknn]], #[[$mn_mmnn]]] + // CHECK-SAME: ["reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]} + // CHECK-SAME: ins(%{{.*}} : tensor<128x8x32x8xf32>, tensor<8x8x32x16xf32>) + // CHECK-SAME: outs(%{{.*}} : tensor<128x8x8x16xf32>) + %0 = linalg.matmul ins(%A, %B : !A_mk, !B_kn) outs(%C : !C_mn) -> !C_mn + return %0 : !C_mn +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %matmul = transform.structured.match ops{["linalg.matmul"]} in %module_op + : (!pdl.operation) -> !transform.op<"linalg.matmul"> + transform.structured.pack_greedily %matmul + gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0] + : (!transform.op<"linalg.matmul">) -> !transform.op<"linalg.generic"> +} + +// ----- + +!A_mk = tensor<1023x255xf32> +!B_nk = tensor<127x255xf32> +!C_nm = tensor<127x1023xf32> + +#mkn_accesses = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (n, k)>, + affine_map<(m, n, k) -> (n, m)> +] +#mkn_trait = { + indexing_maps = #mkn_accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// Normalized dims are: ( k, m, n)(kk, mm, nn) +// CHECK-DAG: #[[$km_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)> +// CHECK-DAG: #[[$kn_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)> +// CHECK-DAG: #[[$mn_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)> + +// CHECK-LABEL: @matmul_mk_nk_nm( +func.func @matmul_mk_nk_nm(%A : !A_mk, %B : !B_nk, %C : !C_nm) -> !C_nm { + // CHECK: linalg.generic + // CHECK-SAME: indexing_maps = [#[[$mk_kkmm]], #[[$kn_kknn]], #[[$mn_mmnn]]] + // CHECK-SAME: ["reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]} + // CHECK-SAME: ins(%{{.*}} : tensor<128x8x32x8xf32>, tensor<8x8x32x16xf32>) + // CHECK-SAME: outs(%{{.*}} : tensor<8x128x8x16xf32>) + %0 = linalg.generic #mkn_trait ins(%A, %B : !A_mk, !B_nk) outs(%C : !C_nm) { + ^bb0(%a: f32, %b: f32, %c: f32): + %d = arith.mulf %a, %b : f32 + %e = arith.addf %c, %d : f32 + linalg.yield %e : f32 + } -> !C_nm + return %0 : !C_nm +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic"> + transform.structured.pack_greedily %generic + gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0] + : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic"> +} + +// ----- + +!A_mk = tensor<1023x255xf32> +!B_nk = tensor<127x255xf32> +!C_nm = tensor<127x1023xf32> + +#mkn_accesses = [ + affine_map<(k, m, n) -> (m, k)>, + affine_map<(k, m, n) -> (n, k)>, + affine_map<(k, m, n) -> (n, m)> +] +#mkn_trait = { + indexing_maps = #mkn_accesses, + iterator_types = ["reduction", "parallel", "parallel"] +} + +// Normalized dims are: ( k, m, n)(kk, mm, nn) +// CHECK-DAG: #[[$mk_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)> +// CHECK-DAG: #[[$kn_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)> +// CHECK-DAG: #[[$mn_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)> + +// CHECK-LABEL: @matmul_mk_nk_nm_transposed( +func.func @matmul_mk_nk_nm_transposed(%A : !A_mk, %B : !B_nk, %C : !C_nm) -> !C_nm { + // CHECK: linalg.generic + // CHECK-SAME: indexing_maps = [#[[$mk_kkmm]], #[[$kn_kknn]], #[[$mn_mmnn]]] + // CHECK-SAME: ["reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]} + // CHECK-SAME: ins(%{{.*}} : tensor<128x8x32x8xf32>, tensor<8x8x32x16xf32>) + // CHECK-SAME: outs(%{{.*}} : tensor<8x128x8x16xf32>) + %0 = linalg.generic #mkn_trait ins(%A, %B : !A_mk, !B_nk) outs(%C : !C_nm) { + ^bb0(%a: f32, %b: f32, %c: f32): + %d = arith.mulf %a, %b : f32 + %e = arith.addf %c, %d : f32 + linalg.yield %e : f32 + } -> !C_nm + return %0 : !C_nm +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic"> + transform.structured.pack_greedily %generic + gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0] + : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic"> +} + +// ----- + +!A_bmkm2 = tensor<42x1023x255x33xf32> +!B_nkb = tensor<127x255x42xf32> +!C_nbm = tensor<127x42x1023xf32> + +#mkn_accesses = [ + affine_map<(k, m, n, b, m2) -> (b, m, k, m2)>, + affine_map<(k, m, n, b, m2) -> (n, k, b)>, + affine_map<(k, m, n, b, m2) -> (n, b, m)> +] +#mkn_trait = { + indexing_maps = #mkn_accesses, + iterator_types = ["reduction", "parallel", "parallel", "parallel", "parallel"] +} + +// Normalized dims are: ( ?, ?, k, m, n)(kk, mm, nn) +// CHECK-DAG: #[[$bmkm2_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d3, d2, d1, d5, d6)> +// CHECK-DAG: #[[$nkb_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d2, d0, d5, d7)> +// CHECK-DAG: #[[$nbm_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d0, d3, d6, d7)> + +// CHECK-LABEL: @contraction_bmkm2_nkb_nbm( +func.func @contraction_bmkm2_nkb_nbm(%A : !A_bmkm2, %B : !B_nkb, %C : !C_nbm) -> !C_nbm { + // CHECK: linalg.generic + // CHECK-SAME: indexing_maps = [#[[$bmkm2_kkmm]], #[[$nkb_kknn]], #[[$nbm_mmnn]]] + // CHECK-SAME: ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]} + // CHECK-SAME: ins(%{{.*}} : tensor<42x128x8x33x32x8xf32>, tensor<8x8x42x32x16xf32>) + // CHECK-SAME: outs(%{{.*}} : tensor<8x42x128x8x16xf32>) + %0 = linalg.generic #mkn_trait ins(%A, %B : !A_bmkm2, !B_nkb) outs(%C : !C_nbm) { + ^bb0(%a: f32, %b: f32, %c: f32): + %d = arith.mulf %a, %b : f32 + %e = arith.addf %c, %d : f32 + linalg.yield %e : f32 + } -> !C_nbm + return %0 : !C_nbm +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic"> + transform.structured.pack_greedily %generic + gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0] + : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic"> +} + +// ----- + +// Conv linguo: h w kh kw c n f cc nn ff +// Normalized dims are: ( ?, ?, ?, ?, k, m, n)(kk, mm, nn) +// n c h + kh w + kw cc nn +// CHECK-DAG: #[[$M1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (d5, d4, d0 + d2, d1 + d3, d7, d8)> +// f c kh kw cc ff +// CHECK-DAG: #[[$M2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (d6, d4, d2, d3, d7, d9)> +// n f h w nn ff +// CHECK-DAG: #[[$M3:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (d5, d6, d0, d1, d8, d9)> + +// CHECK-LABEL: @conv_2d_nchw_fchw +func.func @conv_2d_nchw_fchw(%arg0: tensor, %arg2: tensor) -> tensor { + %c0 = arith.constant dense<0.1> : tensor<16x47x3x3xf32> + // CHECK: linalg.generic + // CHECK-SAME: indexing_maps = [#[[$M1]], #[[$M2]], #[[$M3]]] + // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "reduction", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel"] + // CHECK-SAME: ins(%{{.*}} : tensor, tensor<1x2x3x3x32x16xf32>) + // CHECK-SAME: outs(%{{.*}} : tensor) + %0 = linalg.conv_2d_nchw_fchw + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %c0: tensor, tensor<16x47x3x3xf32>) + outs(%arg2: tensor) -> tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %conv = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %module_op + : (!pdl.operation) -> !transform.op<"linalg.conv_2d_nchw_fchw"> + transform.structured.pack_greedily %conv + gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0] + : (!transform.op<"linalg.conv_2d_nchw_fchw">) -> !transform.op<"linalg.generic"> +} + + +// ----- + +// These should fail to pack for now as they don't contain a contraction. +// CHECK-LABEL: @reduce_and_map +func.func @reduce_and_map(%arg0: tensor<10x100xf32>, + %arg1: tensor<10x100xf32>, %output: tensor<10xf32>) -> tensor<10xf32> { + %map_init = tensor.empty() : tensor<10x100xf32> + // CHECK: linalg.map + %mapped = linalg.map { arith.addf } + ins(%arg0, %arg1 : tensor<10x100xf32>, tensor<10x100xf32>) + outs(%map_init : tensor<10x100xf32>) + // CHECK: linalg.reduce + %res = linalg.reduce { arith.addf } + ins(%mapped: tensor<10x100xf32>) + outs(%output: tensor<10xf32>) + dimensions = [1] + return %res : tensor<10xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic"> + transform.structured.pack_greedily %generic + gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0] + : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic"> +}