diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -42,6 +42,31 @@ namespace detail { +/// Result of matching a Linalg generic against the predicates of it being a +/// convolution. +enum class MatchConvolutionResult; + +/// Positions of a Linalg op loops that correspond to different kinds of a +/// convolution dimension. +struct ConvolutionDimensions { + SmallVector batch; + SmallVector outputImage; + SmallVector outputChannel; + SmallVector filterLoop; + SmallVector inputChannel; + SmallVector depth; +}; + +/// Checks whether `op` conforms to ConvolutionOpInterface and populates +/// `dimensions` with indexes of the different kinds of dimensions when present. +MatchConvolutionResult +isConvolutionInterfaceImpl(Operation *op, + ConvolutionDimensions *dimensions = nullptr); + +/// Returns the error message corresponding to the convolution checking return +/// code. +StringRef getMatchConvolutionMessage(MatchConvolutionResult res); + /// Verify that `op` conforms to ContractionOpInterface. LogicalResult verifyContractionInterface(Operation *op); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -268,6 +268,7 @@ return preservedDims; } +namespace mlir::linalg::detail { enum class MatchConvolutionResult { Success = 0, NotLinalgOp, @@ -278,8 +279,11 @@ OutputDimsNotParallel, NonOutputDimNotReduction }; +} // namespace mlir::linalg::detail -static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) { +mlir::linalg::detail::MatchConvolutionResult +mlir::linalg::detail::isConvolutionInterfaceImpl( + Operation *op, ConvolutionDimensions *dimensions) { auto linalgOp = dyn_cast(op); if (!linalgOp) return MatchConvolutionResult::NotLinalgOp; @@ -307,7 +311,7 @@ llvm::SmallDenseSet outputDims = getPreservedDims(indexingMaps.back()); llvm::SmallDenseSet filterDims = getPreservedDims(indexingMaps[1]); - // Make sure all loops are charecterized as one of: + // Make sure all loops are characterized as one of: // - Batch loop : present in output, as non-convolved in input, not present in // filter. // - Output image dimension : present in output, convolved dims in input, not @@ -329,6 +333,8 @@ if (iteratorTypes[outputDim] != utils::IteratorType::parallel) return MatchConvolutionResult::OutputDimsNotParallel; allLoopDims.insert(outputDim); + if (dimensions) + dimensions->batch.push_back(outputDim); continue; } if (inputExprWalker.convolvedDims.count(outputDim) && @@ -337,6 +343,8 @@ if (iteratorTypes[outputDim] != utils::IteratorType::parallel) return MatchConvolutionResult::OutputDimsNotParallel; allLoopDims.insert(outputDim); + if (dimensions) + dimensions->outputImage.push_back(outputDim); continue; } if (!inputExprWalker.convolvedDims.count(outputDim) && @@ -346,6 +354,8 @@ if (iteratorTypes[outputDim] != utils::IteratorType::parallel) return MatchConvolutionResult::OutputDimsNotParallel; allLoopDims.insert(outputDim); + if (dimensions) + dimensions->outputChannel.push_back(outputDim); continue; } if (inputExprWalker.unConvolvedDims.count(outputDim) && @@ -354,6 +364,8 @@ if (iteratorTypes[outputDim] != utils::IteratorType::parallel) return MatchConvolutionResult::OutputDimsNotParallel; allLoopDims.insert(outputDim); + if (dimensions) + dimensions->depth.push_back(outputDim); continue; } return MatchConvolutionResult::NonConvolutionLoop; @@ -363,7 +375,10 @@ if (outputDims.count(filterDim) && !inputExprWalker.unConvolvedDims.count(filterDim) && !inputExprWalker.convolvedDims.count(filterDim)) { - // Output channel dimension. THis is already seen, continue; + // Output channel dimension. This is already seen, continue; + assert((!dimensions || + llvm::is_contained(dimensions->outputChannel, filterDim)) && + "expected output channel to have been found from output dims"); continue; } if (inputExprWalker.convolvedDims.count(filterDim) && @@ -374,6 +389,8 @@ if (allLoopDims.count(filterDim)) return MatchConvolutionResult::NonConvolutionLoop; allLoopDims.insert(filterDim); + if (dimensions) + dimensions->filterLoop.push_back(filterDim); continue; } if (inputExprWalker.unConvolvedDims.count(filterDim) && @@ -384,11 +401,16 @@ if (allLoopDims.count(filterDim)) return MatchConvolutionResult::NonConvolutionLoop; allLoopDims.insert(filterDim); + if (dimensions) + dimensions->inputChannel.push_back(filterDim); continue; } if (inputExprWalker.unConvolvedDims.count(filterDim) && outputDims.count(filterDim)) { // Depthwise loop. Already seen. + assert( + (!dimensions || llvm::is_contained(dimensions->depth, filterDim)) && + "expected depthwise dimension to have been found from output dims"); continue; } return MatchConvolutionResult::NonConvolutionLoop; @@ -397,32 +419,45 @@ if (allLoopDims.size() != linalgOp.getNumLoops()) return MatchConvolutionResult::NonConvolutionLoop; + if (dimensions) { + assert(dimensions->batch.size() + dimensions->outputImage.size() + + dimensions->outputChannel.size() + + dimensions->filterLoop.size() + + dimensions->inputChannel.size() + dimensions->depth.size() == + linalgOp.getNumLoops() && + "expected all loops to be classified"); + } + return MatchConvolutionResult::Success; } -LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) { - auto res = isConvolutionInterfaceImpl(op); - if (res == MatchConvolutionResult::NotLinalgOp) - return op->emitError("expected a LinalgOp"); - if (res == MatchConvolutionResult::WrongNumOperands) - return op->emitError("expected op with 2 inputs and 1 output"); - if (res == MatchConvolutionResult::WrongInputIndexingMap) - return op->emitError("unexpected input index map for convolutions"); - if (res == MatchConvolutionResult::NotProjectedPermutations) { - return op->emitError( - "expected output/filter indexing maps to be projected permutations"); - } - if (res == MatchConvolutionResult::NonConvolutionLoop) { - return op->emitError("unexpected loop dimension for convolution op"); - } - if (res == MatchConvolutionResult::OutputDimsNotParallel) { - return op->emitError( - "expected all iterators used to access outputs to be parallel"); - } - if (res == MatchConvolutionResult::NonOutputDimNotReduction) { - return op->emitError( - "expected all iterators not used to access outputs to be reduction"); +StringRef +mlir::linalg::detail::getMatchConvolutionMessage(MatchConvolutionResult res) { + switch (res) { + case MatchConvolutionResult::NotLinalgOp: + return "expected a LinalgOp"; + case MatchConvolutionResult::WrongNumOperands: + return "expected op with 2 inputs and 1 output"; + case MatchConvolutionResult::WrongInputIndexingMap: + return "unexpected input index map for convolutions"; + case MatchConvolutionResult::NotProjectedPermutations: + return "expected output/filter indexing maps to be projected permutations"; + case MatchConvolutionResult::NonConvolutionLoop: + return "unexpected loop dimension for convolution op"; + case MatchConvolutionResult::OutputDimsNotParallel: + return "expected all iterators used to access outputs to be parallel"; + case MatchConvolutionResult::NonOutputDimNotReduction: + return "expected all iterators not used to access outputs to be reduction"; + case MatchConvolutionResult::Success: + return ""; } + llvm_unreachable("unhandled MatchConvolutionResult case"); +} + +LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) { + MatchConvolutionResult res = isConvolutionInterfaceImpl(op); + if (res != MatchConvolutionResult::Success) + return op->emitError(getMatchConvolutionMessage(res)); return success(); }