diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -36,12 +36,60 @@ ArrayRef droppedOperands); } // namespace detail +/// Positions of a Linalg op loops that correspond to different kinds of a +/// contraction dimension. +struct ContractionDimensions { + SmallVector batch; + SmallVector m; + SmallVector n; + SmallVector k; +}; + +/// Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates +/// that form a matmul subcomputation within `linalgOp`. +/// These dimensions are such that: +/// 1. The m dimension is involved in an outer-product along LHS +/// (i.e. it is a permutation on RES and LHS and does not appear in RHS). +/// 2. The n dimension is involved in an outer-product along RHS +/// (i.e. it is a permutation on RES and RHS and does not appear in LHS). +/// 3. The k dimension appears as a permutation on LHS and RHS. +/// 4. m, n and k appear only once in any given indexing. +/// 5. Optional batch dimensions that appear in all operands are captured. +/// This allows e.g. detecting that some contraction is embedded within +/// `linalgOp` with some orthogonal heuristic. +/// When multiple dimension occurrences exist that match `batch`, `m`, `n`, or +/// `k`, indices are returned in sorted order. +/// Returns a failure if any of `m`, `n` or `k` is empty. +FailureOr inferContractionDims(LinalgOp linalgOp); + /// Checks whether `linalgOp` conforms to ContractionOpInterface. // TODO: embed within `isa` if possible / natural. bool isaContractionOpInterface(LinalgOp linalgOp); +/// Checks whether `linalgOp` conforms to ConvolutionOpInterface. +// TODO: embed within `isa` if possible / natural. +bool isaConvolutionOpInterface(LinalgOp linalgOp); + namespace detail { +/// Result of matching a Linalg generic against the predicates of it being a +/// contractiom. +enum class MatchContractionResult; + +/// Checks whether `op` conforms to ContractionOpInterface and populates +/// `dimensions` with indexes of the different kinds of dimensions when +/// present. +// TODO: Extract a standalone `inferConvolutionDims` that can also detect +// whether a conv pattern exists within a bigger linalg op (see +// inferContractionDims). +MatchContractionResult +isContractionInterfaceImpl(Operation *op, + ContractionDimensions *dimensions = nullptr); + +/// Returns the error message corresponding to the contraction checking return +/// code. +StringRef getMatchContractionMessage(MatchContractionResult res); + /// Result of matching a Linalg generic against the predicates of it being a /// convolution. enum class MatchConvolutionResult; @@ -58,7 +106,8 @@ }; /// Checks whether `op` conforms to ConvolutionOpInterface and populates -/// `dimensions` with indexes of the different kinds of dimensions when present. +/// `dimensions` with indexes of the different kinds of dimensions when +/// present. MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions = nullptr); diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -722,7 +722,7 @@ :$matmul_padded_sizes_next_multiple_of, ConfinedAttr, [DenseArrayCount<3>]>:$matmul_inner_dims_order); - let results = (outs Transform_ConcreteOpType<"linalg.generic">:$packed_op); + let results = (outs TransformHandleTypeInterface:$packed_op); let builders = [ OpBuilder<(ins "Value":$target, diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -35,41 +35,6 @@ // Utilities for inferring various semantics properties of Linalg ops. //===----------------------------------------------------------------------===// -/// Possible dimension candidates that define a contraction embedded in the -/// indexing maps of a LinalgOp. -struct EmbeddedContractionDimsCandidates { - DenseSet batchPos, mPos, nPos, kPos; -}; - -/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the -/// iterators of type `iter` that index the `opOperand` as a permutation. -/// This is useful to infer various subcomputations on a given `linalgOp`. -/// This is performed by looking up each result in the matching indexing map and -/// determining whether: -/// - It is a single AffineDimExpr. -/// - It is the only result involving this AffineDimExpr. -DenseSet findPermutationsIndexingOperand(LinalgOp linalgOp, - OpOperand *opOperand, - utils::IteratorType iter); - -/// Return true if `linalgOp` contains an embedded matmul subcomputation in its -/// most minor dimensions. -bool containsMostMinorMatmul(linalg::LinalgOp linalgOp); - -/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form -/// a matmul subcomputation within `linalgOp`. These dimensions are such that: -/// 1. The m dimension is involved in an outer-product along LHS -/// (i.e. it is a permutation on RES and LHS and does not appear in RHS). -/// 2. The n dimension is involved in an outer-product along RHS -/// (i.e. it is a permutation on RES and RHS and does not appear in LHS). -/// 3. The k dimension appears as a permutation on LHS and RHS. -/// 4. m, n and k appear only once in any given indexing. -/// 5. Optional batch dimensions that appear in all operands are captured. -/// This allows e.g. detecting that some contraction is embedded within -/// `linalgOp` with some orthogonal heuristic. -FailureOr -inferContractionDims(linalg::LinalgOp linalgOp); - //===----------------------------------------------------------------------===// // General utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -17,7 +17,10 @@ #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallVector.h" +#include using namespace mlir; using namespace mlir::linalg; @@ -112,6 +115,96 @@ return success; } +/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the +/// iterators of type `iter` that index the `opOperand` as a permutation. +/// This is useful to infer various subcomputations on a given `linalgOp`. +/// This is performed by looking up each result in the matching indexing map and +/// determining whether: +/// - It is a single AffineDimExpr. +/// - It is the only result involving this AffineDimExpr. +static DenseSet +findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand, + utils::IteratorType iter) { + DenseSet res; + assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner"); + AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); + for (AffineExpr e : indexingMap.getResults()) { + if (auto d = e.dyn_cast()) { + if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter && + llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) { + return e.isFunctionOfDim(d.getPosition()); + }) == 1) + res.insert(d.getPosition()); + } + } + return res; +} + +namespace { +auto par = utils::IteratorType::parallel; +auto red = utils::IteratorType::reduction; +} // namespace + +/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form +/// a matmul subcomputation within `linalgOp`. These dimensions are such that: +/// 1. The m dimension is involved in an outer-product along LHS +/// (i.e. it is a permutation on RES and LHS and does not appear in RHS). +/// 2. The n dimension is involved in an outer-product along RHS +/// (i.e. it is a permutation on RES and RHS and does not appear in LHS). +/// 3. The k dimension appears as a permutation on LHS and RHS. +/// 4. m, n and k appear only once in any given indexing. +/// 5. Optional batch dimensions that appear in all operands are captured. +/// This allows e.g. detecting that some contraction is embedded within +/// `linalgOp` with some orthogonal heuristic. +FailureOr +mlir::linalg::inferContractionDims(LinalgOp linalgOp) { + if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2) + return failure(); + + DenseSet a = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInputOperand(0), par); + DenseSet b = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInputOperand(1), par); + DenseSet c = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInitOperand(0), par); + + // A & C - B are the iterators involved in an outer-product along A (the LHS). + DenseSet ac = a; + llvm::set_intersect(ac, c); + llvm::set_subtract(ac, b); + // B & C - A are the iterators involved in an outer-product along B (the RHS). + DenseSet bc = b; + llvm::set_intersect(bc, c); + llvm::set_subtract(bc, a); + // A & B & C are the "batch" dimensions. + DenseSet batches = a; + llvm::set_intersect(batches, b); + llvm::set_intersect(batches, c); + + // A & B red are the reduction dimensions. + DenseSet ra = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInputOperand(0), red); + DenseSet rb = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInputOperand(1), red); + llvm::set_intersect(ra, rb); + + if (ac.empty() || bc.empty() || ra.empty()) + return failure(); + + // Return each set in sorted order. + ContractionDimensions dimensions{ + SmallVector(batches.begin(), batches.end()), + SmallVector(ac.begin(), ac.end()), + SmallVector(bc.begin(), bc.end()), + SmallVector(ra.begin(), ra.end())}; + std::sort(dimensions.batch.begin(), dimensions.batch.end()); + std::sort(dimensions.m.begin(), dimensions.m.end()); + std::sort(dimensions.n.begin(), dimensions.n.end()); + std::sort(dimensions.k.begin(), dimensions.k.end()); + return dimensions; +} + +namespace mlir::linalg::detail { enum class MatchContractionResult { Success = 0, NotLinalgOp, @@ -120,7 +213,11 @@ NotProjectedPermutations, NotAddMul }; -static MatchContractionResult isContractionInterfaceImpl(Operation *op) { +} // namespace mlir::linalg::detail + +mlir::linalg::detail::MatchContractionResult +mlir::linalg::detail::isContractionInterfaceImpl( + Operation *op, mlir::linalg::ContractionDimensions *dimensions) { auto linalgOp = dyn_cast(op); if (!linalgOp) return MatchContractionResult::NotLinalgOp; @@ -139,15 +236,41 @@ linalgOp->getRegion(0).front()) && !isAddMul(linalgOp->getRegion(0).front())) return MatchContractionResult::NotAddMul; + + if (dimensions) { + FailureOr res = inferContractionDims(linalgOp); + assert(succeeded(res) && "unexpected failure to infer contraction dims"); + *dimensions = *res; + } return MatchContractionResult::Success; } +StringRef +mlir::linalg::detail::getMatchContractionMessage(MatchContractionResult res) { + switch (res) { + case MatchContractionResult::NotLinalgOp: + return "expected a LinalgOp"; + case MatchContractionResult::WrongNumOperands: + return "expected op with 2 inputs and 1 output"; + case MatchContractionResult::NoReduction: + return "expected at least 1 reduction"; + case MatchContractionResult::NotProjectedPermutations: + return "expected indexing maps to be projected permutations"; + case MatchContractionResult::NotAddMul: + return "expected add/mul op in the body"; + case MatchContractionResult::Success: + return ""; + } + llvm_unreachable("unhandled MatchContractionResult case"); +} + bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) { if (!linalgOp) return false; Operation *op = linalgOp.getOperation(); return isa(op) || - (isContractionInterfaceImpl(op) == MatchContractionResult::Success); + (mlir::linalg::detail::isContractionInterfaceImpl(op) == + mlir::linalg::detail::MatchContractionResult::Success); } /// Verify that a LinalgOp `op` is a contraction. @@ -165,16 +288,8 @@ /// constant operations that do not involve the reduction dimension(s). LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) { auto res = isContractionInterfaceImpl(op); - if (res == MatchContractionResult::NotLinalgOp) - return op->emitError("expected a LinalgOp"); - if (res == MatchContractionResult::WrongNumOperands) - return op->emitError("expected op with 2 inputs and 1 outputs"); - if (res == MatchContractionResult::NoReduction) - return op->emitError("expected at least a reduction loop"); - if (res == MatchContractionResult::NotProjectedPermutations) - return op->emitError("expected all indexings to be projected permutations"); - if (res == MatchContractionResult::NotAddMul) - return op->emitError("(add, mul) operations not found"); + if (res != MatchContractionResult::Success) + return op->emitError(getMatchContractionMessage(res)); return success(); } @@ -454,6 +569,11 @@ llvm_unreachable("unhandled MatchConvolutionResult case"); } +bool mlir::linalg::isaConvolutionOpInterface(LinalgOp linalgOp) { + return linalg::detail::isConvolutionInterfaceImpl(linalgOp.getOperation()) == + linalg::detail::MatchConvolutionResult::Success; +} + LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) { MatchConvolutionResult res = isConvolutionInterfaceImpl(op); if (res != MatchConvolutionResult::Success) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -49,6 +49,7 @@ #define DEBUG_TYPE "linalg-transforms" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define DBGSNL() (llvm::dbgs() << "\n") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") /// Attempts to apply the pattern specified as template argument to the given /// operation. The pattern is expected to have a `returningMatchAndRewrite` @@ -1227,6 +1228,8 @@ int64_t numLoops = linalgOp.getNumLoops(); if (numLoops <= 2) { + LDBG("need 3+ loops to find a matmul to pack, got " + << numLoops << "\nin: " << linalgOp << "\n"); return rewriter.notifyMatchFailure( linalgOp, "need 3+ loops to find a matmul to pack"); } @@ -1247,17 +1250,21 @@ } // 1. Infer dims that are important for matmul. - FailureOr res = inferContractionDims(linalgOp); - if (failed(res)) { + FailureOr maybeDimensions = + inferContractionDims(linalgOp); + if (failed(maybeDimensions)) { + LDBG("couldn't infer matmul iterators in: " << linalgOp << "\n"); return rewriter.notifyMatchFailure(linalgOp, "couldn't infer matmul iterators"); } // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most - // minor iterators. If we wanted a different normalization order, this is - // where it would have to plug a heuristic. - int64_t mPos = *(res->mPos.begin()), nPos = *(res->nPos.begin()), - kPos = *(res->kPos.begin()); + // minor iterators. In cases with multiple options for m, n, k bias towards + // the most minor embedding. + // If we wanted a different normalization order, this is where it would have + // to plug a heuristic. + int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(), + kPos = maybeDimensions->k.back(); LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); DBGS() << "Start packing generic op greedily with (m@" << mPos << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp @@ -2655,71 +2662,71 @@ ArrayRef mixedTileSizes, std::optional mapping, linalg::ForallTilingResult &tilingResult) { // Transform all targets one by one. - auto tileableOp = dyn_cast(target); - if (!tileableOp) { - DiagnosedSilenceableFailure diag = - transformOp.emitSilenceableError() - << "only TilingInterface ops are supported"; - diag.attachNote(target->getLoc()) << "target op"; - return diag; - } - rewriter.setInsertionPoint(tileableOp); - FailureOr maybeTilingResult = failure(); - if (!mixedNumThreads.empty()) { - maybeTilingResult = linalg::tileToForallOp(rewriter, tileableOp, - mixedNumThreads, mapping); - } else { - maybeTilingResult = linalg::tileToForallOpUsingTileSizes( - rewriter, tileableOp, mixedTileSizes, mapping); - } + auto tileableOp = dyn_cast(target); + if (!tileableOp) { + DiagnosedSilenceableFailure diag = + transformOp.emitSilenceableError() + << "only TilingInterface ops are supported"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + rewriter.setInsertionPoint(tileableOp); + FailureOr maybeTilingResult = failure(); + if (!mixedNumThreads.empty()) { + maybeTilingResult = + linalg::tileToForallOp(rewriter, tileableOp, mixedNumThreads, mapping); + } else { + maybeTilingResult = linalg::tileToForallOpUsingTileSizes( + rewriter, tileableOp, mixedTileSizes, mapping); + } - if (failed(maybeTilingResult)) - return transformOp.emitDefaultSilenceableFailure(tileableOp); - rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults()); + if (failed(maybeTilingResult)) + return transformOp.emitDefaultSilenceableFailure(tileableOp); + rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults()); - tilingResult = *maybeTilingResult; - return DiagnosedSilenceableFailure::success(); + tilingResult = *maybeTilingResult; + return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure transform::TileToForallOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &transformResults, transform::TransformState &state) { - auto transformOp = cast(getOperation()); - - // Result payload ops. - SmallVector tileOps; - SmallVector tiledOps; - - // Unpack handles. - SmallVector mixedNumThreads; - DiagnosedSilenceableFailure status = - getPackedNumThreads() - ? unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedNumThreads, getPackedNumThreads()) - : unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedNumThreads, getMixedNumThreads()); - if (!status.succeeded()) - return status; - SmallVector mixedTileSizes; - status = getPackedTileSizes() - ? unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedTileSizes, getPackedTileSizes()) - : unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedTileSizes, getMixedTileSizes()); - if (!status.succeeded()) - return status; - - for (Operation *target : state.getPayloadOps(getTarget())) { - linalg::ForallTilingResult tilingResult; - DiagnosedSilenceableFailure diag = tileToForallOpImpl( - rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes, - getMapping(), tilingResult); - if (!diag.succeeded()) + auto transformOp = cast(getOperation()); + + // Result payload ops. + SmallVector tileOps; + SmallVector tiledOps; + + // Unpack handles. + SmallVector mixedNumThreads; + DiagnosedSilenceableFailure status = + getPackedNumThreads() + ? unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedNumThreads, getPackedNumThreads()) + : unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedNumThreads, getMixedNumThreads()); + if (!status.succeeded()) + return status; + SmallVector mixedTileSizes; + status = getPackedTileSizes() + ? unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedTileSizes, getPackedTileSizes()) + : unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedTileSizes, getMixedTileSizes()); + if (!status.succeeded()) + return status; + + for (Operation *target : state.getPayloadOps(getTarget())) { + linalg::ForallTilingResult tilingResult; + DiagnosedSilenceableFailure diag = tileToForallOpImpl( + rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes, + getMapping(), tilingResult); + if (!diag.succeeded()) return diag; tileOps.push_back(tilingResult.tileOp); tiledOps.push_back(tilingResult.tiledOp); - } + } transformResults.set(cast(getForallOp()), tileOps); transformResults.set(cast(getTiledOp()), tiledOps); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -33,7 +33,6 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Pass/Pass.h" -#include "llvm/ADT/SetOperations.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include @@ -140,86 +139,6 @@ } } -//===----------------------------------------------------------------------===// -// Utilities for inferring various semantics properties of Linalg ops. -//===----------------------------------------------------------------------===// - -DenseSet mlir::linalg::findPermutationsIndexingOperand( - LinalgOp linalgOp, OpOperand *opOperand, utils::IteratorType iter) { - DenseSet res; - assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner"); - AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); - for (AffineExpr e : indexingMap.getResults()) { - if (auto d = e.dyn_cast()) { - if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter && - llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) { - return e.isFunctionOfDim(d.getPosition()); - }) == 1) - res.insert(d.getPosition()); - } - } - return res; -} - -namespace { -auto par = utils::IteratorType::parallel; -auto red = utils::IteratorType::reduction; -} // namespace - -bool mlir::linalg::containsMostMinorMatmul(LinalgOp linalgOp) { - FailureOr res = inferContractionDims(linalgOp); - if (failed(res)) - return false; - int64_t numLoops = linalgOp.getNumLoops(); - for (const DenseSet &s : {res->mPos, res->nPos, res->kPos}) { - if (s.contains(numLoops - 3) || s.contains(numLoops - 2) || - s.contains(numLoops - 1)) - continue; - return false; - } - return true; -} - -FailureOr -mlir::linalg::inferContractionDims(LinalgOp linalgOp) { - if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2) - return failure(); - - DenseSet a = findPermutationsIndexingOperand( - linalgOp, linalgOp.getDpsInputOperand(0), par); - DenseSet b = findPermutationsIndexingOperand( - linalgOp, linalgOp.getDpsInputOperand(1), par); - DenseSet c = findPermutationsIndexingOperand( - linalgOp, linalgOp.getDpsInitOperand(0), par); - - // A & C - B are the iterators involved in an outer-product along A (the LHS). - DenseSet ac = a; - llvm::set_intersect(ac, c); - llvm::set_subtract(ac, b); - // B & C - A are the iterators involved in an outer-product along B (the RHS). - DenseSet bc = b; - llvm::set_intersect(bc, c); - llvm::set_subtract(bc, a); - // A & B & C are the "batch" dimensions. - DenseSet batches = a; - llvm::set_intersect(batches, b); - llvm::set_intersect(batches, c); - - // A & B red are the reduction dimensions. - DenseSet ra = findPermutationsIndexingOperand( - linalgOp, linalgOp.getDpsInputOperand(0), red); - DenseSet rb = findPermutationsIndexingOperand( - linalgOp, linalgOp.getDpsInputOperand(1), red); - llvm::set_intersect(ra, rb); - - if (ac.empty() || bc.empty() || ra.empty()) - return failure(); - - // Pick the first one in each set. - // TODO: Better heuristic (e.g pick dims based on packing-based metric). - return EmbeddedContractionDimsCandidates{batches, ac, bc, ra}; -} - //===----------------------------------------------------------------------===// // General utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp @@ -59,7 +59,9 @@ for (Operation *op : payload) { if (opName != op->getName()) { DiagnosedSilenceableFailure diag = - emitSilenceableError(loc) << "incompatible payload operation name"; + emitSilenceableError(loc) + << "incompatible payload operation name expected " << opName << " vs " + << op->getName() << " -> " << *op; diag.attachNote(op->getLoc()) << "payload operation"; return diag; } diff --git a/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir b/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir --- a/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir +++ b/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir @@ -326,3 +326,25 @@ matmul_inner_dims_order = [1, 2, 0] : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic"> } + +// ----- + +!A = tensor<1023x255xf32> +!X = tensor<255xf32> +!Y = tensor<1023xf32> + +// CHECK-LABEL: @matvec_fail( +func.func @matvec_fail(%A : !A, %x : !X, %y : !Y) -> !Y { + // CHECK: linalg.matvec + %0 = linalg.matvec ins(%A, %x : !A, !X) outs(%y : !Y) -> !Y + return %0 : !Y +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !transform.any_op): + %matmul = transform.structured.match ops{["linalg.matvec"]} in %module_op + : (!transform.any_op) -> !transform.op<"linalg.matvec"> + transform.structured.pack_greedily %matmul + matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0] + : (!transform.op<"linalg.matvec">) -> !transform.any_op +}