diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -23,6 +23,40 @@ #include "mlir/Interfaces/ViewLikeInterface.h" namespace mlir { + +/// Get the loop dimension that iterates over the batch dimension for `op` that +/// implements the `ConvolutionOpInterface`. Returns llvm::None if the `op` does +/// not implement the `ConvolutionOpInterface`, or if no batch dimensions exist. +Optional getConvolutionBatchLoopDim(Operation *op); + +/// Get the loop dimensions that iterate over the output image for `op` that +/// implements the `ConvolutionOpInterface`. Returns `{}` if the `op` does not +/// implement the `ConvolutionOpInterface`. +SmallVector getConvolutionOutputImageLoopDims(Operation *op); + +/// Get the loop dimension that iterates over the output channel dimensions for +/// `op` that implements the `ConvolutionOpInterface`. Returns llvm::None if +/// the `op` does not implement the `ConvolutionOpInterface`, or if no output +/// channel dimensions exist. +Optional getConvolutionOutputChannelLoopDim(Operation *op); + +/// Get the loop dimensions that iterate over the filter loops for `op` that +/// implements the `ConvolutionOpInterface`. Returns `{}` if the `op` does not +/// implement the `ConvolutionOpInterface`. +SmallVector getConvolutionFilterLoopDims(Operation *op); + +/// Get the loop dimension that iterates over the input channel dimensions for +/// `op` that implements the `ConvolutionOpInterface`. Returns llvm::None if +/// the `op` does not implement the `ConvolutionOpInterface`, or if no input +/// channel dimensions exist. +Optional getConvolutionInputChannelLoopDim(Operation *op); + +/// Get the loop dimension that iterates over the depthwise dimension for `op` +/// that implements the `ConvolutionOpInterface`. Returns llvm::None if the +/// `op` does not implement the `ConvolutionOpInterface`, or is not a depthwise +/// convolution. +Optional getConvolutionDepthwiseLoopDim(Operation *op); + namespace linalg { class LinalgOp; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -129,6 +129,60 @@ return $_op.getOperation()->getOperand(1); }] >, + InterfaceMethod< + /*desc=*/"Return the loop over batch dimensions of the convolution operation.", + /*retTy=*/"Optional", + /*methodName=*/"getBatchLoopDim", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::getConvolutionBatchLoopDim($_op); + }] + >, + InterfaceMethod< + /*desc=*/"Return the loops over output image dimensions of the convolution operation.", + /*retTy=*/"SmallVector", + /*methodName=*/"getOutputImageLoopDims", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::getConvolutionOutputImageLoopDims($_op); + }] + >, + InterfaceMethod< + /*desc=*/"Return the loop over output channel dimensions of the convolution operation.", + /*retTy=*/"Optional", + /*methodName=*/"getOutputChannelLoopDim", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::getConvolutionOutputChannelLoopDim($_op); + }] + >, + InterfaceMethod< + /*desc=*/"Return the loop over filter dimensions of the convolution operation.", + /*retTy=*/"SmallVector", + /*methodName=*/"getFilterLoopDims", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::getConvolutionFilterLoopDims($_op); + }] + >, + InterfaceMethod< + /*desc=*/"Return the loop over input channel dimensions of the convolution operation.", + /*retTy=*/"Optional", + /*methodName=*/"getInputChannelLoopDim", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::getConvolutionInputChannelLoopDim($_op); + }] + >, + InterfaceMethod< + /*desc=*/"Return the depthwise loop of the convolution operation.", + /*retTy=*/"Optional", + /*methodName=*/"getDepthwiseLoopDim", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::getConvolutionDepthwiseLoopDim($_op); + }] + >, ]; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -406,6 +406,135 @@ } return success(); } + +Optional mlir::getConvolutionBatchLoopDim(Operation *op) { + auto linalgOp = cast(op); + auto indexingMaps = linalgOp.getIndexingMaps(); + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + llvm_unreachable("illegal input indexing map expression"); + } + // The batch dimensions are part of unconvoled dimensions in the input, and + // not present in the filter. + llvm::SmallDenseSet filterDims = getPreservedDims(indexingMaps[1]); + for (AffineExpr expr : indexingMaps.back().getResults()) { + unsigned dim = expr.cast().getPosition(); + if (inputExprWalker.unConvolvedDims.count(dim) && !filterDims.count(dim)) + return dim; + } + return llvm::None; +} + +SmallVector mlir::getConvolutionOutputImageLoopDims(Operation *op) { + auto linalgOp = cast(op); + auto indexingMaps = linalgOp.getIndexingMaps(); + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + llvm_unreachable("illegal input indexing map expression"); + } + // The output image loops are dims that are convolved in the input, and + // present in the output. + SmallVector outputImageLoopDims; + for (AffineExpr expr : indexingMaps.back().getResults()) { + unsigned dim = expr.cast().getPosition(); + if (inputExprWalker.convolvedDims.count(dim)) + outputImageLoopDims.push_back(dim); + } + return outputImageLoopDims; +} + +SmallVector mlir::getConvolutionFilterLoopDims(Operation *op) { + auto linalgOp = cast(op); + auto indexingMaps = linalgOp.getIndexingMaps(); + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + llvm_unreachable("illegal input indexing map expression"); + } + // The filter loops are dims that are convolved in the input, and present in + // the output. + SmallVector filterLoopDims; + for (AffineExpr expr : indexingMaps[1].getResults()) { + unsigned dim = expr.cast().getPosition(); + if (inputExprWalker.convolvedDims.count(dim)) + filterLoopDims.push_back(dim); + } + return filterLoopDims; +} + +Optional mlir::getConvolutionOutputChannelLoopDim(Operation *op) { + auto linalgOp = cast(op); + auto indexingMaps = linalgOp.getIndexingMaps(); + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + llvm_unreachable("illegal input indexing map expression"); + } + // The output channel dimensions are not part of convolved or unconvolved + // dimensions in the input, and present in the filter. + llvm::SmallDenseSet filterDims = getPreservedDims(indexingMaps[1]); + for (AffineExpr expr : indexingMaps.back().getResults()) { + unsigned dim = expr.cast().getPosition(); + if (!inputExprWalker.convolvedDims.count(dim) && + !inputExprWalker.unConvolvedDims.count(dim) && filterDims.count(dim)) + return dim; + } + return llvm::None; +} + +Optional mlir::getConvolutionInputChannelLoopDim(Operation *op) { + auto linalgOp = cast(op); + auto indexingMaps = linalgOp.getIndexingMaps(); + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + llvm_unreachable("illegal input indexing map expression"); + } + // The input channel dimensions are part of unconvolved dimensions in the + // input, and not present in the output. + llvm::SmallDenseSet outputDims = + getPreservedDims(indexingMaps.back()); + for (AffineExpr expr : indexingMaps[1].getResults()) { + unsigned dim = expr.cast().getPosition(); + if (inputExprWalker.unConvolvedDims.count(dim) && !outputDims.count(dim)) + return dim; + } + return llvm::None; +} + +Optional mlir::getConvolutionDepthwiseLoopDim(Operation *op) { + auto linalgOp = cast(op); + auto indexingMaps = linalgOp.getIndexingMaps(); + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + llvm_unreachable("illegal input indexing map expression"); + } + // The input channel dimensions are part of unconvolved dimensions in the + // input, and not present in the output. + llvm::SmallDenseSet filterDims = getPreservedDims(indexingMaps[1]); + for (AffineExpr expr : indexingMaps.back().getResults()) { + unsigned dim = expr.cast().getPosition(); + if (inputExprWalker.unConvolvedDims.count(dim) && filterDims.count(dim)) + return dim; + } + return llvm::None; +} + //===----------------------------------------------------------------------===// // StructuredOpInterface implementation //===----------------------------------------------------------------------===//