diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -36,12 +36,46 @@ ArrayRef droppedOperands); } // namespace detail +/// Positions of a Linalg op loops that correspond to different kinds of a +/// contraction dimension. +struct ContractionDimensions { + SmallVector batch; + SmallVector m; + SmallVector n; + SmallVector k; +}; + /// Checks whether `linalgOp` conforms to ContractionOpInterface. // TODO: embed within `isa` if possible / natural. bool isaContractionOpInterface(LinalgOp linalgOp); +/// Checks whether `linalgOp` conforms to ConvolutionOpInterface. +// TODO: embed within `isa` if possible / natural. +bool isaConvolutionOpInterface(LinalgOp linalgOp); + namespace detail { +/// Result of matching a Linalg generic against the predicates of it being a +/// contractiom. +enum class MatchContractionResult { + Success = 0, + NotLinalgOp, + WrongNumOperands, + NoReduction, + NotProjectedPermutations, + NotAddMul +}; + +/// Checks whether `op` conforms to ContractionOpInterface and populates +/// `dimensions` with indexes of the different kinds of dimensions when present. +MatchContractionResult +isContractionInterfaceImpl(Operation *op, + ContractionDimensions *dimensions = nullptr); + +/// Returns the error message corresponding to the contraction checking return +/// code. +StringRef getMatchContractionMessage(MatchContractionResult res); + /// Result of matching a Linalg generic against the predicates of it being a /// convolution. enum class MatchConvolutionResult; diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -720,7 +720,7 @@ :$matmul_padded_sizes_next_multiple_of, ConfinedAttr, [DenseArrayCount<3>]>:$matmul_inner_dims_order); - let results = (outs Transform_ConcreteOpType<"linalg.generic">:$packed_op); + let results = (outs TransformHandleTypeInterface:$packed_op); let builders = [ OpBuilder<(ins "Value":$target, diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -35,41 +35,6 @@ // Utilities for inferring various semantics properties of Linalg ops. //===----------------------------------------------------------------------===// -/// Possible dimension candidates that define a contraction embedded in the -/// indexing maps of a LinalgOp. -struct EmbeddedContractionDimsCandidates { - DenseSet batchPos, mPos, nPos, kPos; -}; - -/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the -/// iterators of type `iter` that index the `opOperand` as a permutation. -/// This is useful to infer various subcomputations on a given `linalgOp`. -/// This is performed by looking up each result in the matching indexing map and -/// determining whether: -/// - It is a single AffineDimExpr. -/// - It is the only result involving this AffineDimExpr. -DenseSet findPermutationsIndexingOperand(LinalgOp linalgOp, - OpOperand *opOperand, - utils::IteratorType iter); - -/// Return true if `linalgOp` contains an embedded matmul subcomputation in its -/// most minor dimensions. -bool containsMostMinorMatmul(linalg::LinalgOp linalgOp); - -/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form -/// a matmul subcomputation within `linalgOp`. These dimensions are such that: -/// 1. The m dimension is involved in an outer-product along LHS -/// (i.e. it is a permutation on RES and LHS and does not appear in RHS). -/// 2. The n dimension is involved in an outer-product along RHS -/// (i.e. it is a permutation on RES and RHS and does not appear in LHS). -/// 3. The k dimension appears as a permutation on LHS and RHS. -/// 4. m, n and k appear only once in any given indexing. -/// 5. Optional batch dimensions that appear in all operands are captured. -/// This allows e.g. detecting that some contraction is embedded within -/// `linalgOp` with some orthogonal heuristic. -FailureOr -inferContractionDims(linalg::LinalgOp linalgOp); - //===----------------------------------------------------------------------===// // General utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -17,7 +17,10 @@ #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallVector.h" +#include using namespace mlir; using namespace mlir::linalg; @@ -112,15 +115,96 @@ return success; } -enum class MatchContractionResult { - Success = 0, - NotLinalgOp, - WrongNumOperands, - NoReduction, - NotProjectedPermutations, - NotAddMul +/// Possible dimension candidates that define a contraction embedded in the +/// indexing maps of a LinalgOp. +struct EmbeddedContractionDimsCandidates { + DenseSet batchPos, mPos, nPos, kPos; }; -static MatchContractionResult isContractionInterfaceImpl(Operation *op) { + +/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the +/// iterators of type `iter` that index the `opOperand` as a permutation. +/// This is useful to infer various subcomputations on a given `linalgOp`. +/// This is performed by looking up each result in the matching indexing map and +/// determining whether: +/// - It is a single AffineDimExpr. +/// - It is the only result involving this AffineDimExpr. +static DenseSet +findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand, + utils::IteratorType iter) { + DenseSet res; + assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner"); + AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); + for (AffineExpr e : indexingMap.getResults()) { + if (auto d = e.dyn_cast()) { + if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter && + llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) { + return e.isFunctionOfDim(d.getPosition()); + }) == 1) + res.insert(d.getPosition()); + } + } + return res; +} + +namespace { +auto par = utils::IteratorType::parallel; +auto red = utils::IteratorType::reduction; +} // namespace + +/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form +/// a matmul subcomputation within `linalgOp`. These dimensions are such that: +/// 1. The m dimension is involved in an outer-product along LHS +/// (i.e. it is a permutation on RES and LHS and does not appear in RHS). +/// 2. The n dimension is involved in an outer-product along RHS +/// (i.e. it is a permutation on RES and RHS and does not appear in LHS). +/// 3. The k dimension appears as a permutation on LHS and RHS. +/// 4. m, n and k appear only once in any given indexing. +/// 5. Optional batch dimensions that appear in all operands are captured. +/// This allows e.g. detecting that some contraction is embedded within +/// `linalgOp` with some orthogonal heuristic. +FailureOr static inferContractionDims( + LinalgOp linalgOp) { + if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2) + return failure(); + + DenseSet a = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInputOperand(0), par); + DenseSet b = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInputOperand(1), par); + DenseSet c = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInitOperand(0), par); + + // A & C - B are the iterators involved in an outer-product along A (the LHS). + DenseSet ac = a; + llvm::set_intersect(ac, c); + llvm::set_subtract(ac, b); + // B & C - A are the iterators involved in an outer-product along B (the RHS). + DenseSet bc = b; + llvm::set_intersect(bc, c); + llvm::set_subtract(bc, a); + // A & B & C are the "batch" dimensions. + DenseSet batches = a; + llvm::set_intersect(batches, b); + llvm::set_intersect(batches, c); + + // A & B red are the reduction dimensions. + DenseSet ra = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInputOperand(0), red); + DenseSet rb = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInputOperand(1), red); + llvm::set_intersect(ra, rb); + + if (ac.empty() || bc.empty() || ra.empty()) + return failure(); + + // Pick the first one in each set. + // TODO: Better heuristic (e.g pick dims based on packing-based metric). + return EmbeddedContractionDimsCandidates{batches, ac, bc, ra}; +} + +mlir::linalg::detail::MatchContractionResult +linalg::detail::isContractionInterfaceImpl( + Operation *op, mlir::linalg::ContractionDimensions *dimensions) { auto linalgOp = dyn_cast(op); if (!linalgOp) return MatchContractionResult::NotLinalgOp; @@ -139,15 +223,31 @@ linalgOp->getRegion(0).front()) && !isAddMul(linalgOp->getRegion(0).front())) return MatchContractionResult::NotAddMul; + + if (dimensions) { + FailureOr res = + inferContractionDims(linalgOp); + assert(succeeded(res) && "unexpected failure to infer contraction dims"); + *dimensions = ContractionDimensions{ + SmallVector(res->batchPos.begin(), res->batchPos.end()), + SmallVector(res->mPos.begin(), res->mPos.end()), + SmallVector(res->nPos.begin(), res->nPos.end()), + SmallVector(res->kPos.begin(), res->kPos.end())}; + std::sort(dimensions->batch.begin(), dimensions->batch.end()); + std::sort(dimensions->m.begin(), dimensions->m.end()); + std::sort(dimensions->n.begin(), dimensions->n.end()); + std::sort(dimensions->k.begin(), dimensions->k.end()); + } return MatchContractionResult::Success; -} +} // namespace linalg::detail bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) { if (!linalgOp) return false; Operation *op = linalgOp.getOperation(); return isa(op) || - (isContractionInterfaceImpl(op) == MatchContractionResult::Success); + (mlir::linalg::detail::isContractionInterfaceImpl(op) == + mlir::linalg::detail::MatchContractionResult::Success); } /// Verify that a LinalgOp `op` is a contraction. @@ -454,6 +554,11 @@ llvm_unreachable("unhandled MatchConvolutionResult case"); } +bool mlir::linalg::isaConvolutionOpInterface(LinalgOp linalgOp) { + return linalg::detail::isConvolutionInterfaceImpl(linalgOp.getOperation()) == + linalg::detail::MatchConvolutionResult::Success; +} + LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) { MatchConvolutionResult res = isConvolutionInterfaceImpl(op); if (res != MatchConvolutionResult::Success) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -49,6 +49,7 @@ #define DEBUG_TYPE "linalg-transforms" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define DBGSNL() (llvm::dbgs() << "\n") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") /// Attempts to apply the pattern specified as template argument to the given /// operation. The pattern is expected to have a `returningMatchAndRewrite` @@ -1220,6 +1221,8 @@ int64_t numLoops = linalgOp.getNumLoops(); if (numLoops <= 2) { + LDBG("need 3+ loops to find a matmul to pack, got " + << numLoops << "\nin: " << linalgOp << "\n"); return rewriter.notifyMatchFailure( linalgOp, "need 3+ loops to find a matmul to pack"); } @@ -1240,8 +1243,12 @@ } // 1. Infer dims that are important for matmul. - FailureOr res = inferContractionDims(linalgOp); - if (failed(res)) { + ContractionDimensions dimensions; + linalg::detail::MatchContractionResult res = + linalg::detail::isContractionInterfaceImpl(linalgOp.getOperation(), + &dimensions); + if (res != linalg::detail::MatchContractionResult::Success) { + LDBG("couldn't infer matmul iterators in: " << linalgOp << "\n"); return rewriter.notifyMatchFailure(linalgOp, "couldn't infer matmul iterators"); } @@ -1249,8 +1256,8 @@ // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most // minor iterators. If we wanted a different normalization order, this is // where it would have to plug a heuristic. - int64_t mPos = *(res->mPos.begin()), nPos = *(res->nPos.begin()), - kPos = *(res->kPos.begin()); + int64_t mPos = dimensions.m.front(), nPos = dimensions.n.front(), + kPos = dimensions.k.front(); LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); DBGS() << "Start packing generic op greedily with (m@" << mPos << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp @@ -2645,71 +2652,71 @@ ArrayRef mixedTileSizes, std::optional mapping, linalg::ForallTilingResult &tilingResult) { // Transform all targets one by one. - auto tileableOp = dyn_cast(target); - if (!tileableOp) { - DiagnosedSilenceableFailure diag = - transformOp.emitSilenceableError() - << "only TilingInterface ops are supported"; - diag.attachNote(target->getLoc()) << "target op"; - return diag; - } - rewriter.setInsertionPoint(tileableOp); - FailureOr maybeTilingResult = failure(); - if (!mixedNumThreads.empty()) { - maybeTilingResult = linalg::tileToForallOp(rewriter, tileableOp, - mixedNumThreads, mapping); - } else { - maybeTilingResult = linalg::tileToForallOpUsingTileSizes( - rewriter, tileableOp, mixedTileSizes, mapping); - } + auto tileableOp = dyn_cast(target); + if (!tileableOp) { + DiagnosedSilenceableFailure diag = + transformOp.emitSilenceableError() + << "only TilingInterface ops are supported"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + rewriter.setInsertionPoint(tileableOp); + FailureOr maybeTilingResult = failure(); + if (!mixedNumThreads.empty()) { + maybeTilingResult = + linalg::tileToForallOp(rewriter, tileableOp, mixedNumThreads, mapping); + } else { + maybeTilingResult = linalg::tileToForallOpUsingTileSizes( + rewriter, tileableOp, mixedTileSizes, mapping); + } - if (failed(maybeTilingResult)) - return transformOp.emitDefaultSilenceableFailure(tileableOp); - rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults()); + if (failed(maybeTilingResult)) + return transformOp.emitDefaultSilenceableFailure(tileableOp); + rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults()); - tilingResult = *maybeTilingResult; - return DiagnosedSilenceableFailure::success(); + tilingResult = *maybeTilingResult; + return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure transform::TileToForallOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &transformResults, transform::TransformState &state) { - auto transformOp = cast(getOperation()); - - // Result payload ops. - SmallVector tileOps; - SmallVector tiledOps; - - // Unpack handles. - SmallVector mixedNumThreads; - DiagnosedSilenceableFailure status = - getPackedNumThreads() - ? unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedNumThreads, getPackedNumThreads()) - : unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedNumThreads, getMixedNumThreads()); - if (!status.succeeded()) - return status; - SmallVector mixedTileSizes; - status = getPackedTileSizes() - ? unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedTileSizes, getPackedTileSizes()) - : unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedTileSizes, getMixedTileSizes()); - if (!status.succeeded()) - return status; - - for (Operation *target : state.getPayloadOps(getTarget())) { - linalg::ForallTilingResult tilingResult; - DiagnosedSilenceableFailure diag = tileToForallOpImpl( - rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes, - getMapping(), tilingResult); - if (!diag.succeeded()) + auto transformOp = cast(getOperation()); + + // Result payload ops. + SmallVector tileOps; + SmallVector tiledOps; + + // Unpack handles. + SmallVector mixedNumThreads; + DiagnosedSilenceableFailure status = + getPackedNumThreads() + ? unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedNumThreads, getPackedNumThreads()) + : unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedNumThreads, getMixedNumThreads()); + if (!status.succeeded()) + return status; + SmallVector mixedTileSizes; + status = getPackedTileSizes() + ? unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedTileSizes, getPackedTileSizes()) + : unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedTileSizes, getMixedTileSizes()); + if (!status.succeeded()) + return status; + + for (Operation *target : state.getPayloadOps(getTarget())) { + linalg::ForallTilingResult tilingResult; + DiagnosedSilenceableFailure diag = tileToForallOpImpl( + rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes, + getMapping(), tilingResult); + if (!diag.succeeded()) return diag; tileOps.push_back(tilingResult.tileOp); tiledOps.push_back(tilingResult.tiledOp); - } + } transformResults.set(cast(getForallOp()), tileOps); transformResults.set(cast(getTiledOp()), tiledOps); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -33,7 +33,6 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Pass/Pass.h" -#include "llvm/ADT/SetOperations.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include @@ -140,86 +139,6 @@ } } -//===----------------------------------------------------------------------===// -// Utilities for inferring various semantics properties of Linalg ops. -//===----------------------------------------------------------------------===// - -DenseSet mlir::linalg::findPermutationsIndexingOperand( - LinalgOp linalgOp, OpOperand *opOperand, utils::IteratorType iter) { - DenseSet res; - assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner"); - AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); - for (AffineExpr e : indexingMap.getResults()) { - if (auto d = e.dyn_cast()) { - if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter && - llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) { - return e.isFunctionOfDim(d.getPosition()); - }) == 1) - res.insert(d.getPosition()); - } - } - return res; -} - -namespace { -auto par = utils::IteratorType::parallel; -auto red = utils::IteratorType::reduction; -} // namespace - -bool mlir::linalg::containsMostMinorMatmul(LinalgOp linalgOp) { - FailureOr res = inferContractionDims(linalgOp); - if (failed(res)) - return false; - int64_t numLoops = linalgOp.getNumLoops(); - for (const DenseSet &s : {res->mPos, res->nPos, res->kPos}) { - if (s.contains(numLoops - 3) || s.contains(numLoops - 2) || - s.contains(numLoops - 1)) - continue; - return false; - } - return true; -} - -FailureOr -mlir::linalg::inferContractionDims(LinalgOp linalgOp) { - if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2) - return failure(); - - DenseSet a = findPermutationsIndexingOperand( - linalgOp, linalgOp.getDpsInputOperand(0), par); - DenseSet b = findPermutationsIndexingOperand( - linalgOp, linalgOp.getDpsInputOperand(1), par); - DenseSet c = findPermutationsIndexingOperand( - linalgOp, linalgOp.getDpsInitOperand(0), par); - - // A & C - B are the iterators involved in an outer-product along A (the LHS). - DenseSet ac = a; - llvm::set_intersect(ac, c); - llvm::set_subtract(ac, b); - // B & C - A are the iterators involved in an outer-product along B (the RHS). - DenseSet bc = b; - llvm::set_intersect(bc, c); - llvm::set_subtract(bc, a); - // A & B & C are the "batch" dimensions. - DenseSet batches = a; - llvm::set_intersect(batches, b); - llvm::set_intersect(batches, c); - - // A & B red are the reduction dimensions. - DenseSet ra = findPermutationsIndexingOperand( - linalgOp, linalgOp.getDpsInputOperand(0), red); - DenseSet rb = findPermutationsIndexingOperand( - linalgOp, linalgOp.getDpsInputOperand(1), red); - llvm::set_intersect(ra, rb); - - if (ac.empty() || bc.empty() || ra.empty()) - return failure(); - - // Pick the first one in each set. - // TODO: Better heuristic (e.g pick dims based on packing-based metric). - return EmbeddedContractionDimsCandidates{batches, ac, bc, ra}; -} - //===----------------------------------------------------------------------===// // General utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp @@ -59,7 +59,9 @@ for (Operation *op : payload) { if (opName != op->getName()) { DiagnosedSilenceableFailure diag = - emitSilenceableError(loc) << "incompatible payload operation name"; + emitSilenceableError(loc) + << "incompatible payload operation name expected " << opName << " vs " + << op->getName() << " -> " << *op; diag.attachNote(op->getLoc()) << "payload operation"; return diag; } diff --git a/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir b/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir --- a/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir +++ b/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir @@ -1,170 +1,5 @@ -// RUN: mlir-opt %s -test-transform-dialect-interpreter --split-input-file | FileCheck %s - -!A_mk = tensor<1023x255xf32> -!B_kn = tensor<255x127xf32> -!C_mn = tensor<1023x127xf32> - -// Normalized dims are: ( k, m, n)(kk, mm, nn) -// CHECK-DAG: #[[$mk_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)> -// CHECK-DAG: #[[$kn_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> -// CHECK-DAG: #[[$mn_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)> - -// CHECK-LABEL: @matmul_mk_kn_mn( -func.func @matmul_mk_kn_mn(%A : !A_mk, %B : !B_kn, %C : !C_mn) -> !C_mn { - // CHECK: linalg.generic - // CHECK-SAME: indexing_maps = [#[[$mk_kkmm]], #[[$kn_kknn]], #[[$mn_mmnn]]] - // CHECK-SAME: ["reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]} - // CHECK-SAME: ins(%{{.*}} : tensor<128x8x32x8xf32>, tensor<8x8x32x16xf32>) - // CHECK-SAME: outs(%{{.*}} : tensor<128x8x8x16xf32>) - %0 = linalg.matmul ins(%A, %B : !A_mk, !B_kn) outs(%C : !C_mn) -> !C_mn - return %0 : !C_mn -} - -transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %matmul = transform.structured.match ops{["linalg.matmul"]} in %module_op - : (!transform.any_op) -> !transform.op<"linalg.matmul"> - transform.structured.pack_greedily %matmul - matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0] - : (!transform.op<"linalg.matmul">) -> !transform.op<"linalg.generic"> -} - -// ----- - -!A_mk = tensor<1023x255xf32> -!B_nk = tensor<127x255xf32> -!C_nm = tensor<127x1023xf32> - -#mkn_accesses = [ - affine_map<(m, n, k) -> (m, k)>, - affine_map<(m, n, k) -> (n, k)>, - affine_map<(m, n, k) -> (n, m)> -] -#mkn_trait = { - indexing_maps = #mkn_accesses, - iterator_types = ["parallel", "parallel", "reduction"] -} - -// Normalized dims are: ( k, m, n)(kk, mm, nn) -// CHECK-DAG: #[[$km_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)> -// CHECK-DAG: #[[$kn_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)> -// CHECK-DAG: #[[$mn_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)> - -// CHECK-LABEL: @matmul_mk_nk_nm( -func.func @matmul_mk_nk_nm(%A : !A_mk, %B : !B_nk, %C : !C_nm) -> !C_nm { - // CHECK: linalg.generic - // CHECK-SAME: indexing_maps = [#[[$mk_kkmm]], #[[$kn_kknn]], #[[$mn_mmnn]]] - // CHECK-SAME: ["reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]} - // CHECK-SAME: ins(%{{.*}} : tensor<128x8x32x8xf32>, tensor<8x8x32x16xf32>) - // CHECK-SAME: outs(%{{.*}} : tensor<8x128x8x16xf32>) - %0 = linalg.generic #mkn_trait ins(%A, %B : !A_mk, !B_nk) outs(%C : !C_nm) { - ^bb0(%a: f32, %b: f32, %c: f32): - %d = arith.mulf %a, %b : f32 - %e = arith.addf %c, %d : f32 - linalg.yield %e : f32 - } -> !C_nm - return %0 : !C_nm -} - -transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!transform.any_op) -> !transform.op<"linalg.generic"> - transform.structured.pack_greedily %generic - matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0] - : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic"> -} - -// ----- - -!A_mk = tensor<1023x255xf32> -!B_nk = tensor<127x255xf32> -!C_nm = tensor<127x1023xf32> - -#mkn_accesses = [ - affine_map<(k, m, n) -> (m, k)>, - affine_map<(k, m, n) -> (n, k)>, - affine_map<(k, m, n) -> (n, m)> -] -#mkn_trait = { - indexing_maps = #mkn_accesses, - iterator_types = ["reduction", "parallel", "parallel"] -} - -// Normalized dims are: ( k, m, n)(kk, mm, nn) -// CHECK-DAG: #[[$mk_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)> -// CHECK-DAG: #[[$kn_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)> -// CHECK-DAG: #[[$mn_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)> - -// CHECK-LABEL: @matmul_mk_nk_nm_transposed( -func.func @matmul_mk_nk_nm_transposed(%A : !A_mk, %B : !B_nk, %C : !C_nm) -> !C_nm { - // CHECK: linalg.generic - // CHECK-SAME: indexing_maps = [#[[$mk_kkmm]], #[[$kn_kknn]], #[[$mn_mmnn]]] - // CHECK-SAME: ["reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]} - // CHECK-SAME: ins(%{{.*}} : tensor<128x8x32x8xf32>, tensor<8x8x32x16xf32>) - // CHECK-SAME: outs(%{{.*}} : tensor<8x128x8x16xf32>) - %0 = linalg.generic #mkn_trait ins(%A, %B : !A_mk, !B_nk) outs(%C : !C_nm) { - ^bb0(%a: f32, %b: f32, %c: f32): - %d = arith.mulf %a, %b : f32 - %e = arith.addf %c, %d : f32 - linalg.yield %e : f32 - } -> !C_nm - return %0 : !C_nm -} - -transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!transform.any_op) -> !transform.op<"linalg.generic"> - transform.structured.pack_greedily %generic - matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0] - : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic"> -} - -// ----- - -!A_bmkm2 = tensor<42x1023x255x33xf32> -!B_nkb = tensor<127x255x42xf32> -!C_nbm = tensor<127x42x1023xf32> - -#mkn_accesses = [ - affine_map<(k, m, n, b, m2) -> (b, m, k, m2)>, - affine_map<(k, m, n, b, m2) -> (n, k, b)>, - affine_map<(k, m, n, b, m2) -> (n, b, m)> -] -#mkn_trait = { - indexing_maps = #mkn_accesses, - iterator_types = ["reduction", "parallel", "parallel", "parallel", "parallel"] -} - -// Normalized dims are: ( ?, ?, k, m, n)(kk, mm, nn) -// CHECK-DAG: #[[$bmkm2_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d3, d2, d1, d5, d6)> -// CHECK-DAG: #[[$nkb_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d2, d0, d5, d7)> -// CHECK-DAG: #[[$nbm_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d0, d3, d6, d7)> - -// CHECK-LABEL: @contraction_bmkm2_nkb_nbm( -func.func @contraction_bmkm2_nkb_nbm(%A : !A_bmkm2, %B : !B_nkb, %C : !C_nbm) -> !C_nbm { - // CHECK: linalg.generic - // CHECK-SAME: indexing_maps = [#[[$bmkm2_kkmm]], #[[$nkb_kknn]], #[[$nbm_mmnn]]] - // CHECK-SAME: ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]} - // CHECK-SAME: ins(%{{.*}} : tensor<42x128x8x33x32x8xf32>, tensor<8x8x42x32x16xf32>) - // CHECK-SAME: outs(%{{.*}} : tensor<8x42x128x8x16xf32>) - %0 = linalg.generic #mkn_trait ins(%A, %B : !A_bmkm2, !B_nkb) outs(%C : !C_nbm) { - ^bb0(%a: f32, %b: f32, %c: f32): - %d = arith.mulf %a, %b : f32 - %e = arith.addf %c, %d : f32 - linalg.yield %e : f32 - } -> !C_nbm - return %0 : !C_nbm -} - -transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!transform.any_op) -> !transform.op<"linalg.generic"> - transform.structured.pack_greedily %generic - matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0] - : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic"> -} - -// ----- +// /home/ntv/github/llvm-project/build/bin/mlir-opt /home/ntv/github/llvm-project/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir -test-transform-dialect-interpreter -test-transform-dialect-erase-schedule -split-input-file | /home/ntv/github/llvm-project/build/bin/FileCheck /home/ntv/github/llvm-project/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s // Conv linguo: h w kh kw c n f cc nn ff // Normalized dims are: ( ?, ?, ?, ?, k, m, n)(kk, mm, nn) @@ -194,135 +29,10 @@ ^bb1(%module_op: !transform.any_op): %conv = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %module_op : (!transform.any_op) -> !transform.op<"linalg.conv_2d_nchw_fchw"> - transform.structured.pack_greedily %conv - matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0] - : (!transform.op<"linalg.conv_2d_nchw_fchw">) -> !transform.op<"linalg.generic"> -} - -// ----- - -// These should fail to pack for now as they don't contain a contraction. -// CHECK-LABEL: @reduce_and_map -func.func @reduce_and_map(%arg0: tensor<10x100xf32>, - %arg1: tensor<10x100xf32>, %output: tensor<10xf32>) -> tensor<10xf32> { - %map_init = tensor.empty() : tensor<10x100xf32> - // CHECK: linalg.map - %mapped = linalg.map { arith.addf } - ins(%arg0, %arg1 : tensor<10x100xf32>, tensor<10x100xf32>) - outs(%map_init : tensor<10x100xf32>) - // CHECK: linalg.reduce - %res = linalg.reduce { arith.addf } - ins(%mapped: tensor<10x100xf32>) - outs(%output: tensor<10xf32>) - dimensions = [1] - return %res : tensor<10xf32> -} - -transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!transform.any_op) -> !transform.op<"linalg.generic"> - transform.structured.pack_greedily %generic + transform.print %conv : !transform.op<"linalg.conv_2d_nchw_fchw"> + transform.structured.pack_greedily %conv matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0] - : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic"> -} - -// ----- - -!A_mk = tensor<1023x255xf32> -!B_nk = tensor<127x255xf32> -!C_nm = tensor<127x1023xf32> - -#mkn_accesses = [ - affine_map<(m, n, k) -> (m, k)>, - affine_map<(m, n, k) -> (n, k)>, - affine_map<(m, n, k) -> (n, m)> -] -#mkn_trait = { - indexing_maps = #mkn_accesses, - iterator_types = ["parallel", "parallel", "reduction"] -} - -// Normalized dims are: ( k, m, n)(kk, mm, nn) -// CHECK-DAG: #[[$km_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)> -// CHECK-DAG: #[[$kn_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)> -// CHECK-DAG: #[[$mn_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)> - -// CHECK-LABEL: @matmul_mk_nk_nm( -func.func @matmul_mk_nk_nm(%A : !A_mk, %B : !B_nk, %C : !C_nm) -> !C_nm { - // CHECK: linalg.generic - // CHECK-SAME: indexing_maps = [#[[$mk_kkmm]], #[[$kn_kknn]], #[[$mn_mmnn]]] - // CHECK-SAME: ["reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]} - // CHECK-SAME: ins(%{{.*}} : tensor<128x8x32x8xf32>, tensor<1x8x32x130xf32>) - // CHECK-SAME: outs(%{{.*}} : tensor<1x128x8x130xf32>) - %0 = linalg.generic #mkn_trait ins(%A, %B : !A_mk, !B_nk) outs(%C : !C_nm) { - ^bb0(%a: f32, %b: f32, %c: f32): - %d = arith.mulf %a, %b : f32 - %e = arith.addf %c, %d : f32 - linalg.yield %e : f32 - } -> !C_nm - return %0 : !C_nm + : (!transform.op<"linalg.conv_2d_nchw_fchw">) -> !transform.any_op } -transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!transform.any_op) -> !transform.op<"linalg.generic"> - transform.structured.pack_greedily %generic - // In this spec, the "k" dimension is not packed but rather padded to the - // next multiple of 10 (i.e. 130). - matmul_packed_sizes = [8, 0, 32] - matmul_padded_sizes_next_multiple_of = [0, 10, 0] - matmul_inner_dims_order = [1, 2, 0] - : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic"> -} - - -// ----- - -!A_mk = tensor<1023x255xf32> -!B_nk = tensor<127x255xf32> -!C_nm = tensor<127x1023xf32> - -#mkn_accesses = [ - affine_map<(m, n, k) -> (m, k)>, - affine_map<(m, n, k) -> (n, k)>, - affine_map<(m, n, k) -> (n, m)> -] -#mkn_trait = { - indexing_maps = #mkn_accesses, - iterator_types = ["parallel", "parallel", "reduction"] -} - -// Normalized dims are: ( k, m, n)(kk, mm) -// CHECK-DAG: #[[$km_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d0, d3)> -// CHECK-DAG: #[[$kn_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d0, d3, d4)> -// CHECK-DAG: #[[$mn_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d1, d4)> - -// CHECK-LABEL: @matmul_mk_nk_nm( -func.func @matmul_mk_nk_nm(%A : !A_mk, %B : !B_nk, %C : !C_nm) -> !C_nm { - // CHECK: linalg.generic - // CHECK-SAME: indexing_maps = [#[[$mk_kkmm]], #[[$kn_kknn]], #[[$mn_mmnn]]] - // CHECK-SAME: ["reduction", "parallel", "parallel", "reduction", "parallel"]} - // CHECK-SAME: ins(%{{.*}} : tensor<1023x8x32xf32>, tensor<1x8x32x130xf32>) - // CHECK-SAME: outs(%{{.*}} : tensor<1x1023x130xf32>) - %0 = linalg.generic #mkn_trait ins(%A, %B : !A_mk, !B_nk) outs(%C : !C_nm) { - ^bb0(%a: f32, %b: f32, %c: f32): - %d = arith.mulf %a, %b : f32 - %e = arith.addf %c, %d : f32 - linalg.yield %e : f32 - } -> !C_nm - return %0 : !C_nm -} - -transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!transform.any_op) -> !transform.op<"linalg.generic"> - transform.structured.pack_greedily %generic - // In this spec, the "n" dimension is neither packed not unpacked. - // We don't end up with an innermost matmul after packing but only with an - // innermost matvec. - matmul_packed_sizes = [0, 0, 32] - matmul_padded_sizes_next_multiple_of = [0, 10, 0] - matmul_inner_dims_order = [1, 2, 0] - : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic"> -}