diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -67,6 +67,45 @@ // TODO: embed within `isa` if possible / natural. bool isaContractionOpInterface(LinalgOp linalgOp); +/// Positions of a Linalg op loops that correspond to different kinds of a +/// convolution dimension. +struct ConvolutionDimensions { + SmallVector batch; + SmallVector outputImage; + SmallVector outputChannel; + SmallVector filterLoop; + SmallVector inputChannel; + SmallVector depth; + SmallVector strides; + SmallVector dilations; +}; + +/// Find at least 1 parallel (output_image) and reduction (filter_loop) +/// dimension candidates that form a convolution subcomputation within +/// `linalgOp`. The LHS is assumed to be the convolution input while the +/// RHS is assumed as the filter. +/// These dimensions are such that: +/// 1. Optional batch dimensions that appear in the input and filter. +/// 2. The output_image dimension is involved in a cross-correlation along LHS +/// (i.e. it is a permutation on RES and LHS and has an associated +/// filter_loop in RHS). +/// 3. Optional output_channel dimension is involved in an outer-product along +/// RHS (i.e. it is a permutation on RES and RHS and does not appear in +/// LHS). +/// 4. Optional input_channel dimension appears as a permutation on LHS and +/// RHS. +/// 5. The filter_loop dimension appears as a permutation on the RHS and +/// represents the shape of the kernel cross-correlated along a +/// corresponding output_image dim. +/// 6. The input_channel dimension appears as a permutation on LHS and RHS. +/// 7. All dimensions appear only once in any given indexing map. +/// This allows e.g. detecting that some convolution is embedded within +/// `linalgOp` with some orthogonal heuristic. +/// When multiple dimension occurrences exist that match any classification +/// indices are returned in sorted order. +/// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty. +FailureOr inferConvolutionDims(LinalgOp linalgOp); + /// Checks whether `linalgOp` conforms to ConvolutionOpInterface. // TODO: embed within `isa` if possible / natural. bool isaConvolutionOpInterface(LinalgOp linalgOp); @@ -100,9 +139,6 @@ /// Checks whether `op` conforms to ContractionOpInterface and populates /// `dimensions` with indexes of the different kinds of dimensions when /// present. -// TODO: Extract a standalone `inferConvolutionDims` that can also detect -// whether a conv pattern exists within a bigger linalg op (see -// inferContractionDims). MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions = nullptr); @@ -115,17 +151,6 @@ /// convolution. enum class MatchConvolutionResult; -/// Positions of a Linalg op loops that correspond to different kinds of a -/// convolution dimension. -struct ConvolutionDimensions { - SmallVector batch; - SmallVector outputImage; - SmallVector outputChannel; - SmallVector filterLoop; - SmallVector inputChannel; - SmallVector depth; -}; - /// Checks whether `op` conforms to ConvolutionOpInterface and populates /// `dimensions` with indexes of the different kinds of dimensions when /// present. diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td @@ -170,6 +170,57 @@ let extraClassDeclaration = SingleOpMatcher.extraDeclaration; } +def MatchStructuredClassifyConvolutionDimsOp + : Op { + let summary = + "Checks if an operation has convolution-like dimensions and returns them"; + let description = !strconcat([{ + Checks if the structured payload op has convolution-like dimensions as + follows: + + C(batch, depth, oi, oc) += A(batch, depth, oi, ic) * B(fl, depth, ic, oc) + + That is: + + - 'batch' are parallel dimensions used in the input and result; + - 'output_image' ('oi') are parallel dimensions used in the input and result; + - 'output_channel' ('oc') are parallel dimensions used in the filter and result; + - 'filter_loop' ('fl') are reduction dimensions representing the dimensions of the sliding window; + - 'input_channel' ('ic') are reduction dimensions present only in the input and filter. + - 'depth' ('ic') are parallel dimensions present in the input, filter, and output. + + Additionally this will match stride and dilation information for the convolution: + - 'strides' are the static strides per convolution window dimension; + - 'dilations' are the static dilations per convolution window dimension. + + Note that this doesn't check the operation in the body. + + }], StructuredPredicate.extraDescription, [{ + + #### Return modes + + Succeeds if the operation has the convolution-like dimensions, produces a + silenceable failure otherwise. + }]); + + let arguments = (ins TransformHandleTypeInterface:$operand_handle); + let results = (outs TransformParamTypeInterface:$batch, + TransformParamTypeInterface:$output_image, + TransformParamTypeInterface:$output_channel, + TransformParamTypeInterface:$filter_loop, + TransformParamTypeInterface:$input_channel, + TransformParamTypeInterface:$depth, + TransformParamTypeInterface:$strides, + TransformParamTypeInterface:$dilations); + let assemblyFormat = + "$operand_handle attr-dict `:` functional-type(operands, results)"; + let extraClassDeclaration = SingleOpMatcher.extraDeclaration; +} + class StructuredDimDescription { string description = !strconcat([{ The following }], kind ,[{ specifications are supported: diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -161,10 +161,10 @@ /// determining whether: /// - It is a single AffineDimExpr. /// - It is the only result involving this AffineDimExpr. -static DenseSet +static llvm::SmallDenseSet findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand, utils::IteratorType iter) { - DenseSet res; + llvm::SmallDenseSet res; assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner"); AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); for (AffineExpr e : indexingMap.getResults()) { @@ -200,30 +200,30 @@ if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2) return failure(); - DenseSet a = findPermutationsIndexingOperand( + llvm::SmallDenseSet a = findPermutationsIndexingOperand( linalgOp, linalgOp.getDpsInputOperand(0), par); - DenseSet b = findPermutationsIndexingOperand( + llvm::SmallDenseSet b = findPermutationsIndexingOperand( linalgOp, linalgOp.getDpsInputOperand(1), par); - DenseSet c = findPermutationsIndexingOperand( + llvm::SmallDenseSet c = findPermutationsIndexingOperand( linalgOp, linalgOp.getDpsInitOperand(0), par); // A & C - B are the iterators involved in an outer-product along A (the LHS). - DenseSet ac = a; + llvm::SmallDenseSet ac = a; llvm::set_intersect(ac, c); llvm::set_subtract(ac, b); // B & C - A are the iterators involved in an outer-product along B (the RHS). - DenseSet bc = b; + llvm::SmallDenseSet bc = b; llvm::set_intersect(bc, c); llvm::set_subtract(bc, a); // A & B & C are the "batch" dimensions. - DenseSet batches = a; + llvm::SmallDenseSet batches = a; llvm::set_intersect(batches, b); llvm::set_intersect(batches, c); // A & B red are the reduction dimensions. - DenseSet ra = findPermutationsIndexingOperand( + llvm::SmallDenseSet ra = findPermutationsIndexingOperand( linalgOp, linalgOp.getDpsInputOperand(0), red); - DenseSet rb = findPermutationsIndexingOperand( + llvm::SmallDenseSet rb = findPermutationsIndexingOperand( linalgOp, linalgOp.getDpsInputOperand(1), red); llvm::set_intersect(ra, rb); @@ -236,10 +236,10 @@ SmallVector(ac.begin(), ac.end()), SmallVector(bc.begin(), bc.end()), SmallVector(ra.begin(), ra.end())}; - std::sort(dimensions.batch.begin(), dimensions.batch.end()); - std::sort(dimensions.m.begin(), dimensions.m.end()); - std::sort(dimensions.n.begin(), dimensions.n.end()); - std::sort(dimensions.k.begin(), dimensions.k.end()); + llvm::sort(dimensions.batch.begin(), dimensions.batch.end()); + llvm::sort(dimensions.m.begin(), dimensions.m.end()); + llvm::sort(dimensions.n.begin(), dimensions.n.end()); + llvm::sort(dimensions.k.begin(), dimensions.k.end()); return dimensions; } @@ -359,8 +359,37 @@ /// dimensions and verifies each dimension occurs only once. struct ConvAccessExprWalker : public AffineExprVisitor { - llvm::SmallDenseSet convolvedDims; - llvm::SmallDenseSet unConvolvedDims; + // Stores dimensions used in expressions of the above form. + llvm::SmallDenseSet convolvedDims; + // Stores the dual mapping between LHS and RHS of convolution exprs. + llvm::SmallDenseMap convolvedDimMapping; + // Stores single use dimensions used by an AffineDimExpr. + llvm::SmallDenseSet unConvolvedDims; + // Stores a mapping from convolved dims to their coefficient. + llvm::SmallDenseMap strideAndDilationMapping; + + // Removes dims with multiple uses in the source input map from dimension + // sets tracked by this walker. + void clearMultiUseDims(AffineMap map) { + for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) { + if (llvm::count_if(map.getResults(), [dimPos](AffineExpr e) { + return e.isFunctionOfDim(dimPos); + }) > 1) { + convolvedDims.erase(dimPos); + unConvolvedDims.erase(dimPos); + // If a duplicate dim is marked as convolved, the pair of the duplicate + // dim must be removed from the map as well. + if (convolvedDimMapping.contains(dimPos)) { + int64_t pairedDim = convolvedDimMapping[dimPos]; + convolvedDims.erase(pairedDim); + unConvolvedDims.erase(pairedDim); + strideAndDilationMapping.erase(pairedDim); + convolvedDimMapping.erase(dimPos); + convolvedDimMapping.erase(pairedDim); + } + } + } + } LogicalResult visitDimExpr(AffineDimExpr dimExpr) { unsigned position = dimExpr.getPosition(); @@ -379,17 +408,25 @@ // In pre-order visit, top level op has to be an add op. if (binaryExpr.getKind() != AffineExprKind::Add) return failure(); - return success(succeeded(isDimExprOrMulExpr(binaryExpr.getLHS())) && - succeeded(isDimExprOrMulExpr(binaryExpr.getRHS()))); + auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getLHS()); + auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getRHS()); + if (failed(lhsDimPos) || failed(rhsDimPos)) + return failure(); + convolvedDimMapping[*lhsDimPos] = *rhsDimPos; + convolvedDimMapping[*rhsDimPos] = *lhsDimPos; + return success(); } - LogicalResult isDimExprOrMulExpr(AffineExpr expr) { + FailureOr getDimExprOrMulExprDimPos(AffineExpr expr) { if (auto dimExpr = expr.dyn_cast()) { - unsigned dim = dimExpr.getPosition(); + int64_t dim = dimExpr.getPosition(); if (convolvedDims.count(dim) || unConvolvedDims.count(dim)) return failure(); + // Stride/dilation for this dim is implicitly 1. + strideAndDilationMapping[dim] = + getAffineConstantExpr(1, expr.getContext()); convolvedDims.insert(dim); - return success(); + return dim; } if (auto symbolMulExpr = expr.dyn_cast()) { if (symbolMulExpr.getKind() != AffineExprKind::Mul) @@ -406,26 +443,170 @@ auto dimExpr = getAffineExprOfType(lhsExpr, rhsExpr); if (!mulExpr || !dimExpr) return failure(); - unsigned dim = dimExpr.getPosition(); + int64_t dim = dimExpr.getPosition(); if (convolvedDims.count(dim) || unConvolvedDims.count(dim)) return failure(); + strideAndDilationMapping[dim] = mulExpr; convolvedDims.insert(dim); - return success(); + return dim; } return failure(); } }; } // namespace -static llvm::SmallDenseSet getPreservedDims(AffineMap map) { +static llvm::SmallDenseSet getPreservedDims(AffineMap map) { assert(map.isProjectedPermutation() && "expected map to have projected permutations"); - llvm::SmallDenseSet preservedDims; + llvm::SmallDenseSet preservedDims; for (auto expr : map.getResults()) preservedDims.insert(expr.cast().getPosition()); return preservedDims; } +static SmallVector +getConstantsFromExprList(SmallVector exprs) { + SmallVector vals; + for (auto e : exprs) { + auto constantExpr = e.dyn_cast(); + assert(constantExpr && "Found non-constant stride/dilation"); + vals.push_back(constantExpr.getValue()); + } + return vals; +} + +/// Classifies dimensions in the `linalgOp` used by a convolution +/// subcomputation, as captured by `inputExprWalker`. If +/// `allowEmptyConvolvedDims` is not set this this will fail if there is not +/// at least convolved dimension pair (output image + filter loop). Convolution +/// dimensions are specified in sorted order, and strides match the order of +/// the filter loop dimensions, while the dilations match the order of the +/// output image dimensions. +static FailureOr +inferConvolutionDimsImpl(LinalgOp linalgOp, + ConvAccessExprWalker &inputExprWalker, + bool allowEmptyConvolvedDims) { + llvm::SmallDenseSet filterDims = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInputOperand(1), par); + llvm::SmallDenseSet outputDims = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInitOperand(0), par); + + // unConvolvedDims & outputDims - filterDims are the batch iterators. + llvm::SmallDenseSet batch = inputExprWalker.unConvolvedDims; + llvm::set_intersect(batch, outputDims); + llvm::set_subtract(batch, filterDims); + + // convolvedDims & outputDims are the output image iterators. + llvm::SmallDenseSet oi = inputExprWalker.convolvedDims; + llvm::set_intersect(oi, outputDims); + + // filterDims & outputDims - unConvolvedDims are the output channel iterators. + llvm::SmallDenseSet oc = filterDims; + llvm::set_intersect(oc, outputDims); + llvm::set_subtract(oc, inputExprWalker.unConvolvedDims); + + // filterDims & outputDims & unConvolvedDims are the depth iterators. + llvm::SmallDenseSet depth = filterDims; + llvm::set_intersect(depth, outputDims); + llvm::set_intersect(depth, inputExprWalker.unConvolvedDims); + + llvm::SmallDenseSet filterReducedDims = + findPermutationsIndexingOperand(linalgOp, linalgOp.getDpsInputOperand(1), + red); + + // convolvedDims & filterReducedDims are the filter loop iterators. + llvm::SmallDenseSet fl = inputExprWalker.convolvedDims; + llvm::set_intersect(fl, filterReducedDims); + + // unConvolvedDims & filterReducedDims are the input channel iterators. + llvm::SmallDenseSet ic = inputExprWalker.unConvolvedDims; + llvm::set_intersect(ic, filterReducedDims); + + if (oi.empty() && !allowEmptyConvolvedDims) + return failure(); + + // Return each set in sorted order. + ConvolutionDimensions dimensions{ + SmallVector(batch.begin(), batch.end()), + SmallVector(oi.begin(), oi.end()), + SmallVector(oc.begin(), oc.end()), + SmallVector(fl.begin(), fl.end()), + SmallVector(ic.begin(), ic.end()), + SmallVector(depth.begin(), depth.end()), + /*strides=*/SmallVector{}, + /*dilations=*/SmallVector{}}; + llvm::sort(dimensions.batch.begin(), dimensions.batch.end()); + llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end()); + llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end()); + llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end()); + llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end()); + llvm::sort(dimensions.depth.begin(), dimensions.depth.end()); + + // Use the op carried strides/dilations attribute if present. + auto nativeStrides = linalgOp->getAttrOfType("strides"); + if (!nativeStrides) { + SmallVector strideExprs; + for (unsigned oiDim : dimensions.outputImage) + strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]); + dimensions.strides = getConstantsFromExprList(strideExprs); + } else { + dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues()); + } + auto nativeDilations = + linalgOp->getAttrOfType("dilations"); + if (!nativeDilations) { + SmallVector dilationExprs; + for (unsigned flDim : dimensions.filterLoop) + dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]); + dimensions.dilations = getConstantsFromExprList(dilationExprs); + } else { + dimensions.dilations = + llvm::to_vector<2>(nativeDilations.getValues()); + } + return dimensions; +} + +/// Find at least 1 parallel (output_image) and reduction (filter_loop) +/// dimension candidates that form a convolution subcomputation within +/// `linalgOp`. The LHS is assumed to be the convolution input while the +/// RHS is assumed as the filter. +/// These dimensions are such that: +/// 1. Optional batch dimensions that appear in the input and filter. +/// 2. The output_image dimension is involved in a cross-correlation along LHS +/// (i.e. it is a permutation on RES and LHS and has an associated +/// filter_loop in RHS). +/// 3. Optional output_channel dimension is involved in an outer-product along +/// RHS (i.e. it is a permutation on RES and RHS and does not appear in +/// LHS). +/// 4. Optional input_channel dimension appears as a permutation on LHS and +/// RHS. +/// 5. The filter_loop dimension appears as a permutation on the RHS and +/// represents the shape of the kernel cross-correlated along a +/// corresponding output_image dim. +/// 6. The input_channel dimension appears as a permutation on LHS and RHS. +/// 7. All dimensions appear only once in any given indexing map. +/// This allows e.g. detecting that some convolution is embedded within +/// `linalgOp` with some orthogonal heuristic. +/// When multiple dimension occurrences exist that match any classification +/// indices are returned in sorted order. +/// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty. +FailureOr +mlir::linalg::inferConvolutionDims(LinalgOp linalgOp) { + if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2) + return failure(); + + auto indexingMaps = linalgOp.getIndexingMapsArray(); + + // Check the input indexing map has the right form. + ConvAccessExprWalker inputExprWalker; + for (AffineExpr expr : indexingMaps[0].getResults()) + (void)inputExprWalker.visit(expr); + inputExprWalker.clearMultiUseDims(indexingMaps[0]); + + return inferConvolutionDimsImpl(linalgOp, inputExprWalker, + /*allowEmptyConvolvedDims=*/false); +} + namespace mlir::linalg::detail { enum class MatchConvolutionResult { Success = 0, @@ -466,9 +647,9 @@ auto iteratorTypes = linalgOp.getIteratorTypesArray(); - llvm::SmallDenseSet outputDims = + llvm::SmallDenseSet outputDims = getPreservedDims(indexingMaps.back()); - llvm::SmallDenseSet filterDims = getPreservedDims(indexingMaps[1]); + llvm::SmallDenseSet filterDims = getPreservedDims(indexingMaps[1]); // Make sure all loops are characterized as one of: // - Batch loop : present in output, as non-convolved in input, not present in // filter. @@ -482,17 +663,15 @@ // present in filter. // - Depth multiplier : unconvolved in input, present in output, present in // filter. - llvm::SmallDenseSet allLoopDims; + llvm::SmallDenseSet allLoopDims; for (auto outputExpr : indexingMaps.back().getResults()) { - unsigned outputDim = outputExpr.cast().getPosition(); + int64_t outputDim = outputExpr.cast().getPosition(); if (inputExprWalker.unConvolvedDims.count(outputDim) && !filterDims.count(outputDim)) { // Batch dimension. if (iteratorTypes[outputDim] != utils::IteratorType::parallel) return MatchConvolutionResult::OutputDimsNotParallel; allLoopDims.insert(outputDim); - if (dimensions) - dimensions->batch.push_back(outputDim); continue; } if (inputExprWalker.convolvedDims.count(outputDim) && @@ -501,8 +680,6 @@ if (iteratorTypes[outputDim] != utils::IteratorType::parallel) return MatchConvolutionResult::OutputDimsNotParallel; allLoopDims.insert(outputDim); - if (dimensions) - dimensions->outputImage.push_back(outputDim); continue; } if (!inputExprWalker.convolvedDims.count(outputDim) && @@ -512,8 +689,6 @@ if (iteratorTypes[outputDim] != utils::IteratorType::parallel) return MatchConvolutionResult::OutputDimsNotParallel; allLoopDims.insert(outputDim); - if (dimensions) - dimensions->outputChannel.push_back(outputDim); continue; } if (inputExprWalker.unConvolvedDims.count(outputDim) && @@ -522,21 +697,16 @@ if (iteratorTypes[outputDim] != utils::IteratorType::parallel) return MatchConvolutionResult::OutputDimsNotParallel; allLoopDims.insert(outputDim); - if (dimensions) - dimensions->depth.push_back(outputDim); continue; } return MatchConvolutionResult::NonConvolutionLoop; } for (auto filterExpr : indexingMaps[1].getResults()) { - unsigned filterDim = filterExpr.cast().getPosition(); + int64_t filterDim = filterExpr.cast().getPosition(); if (outputDims.count(filterDim) && !inputExprWalker.unConvolvedDims.count(filterDim) && !inputExprWalker.convolvedDims.count(filterDim)) { // Output channel dimension. This is already seen, continue; - assert((!dimensions || - llvm::is_contained(dimensions->outputChannel, filterDim)) && - "expected output channel to have been found from output dims"); continue; } if (inputExprWalker.convolvedDims.count(filterDim) && @@ -547,8 +717,6 @@ if (allLoopDims.count(filterDim)) return MatchConvolutionResult::NonConvolutionLoop; allLoopDims.insert(filterDim); - if (dimensions) - dimensions->filterLoop.push_back(filterDim); continue; } if (inputExprWalker.unConvolvedDims.count(filterDim) && @@ -559,16 +727,11 @@ if (allLoopDims.count(filterDim)) return MatchConvolutionResult::NonConvolutionLoop; allLoopDims.insert(filterDim); - if (dimensions) - dimensions->inputChannel.push_back(filterDim); continue; } if (inputExprWalker.unConvolvedDims.count(filterDim) && outputDims.count(filterDim)) { // Depthwise loop. Already seen. - assert( - (!dimensions || llvm::is_contained(dimensions->depth, filterDim)) && - "expected depthwise dimension to have been found from output dims"); continue; } return MatchConvolutionResult::NonConvolutionLoop; @@ -578,12 +741,11 @@ return MatchConvolutionResult::NonConvolutionLoop; if (dimensions) { - assert(dimensions->batch.size() + dimensions->outputImage.size() + - dimensions->outputChannel.size() + - dimensions->filterLoop.size() + - dimensions->inputChannel.size() + dimensions->depth.size() == - linalgOp.getNumLoops() && - "expected all loops to be classified"); + FailureOr res = + inferConvolutionDimsImpl(linalgOp, inputExprWalker, + /*allowEmptyConvolvedDims=*/true); + assert(succeeded(res) && "unexpected failure to infer convolution dims"); + *dimensions = *res; } return MatchConvolutionResult::Success; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp @@ -259,6 +259,53 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// MatchStructuredClassifyConvolutionDimsOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation( + Operation *current, transform::TransformResults &results, + transform::TransformState &state) { + FailureOr convolutionDims = + linalg::inferConvolutionDims(cast(current)); + if (failed(convolutionDims)) + return emitSilenceableError() << "could not infer convolution dimensions"; + + MLIRContext *context = current->getContext(); + Builder builder(context); + auto makeI64Attrs = [&](ArrayRef values) { + return llvm::to_vector( + llvm::map_range(values, [&](unsigned value) -> Attribute { + return builder.getI64IntegerAttr(value); + })); + }; + results.setParams(getBatch().cast(), + makeI64Attrs(convolutionDims->batch)); + results.setParams(getOutputImage().cast(), + makeI64Attrs(convolutionDims->outputImage)); + results.setParams(getOutputChannel().cast(), + makeI64Attrs(convolutionDims->outputChannel)); + results.setParams(getFilterLoop().cast(), + makeI64Attrs(convolutionDims->filterLoop)); + results.setParams(getInputChannel().cast(), + makeI64Attrs(convolutionDims->inputChannel)); + results.setParams(getDepth().cast(), + makeI64Attrs(convolutionDims->depth)); + + auto makeI64AttrsFromI64 = [&](ArrayRef values) { + return llvm::to_vector( + llvm::map_range(values, [&](int64_t value) -> Attribute { + return builder.getI64IntegerAttr(value); + })); + }; + results.setParams(getStrides().cast(), + makeI64AttrsFromI64(convolutionDims->strides)); + results.setParams(getDilations().cast(), + makeI64AttrsFromI64(convolutionDims->dilations)); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // Utilities for structured match predicates. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir --- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir +++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir @@ -934,3 +934,97 @@ return %result : tensor<40x10x50x15xf32> } } + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @match_convolution(%arg0: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param) { + %1:8 = transform.match.structured %arg0 : (!transform.any_op) -> (!transform.param, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param) { + ^bb0(%struct: !transform.any_op): + transform.match.structured.body %struct { contraction = ["arith.mulf", "arith.addf"] } : !transform.any_op + %0:8 = transform.match.structured.classify_convolution_dims %struct + : (!transform.any_op) -> (!transform.param, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param) + transform.match.structured.yield %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7 + : !transform.param, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param + } + transform.yield %arg0, %1#0, %1#1, %1#2, %1#3, %1#4, %1#5, %1#6, %1#7 : !transform.any_op, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param + } + + transform.named_sequence @print_convolution( + %op: !transform.any_op {transform.readonly}, + %batch: !transform.param {transform.readonly}, + %oi: !transform.param {transform.readonly}, + %oc: !transform.param {transform.readonly}, + %fl: !transform.param {transform.readonly}, + %ic: !transform.param {transform.readonly}, + %depth: !transform.param {transform.readonly}, + %strides: !transform.param {transform.readonly}, + %dilations: !transform.param {transform.readonly}) { + transform.test_print_remark_at_operand %op, "convolution" : !transform.any_op + transform.test_print_param %batch, "batch dims" at %op : !transform.param, !transform.any_op + transform.test_print_param %oi, "output image dims" at %op : !transform.param, !transform.any_op + transform.test_print_param %oc, "output channel dims" at %op : !transform.param, !transform.any_op + transform.test_print_param %fl, "filter loop dims" at %op : !transform.param, !transform.any_op + transform.test_print_param %ic, "input channel dims" at %op : !transform.param, !transform.any_op + transform.test_print_param %depth, "depth dims" at %op : !transform.param, !transform.any_op + transform.test_print_param %strides, "strides" at %op : !transform.param, !transform.any_op + transform.test_print_param %dilations, "dilations" at %op : !transform.param, !transform.any_op + transform.yield + } + + transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } { + ^bb0(%arg0: !transform.any_op): + %3 = transform.foreach_match in %arg0 @match_convolution -> @print_convolution : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +module attributes { transform.target_tag = "start_here" } { + func.func @convolution_simple(%input: tensor<10x20x30xf32>, %filter: tensor<3x30x15xf32>) -> tensor<10x18x15xf64> { + %cst = arith.constant 0.0 : f64 + %empty = tensor.empty() : tensor<10x18x15xf64> + %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x18x15xf64>) -> tensor<10x18x15xf64> + // expected-remark @below {{convolution}} + // expected-remark @below {{batch dims 0}} + // expected-remark @below {{output image dims 1}} + // expected-remark @below {{output channel dims 2}} + // expected-remark @below {{filter loop dims 3}} + // expected-remark @below {{input channel dims 4}} + // expected-remark @below {{depth dims}} + // expected-remark @below {{strides 1}} + // expected-remark @below {{dilations 1}} + %result = linalg.conv_1d_nwc_wcf {dilations = dense<1> : tensor<1xi64>, + strides = dense<1> : tensor<1xi64>} + ins(%input, %filter: tensor<10x20x30xf32>, tensor<3x30x15xf32>) outs(%fill: tensor<10x18x15xf64>) -> tensor<10x18x15xf64> + return %result : tensor<10x18x15xf64> + } + + func.func @convolution_multi_channel(%input: tensor<2x34x68x16xf32>, %filter: tensor<8x2x3x5x16x16xf32>) -> tensor<8x32x32x16xf32> { + %cst = arith.constant 0.0 : f32 + %empty = tensor.empty() : tensor<8x32x32x16xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<8x32x32x16xf32>) -> tensor<8x32x32x16xf32> + // expected-remark @below {{convolution}} + // expected-remark @below {{batch dims}} + // expected-remark @below {{output image dims 1 : i64, 2 : i64}} + // expected-remark @below {{output channel dims 0 : i64, 3 : i64}} + // expected-remark @below {{filter loop dims 5 : i64, 6 : i64}} + // expected-remark @below {{input channel dims 4 : i64, 7 : i64}} + // expected-remark @below {{depth dims}} + // expected-remark @below {{strides 1 : i64, 2 : i64}} + // expected-remark @below {{dilations 1 : i64, 1 : i64}} + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d1 + d5, 2 * d2 + d6, d7)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d4, d5, d6, d7, d3)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction", "reduction"]} + ins(%input, %filter : tensor<2x34x68x16xf32>, tensor<8x2x3x5x16x16xf32>) outs(%fill : tensor<8x32x32x16xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %mul = arith.mulf %in, %in_0 : f32 + %add = arith.addf %mul, %out : f32 + linalg.yield %add : f32 + } -> tensor<8x32x32x16xf32> + return %result : tensor<8x32x32x16xf32> + } +}