diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -23,6 +23,29 @@ #include "mlir/Interfaces/ViewLikeInterface.h" namespace mlir { + +/// Get the iterator indices that index into the output window dimensions for +/// `op` that implements the `ConvolutionOpInterface`. Returns `{}` if the `op` +/// does not implement the `ConvolutionOpInterface`. +SmallVector getConvolutionOutputWindowIteratorIndices(Operation *op); + +/// Get the iterator indices that index into the filter window dimensions for +/// `op` that implements the `ConvolutionOpInterface`. Returns `{}` if the `op` +/// does not implement the `ConvolutionOpInterface`. +SmallVector getConvolutionFilterWindowIteratorIndices(Operation *op); + +/// Get the iterator index that indexes the input reduction dimension for +/// `op` that implements the `ConvolutionOpInterface`. Returns llvm::None if +/// the `op` does not implement the `ConvolutionOpInterface`, or if no input +/// reduction dimensions exist. +Optional getConvolutionInputReductionIteratorIndex(Operation *op); + +/// Get the iterator index that indexes the filter parallel dimension for `op` +/// that implements the `ConvolutionOpInterface`. Returns llvm::None if the +/// `op` does not implement the `ConvolutionOpInterface`, or does not have a +/// standalone parallel dimension. +Optional getConvolutionFilterParallelIteratorIndex(Operation *op); + namespace linalg { class LinalgOp; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -100,11 +100,11 @@ ``` 4. The filter and the output have projected permutation maps. 5. Each of the loops can be qualified as one of, - - Loop over batch dimension, - - Loop over output image dimensions, - - Loop over output channel dimensions, - - Loop over convolved filter dimensions, - - Loop over input channel dimension. + - Loop over output parallel dimensions (e.g., batch, output channel), + - Loop over output window dimensions (e.g., output height/width), + - Loop over input reduction dimension (e.g., input channel), + - Loop over convolved filter parallel dimensions (e.g., filter channel multiplier), + - Loop over convolved filter window dimensions (e.g., filter height/width). }]; let cppNamespace = "::mlir::linalg"; let verify = [{ return detail::verifyConvolutionInterface($_op); }]; @@ -129,6 +129,78 @@ return $_op.getOperation()->getOperand(1); }] >, + InterfaceMethod< + /*desc=*/"Return the output operand.", + /*retTy=*/"Value", + /*methodName=*/"output", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getOperation()->getOperand(2); + }] + >, + InterfaceMethod< + /*desc=*/"Return the strides.", + /*retTy=*/"DenseIntElementsAttr", + /*methodName=*/"strides", + /*args=*/(ins), + /*methodBody=*/[{ + if (auto strides = $_op->template + getAttrOfType("strides")) { + return strides; + } + return nullptr; + }] + >, + InterfaceMethod< + /*desc=*/"Return the dilations.", + /*retTy=*/"DenseIntElementsAttr", + /*methodName=*/"dilations", + /*args=*/(ins), + /*methodBody=*/[{ + if (auto dilations = $_op->template + getAttrOfType("dilations")) { + return dilations; + } + return nullptr; + }] + >, + InterfaceMethod< + /*desc=*/"Return the indices for the iterators that index into output window dimensions of the convolution operation.", + /*retTy=*/"SmallVector", + /*methodName=*/"getOutputWindowIteratorIndices", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::getConvolutionOutputWindowIteratorIndices($_op); + }] + >, + InterfaceMethod< + /*desc=*/"Return the indices for the iterators that index into filter window dimensions of the convolution operation.", + /*retTy=*/"SmallVector", + /*methodName=*/"getFilterWindowIteratorIndices", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::getConvolutionFilterWindowIteratorIndices($_op); + }] + >, + InterfaceMethod< + /*desc=*/"Return the index for the iterator that index input reduction dimensions of the convolution operation.", + /*retTy=*/"Optional", + /*methodName=*/"getInputRecutionIteratorIndex", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::getConvolutionInputReductionIteratorIndex($_op); + }] + >, + InterfaceMethod< + /*desc=*/"Return the index for the iterator that index filter parallel dimensions of the convolution operation.", + /*retTy=*/"Optional", + /*methodName=*/"getFilterParallelIteratorIndex", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::getConvolutionFilterParallelIteratorIndex($_op); + }] + >, ]; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -406,6 +406,96 @@ } return success(); } + +SmallVector +mlir::getConvolutionOutputWindowIteratorIndices(Operation *op) { + auto linalgOp = cast(op); + auto indexingMaps = linalgOp.getIndexingMaps(); + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + llvm_unreachable("illegal input indexing map expression"); + } + // The output image loops are dims that are convolved in the input, and + // present in the output. + SmallVector outputImageIteratorIndices; + for (AffineExpr expr : indexingMaps.back().getResults()) { + unsigned dim = expr.cast().getPosition(); + if (inputExprWalker.convolvedDims.count(dim)) + outputImageIteratorIndices.push_back(dim); + } + return outputImageIteratorIndices; +} + +SmallVector +mlir::getConvolutionFilterWindowIteratorIndices(Operation *op) { + auto linalgOp = cast(op); + auto indexingMaps = linalgOp.getIndexingMaps(); + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + llvm_unreachable("illegal input indexing map expression"); + } + // The filter loops are dims that are convolved in the input, and present in + // the output. + SmallVector filterIteratorIndices; + for (AffineExpr expr : indexingMaps[1].getResults()) { + unsigned dim = expr.cast().getPosition(); + if (inputExprWalker.convolvedDims.count(dim)) + filterIteratorIndices.push_back(dim); + } + return filterIteratorIndices; +} + +Optional +mlir::getConvolutionInputReductionIteratorIndex(Operation *op) { + auto linalgOp = cast(op); + auto indexingMaps = linalgOp.getIndexingMaps(); + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + llvm_unreachable("illegal input indexing map expression"); + } + // The input channel dimensions are part of unconvolved dimensions in the + // input, and not present in the output. + llvm::SmallDenseSet outputDims = + getPreservedDims(indexingMaps.back()); + for (AffineExpr expr : indexingMaps[1].getResults()) { + unsigned dim = expr.cast().getPosition(); + if (inputExprWalker.unConvolvedDims.count(dim) && !outputDims.count(dim)) + return dim; + } + return llvm::None; +} + +Optional +mlir::getConvolutionFilterParallelIteratorIndex(Operation *op) { + auto linalgOp = cast(op); + auto indexingMaps = linalgOp.getIndexingMaps(); + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + llvm_unreachable("illegal input indexing map expression"); + } + // The input channel dimensions are part of unconvolved dimensions in the + // input, and not present in the output. + llvm::SmallDenseSet filterDims = getPreservedDims(indexingMaps[1]); + for (AffineExpr expr : indexingMaps.back().getResults()) { + unsigned dim = expr.cast().getPosition(); + if (inputExprWalker.unConvolvedDims.count(dim) && filterDims.count(dim)) + return dim; + } + return llvm::None; +} + //===----------------------------------------------------------------------===// // StructuredOpInterface implementation //===----------------------------------------------------------------------===//