diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -55,6 +55,8 @@ SmallVector filterLoop; SmallVector inputChannel; SmallVector depth; + SmallVector strides; + SmallVector dilations; }; /// Checks whether `op` conforms to ConvolutionOpInterface and populates diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -203,6 +203,7 @@ : public AffineExprVisitor { llvm::SmallDenseSet convolvedDims; llvm::SmallDenseSet unConvolvedDims; + llvm::SmallDenseMap strideAndDilationMapping; LogicalResult visitDimExpr(AffineDimExpr dimExpr) { unsigned position = dimExpr.getPosition(); @@ -230,6 +231,9 @@ unsigned dim = dimExpr.getPosition(); if (convolvedDims.count(dim) || unConvolvedDims.count(dim)) return failure(); + // Stride/dilation for this dim is implicitly 1. + strideAndDilationMapping[dim] = + getAffineConstantExpr(1, expr.getContext()); convolvedDims.insert(dim); return success(); } @@ -251,6 +255,7 @@ unsigned dim = dimExpr.getPosition(); if (convolvedDims.count(dim) || unConvolvedDims.count(dim)) return failure(); + strideAndDilationMapping[dim] = mulExpr; convolvedDims.insert(dim); return success(); } @@ -268,6 +273,17 @@ return preservedDims; } +static SmallVector +getConstantsFromExprList(SmallVector exprs) { + SmallVector vals; + for (auto e : exprs) { + auto constantExpr = e.dyn_cast(); + assert(constantExpr && "Found non-constant stride/dilation"); + vals.push_back(constantExpr.getValue()); + } + return vals; +} + namespace mlir::linalg::detail { enum class MatchConvolutionResult { Success = 0, @@ -325,6 +341,7 @@ // - Depth multiplier : unconvolved in input, present in output, present in // filter. llvm::SmallDenseSet allLoopDims; + llvm::SmallVector strideExprs; for (auto outputExpr : indexingMaps.back().getResults()) { unsigned outputDim = outputExpr.cast().getPosition(); if (inputExprWalker.unConvolvedDims.count(outputDim) && @@ -343,8 +360,12 @@ if (iteratorTypes[outputDim] != utils::IteratorType::parallel) return MatchConvolutionResult::OutputDimsNotParallel; allLoopDims.insert(outputDim); - if (dimensions) + if (dimensions) { + strideExprs.push_back( + inputExprWalker.strideAndDilationMapping[outputDim]); dimensions->outputImage.push_back(outputDim); + } + continue; } if (!inputExprWalker.convolvedDims.count(outputDim) && @@ -370,6 +391,7 @@ } return MatchConvolutionResult::NonConvolutionLoop; } + llvm::SmallVector dilationExprs; for (auto filterExpr : indexingMaps[1].getResults()) { unsigned filterDim = filterExpr.cast().getPosition(); if (outputDims.count(filterDim) && @@ -389,8 +411,11 @@ if (allLoopDims.count(filterDim)) return MatchConvolutionResult::NonConvolutionLoop; allLoopDims.insert(filterDim); - if (dimensions) + if (dimensions) { + dilationExprs.push_back( + inputExprWalker.strideAndDilationMapping[filterDim]); dimensions->filterLoop.push_back(filterDim); + } continue; } if (inputExprWalker.unConvolvedDims.count(filterDim) && @@ -426,6 +451,17 @@ dimensions->inputChannel.size() + dimensions->depth.size() == linalgOp.getNumLoops() && "expected all loops to be classified"); + + // Use the op carried strides/dilations attribute if present. + auto nativeStrides = op->getAttrOfType("strides"); + dimensions->strides = + !nativeStrides ? getConstantsFromExprList(strideExprs) + : llvm::to_vector<2>(nativeStrides.getValues()); + auto nativeDilations = op->getAttrOfType("dilations"); + dimensions->dilations = + !nativeDilations + ? getConstantsFromExprList(dilationExprs) + : llvm::to_vector<2>(nativeDilations.getValues()); } return MatchConvolutionResult::Success;