diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -23,6 +23,40 @@ #include "mlir/Interfaces/ViewLikeInterface.h" namespace mlir { + +/// Get the loop dimension that iterates over the batch dimension for `op` that +/// implements the `ConvolutionOpInterface`. Returns llvm::None if the `op` does +/// not implement the `ConvolutionOpInterface`, or if no batch dimensions exist. +Optional getConvolutionBatchLoopDim(Operation *op); + +/// Get the loop dimensions that iterate over the output image for `op` that +/// implements the `ConvolutionOpInterface`. Returns `{}` if the `op` does not +/// implement the `ConvolutionOpInterface`. +SmallVector getConvolutionOutputImageLoopDims(Operation *op); + +/// Get the loop dimension that iterates over the output channel dimensions for +/// `op` that implements the `ConvolutionOpInterface`. Returns llvm::None if +/// the `op` does not implement the `ConvolutionOpInterface`, or if no output +/// channel dimensions exist. +Optional getConvolutionOutputChannelLoopDim(Operation *op); + +/// Get the loop dimensions that iterate over the filter loops for `op` that +/// implements the `ConvolutionOpInterface`. Returns `{}` if the `op` does not +/// implement the `ConvolutionOpInterface`. +SmallVector getConvolutionFilterLoopDims(Operation *op); + +/// Get the loop dimension that iterates over the input channel dimensions for +/// `op` that implements the `ConvolutionOpInterface`. Returns llvm::None if +/// the `op` does not implement the `ConvolutionOpInterface`, or if no input +/// channel dimensions exist. +Optional getConvolutionInputChannelLoopDim(Operation *op); + +/// Get the loop dimension that iterates over the depthwise dimension for `op` +/// that implements the `ConvolutionOpInterface`. Returns llvm::None if the +/// `op` does not implement the `ConvolutionOpInterface`, or is not a depthwise +/// convolution. +Optional getConvolutionDepthwiseLoopDim(Operation *op); + namespace linalg { class LinalgOp; @@ -44,6 +78,9 @@ /// Verify that `op` conforms to ContractionOpInterface. LogicalResult verifyContractionInterface(Operation *op); +/// Verify that `op` conforms to the ConvolutionOpInterface. +LogicalResult verifyConvolutionInterface(Operation *op); + /// Verify that `op` conforms to the invariants of StructuredOpInterface LogicalResult verifyStructuredOpInterface(Operation *op); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -87,6 +87,101 @@ ]; } +def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> { + let description = [{ + A convolution is defined in general terms: + 1. Has an `image` and a `filter` operand. + 2. Has one `output` operand. + 3. The indexing maps of the input have expressions that satisy + AffineExpr ::== AffineDimExpr | ConvolvedExpr + ConvolvedExpr ::== MulExpr (`+` MulExpr)+ + MulExpr ::== AffineDimExpr (`*` (AffineConstantExpr | AffineSymbolExpr))? + 4. The filter and the output have projected permutation maps. + 5. Each of the loops can be qualified as one of, + - Loop over batch dimension, + - Loop over output image dimensions, + - Loop over output change dimensions, + - Loop over convolved filter dimensions, + - Loop over input channel dimension. + }]; + let cppNamespace = "::mlir::linalg"; + let verify = [{ return detail::verifyConvolutionInterface($_op); }]; + let methods = [ + InterfaceMethod< + /*desc=*/"Return the image operand.", + /*retTy=*/"Value", + /*methodName=*/"image", + /*args=*/(ins), + /*methodBody=*/[{ + return $_op.getOperation()->getOperand(0); + }] + >, + InterfaceMethod< + /*desc=*/"Return the filter operand.", + /*retTy=*/"Value", + /*methodName=*/"filter", + /*args=*/(ins), + /*methodBody=*/[{ + return $_op.getOperation()->getOperand(1); + }] + >, + InterfaceMethod< + /*desc=*/"Return the loop over batch dimensions of the convolution operation.", + /*retTy=*/"Optional", + /*methodName=*/"getBatchLoopDim", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::getConvolutionBatchLoopDim($_op); + }] + >, + InterfaceMethod< + /*desc=*/"Return the loops over output image dimensions of the convolution operation.", + /*retTy=*/"SmallVector", + /*methodName=*/"getOutputImageLoopDims", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::getConvolutionOutputImageLoopDims($_op); + }] + >, + InterfaceMethod< + /*desc=*/"Return the loop over output channel dimensions of the convolution operation.", + /*retTy=*/"Optional", + /*methodName=*/"getOutputChannelLoopDim", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::getConvolutionOutputChannelLoopDim($_op); + }] + >, + InterfaceMethod< + /*desc=*/"Return the loop over filter dimensions of the convolution operation.", + /*retTy=*/"SmallVector", + /*methodName=*/"getFilterLoopDims", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::getConvolutionFilterLoopDims($_op); + }] + >, + InterfaceMethod< + /*desc=*/"Return the loop over input channel dimensions of the convolution operation.", + /*retTy=*/"Optional", + /*methodName=*/"getInputChannelLoopDim", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::getConvolutionInputChannelLoopDim($_op); + }] + >, + InterfaceMethod< + /*desc=*/"Return the depthwise loop of the convolution operation.", + /*retTy=*/"Optional", + /*methodName=*/"getDepthwiseLoopDim", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::getConvolutionDepthwiseLoopDim($_op); + }] + >, + ]; +} + // The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface. def LinalgStructuredInterface : OpInterface<"LinalgOp"> { let cppNamespace = "::mlir::linalg"; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -636,6 +636,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -695,6 +697,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -756,6 +760,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -820,6 +826,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -898,6 +906,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -985,6 +995,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. This includes the zero point offsets common to quantized operations. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1103,6 +1115,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1185,6 +1199,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1272,6 +1288,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. Multiplier is set to 1 which is a special case for most dpethwise convolutions. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1350,6 +1368,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1460,6 +1480,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1542,6 +1564,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1656,6 +1680,8 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1724,6 +1750,8 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1792,6 +1820,8 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1860,6 +1890,8 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1928,6 +1960,8 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -2002,6 +2036,8 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -2076,6 +2112,8 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -25,8 +25,8 @@ // first operands. These may be optionally followed by non-view operands // depending on the specific Linalg op. class LinalgStructuredBase_Op props> - : Op { + : Op { code structuredOpsBaseDecls = [{ // Return whether the op accesses the iteration indices. bool hasIndexSemantics() { diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -89,7 +89,7 @@ return success; } -enum MatchContractionResult { +enum class MatchContractionResult { Success = 0, NotLinalgOp, WrongNumOperands, @@ -152,6 +152,388 @@ return success(); } +//===----------------------------------------------------------------------===// +// ConvolutionOpInterface implementation +//===----------------------------------------------------------------------===// + +/// Of the given two expressions returns one that is of type T (`lhs` gets +/// preference over `rhs`) +template +static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs) { + return lhs.isa() ? lhs.cast() + : (rhs.isa() ? rhs.cast() : nullptr); +} + +namespace { +/// Walk the indexing expressions for input of a convolution operation to verify +/// its of the right form, either +/// - AffineDimExpr +/// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))? +/// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)* +/// +/// classifies the AffineDimExpr as convolved dimensions or unconvolved +/// dimensions and verifies each dimension occurs only once. +struct ConvAccessExprWalker + : public AffineExprVisitor { + llvm::SmallDenseSet convolvedDims; + llvm::SmallDenseSet unConvolvedDims; + + LogicalResult visitDimExpr(AffineDimExpr dimExpr) { + unsigned position = dimExpr.getPosition(); + if (unConvolvedDims.count(position) || convolvedDims.count(position)) { + return failure(); + } + unConvolvedDims.insert(position); + return success(); + } + + LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); } + + LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); } + + LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) { + // In pre-order visit, top level op has to be an add op. + if (binaryExpr.getKind() != AffineExprKind::Add) + return failure(); + + if (succeeded(isDimExprOrMulExpr(binaryExpr.getLHS())) && + succeeded(isDimExprOrMulExpr(binaryExpr.getRHS()))) { + return success(); + } + return failure(); + } + + LogicalResult isDimExprOrMulExpr(AffineExpr expr) { + if (auto dimExpr = expr.dyn_cast()) { + unsigned dim = dimExpr.getPosition(); + if (convolvedDims.count(dim) || unConvolvedDims.count(dim)) + return failure(); + convolvedDims.insert(dim); + return success(); + } + if (auto symbolMulExpr = expr.dyn_cast()) { + if (symbolMulExpr.getKind() != AffineExprKind::Mul) + return failure(); + auto lhsExpr = symbolMulExpr.getLHS(); + auto rhsExpr = symbolMulExpr.getRHS(); + // Check for symbol expression. + AffineExpr mulExpr = + getAffineExprOfType(lhsExpr, rhsExpr); + // If there was no symbol expr, check for constant expression. + if (!mulExpr) { + mulExpr = getAffineExprOfType(lhsExpr, rhsExpr); + } + auto dimExpr = getAffineExprOfType(lhsExpr, rhsExpr); + if (!mulExpr || !dimExpr) + return failure(); + unsigned dim = dimExpr.getPosition(); + if (convolvedDims.count(dim) || unConvolvedDims.count(dim)) + return failure(); + convolvedDims.insert(dim); + return success(); + } + return failure(); + } +}; +} // namespace + +static llvm::SmallDenseSet getPreservedDims(AffineMap map) { + assert(map.isProjectedPermutation() && + "expected map to have projected permutations"); + llvm::SmallDenseSet preservedDims; + for (auto expr : map.getResults()) + preservedDims.insert(expr.cast().getPosition()); + return preservedDims; +} + +enum class MatchConvolutionResult { + Success = 0, + NotLinalgOp, + WrongNumOperands, + WrongInputIndexingMap, + NotProjectedPermutations, + NonConvolutionLoop, + OutputDimsNotParallel, + NonOutputDimNotReduction +}; + +static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) { + auto linalgOp = dyn_cast(op); + if (!linalgOp) + return MatchConvolutionResult::NotLinalgOp; + if (linalgOp.getNumInputs() < 2 || linalgOp.getNumOutputs() != 1) + return MatchConvolutionResult::WrongNumOperands; + + auto indexingMaps = linalgOp.getIndexingMaps(); + + // Check the input indexing map has the right form. + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + return MatchConvolutionResult::WrongInputIndexingMap; + } + + // Filter and output maps must be projected permutation. + if (!indexingMaps[1].isProjectedPermutation() || + !indexingMaps.back().isProjectedPermutation()) + return MatchConvolutionResult::NotProjectedPermutations; + + auto iteratorTypesRange = + linalgOp.iterator_types().getAsValueRange(); + + llvm::SmallDenseSet outputDims = + getPreservedDims(indexingMaps.back()); + llvm::SmallDenseSet filterDims = getPreservedDims(indexingMaps[1]); + // Make sure all loops are charecterized. + llvm::SmallDenseSet allLoopDims; + // Batch loop : present in output, as non-convolved in input, not present in + // filter. Output image dimension : present in output, convolved dims in + // input, not present in filter. Output channel dimension : present in output, + // not present in input, present in filter. Filter loop dimension : present in + // filter, convolved in input, not present in output. Input channel dimension + // : unconvolved in input, not present in output, present in filter. depthwise + // dimension : unconvolved in input, present in output, present in filter. + for (auto outputExpr : indexingMaps.back().getResults()) { + unsigned outputDim = outputExpr.cast().getPosition(); + if (inputExprWalker.unConvolvedDims.count(outputDim) && + !filterDims.count(outputDim)) { + // Batch dimension. + if (*std::next(iteratorTypesRange.begin(), outputDim) != + getParallelIteratorTypeName()) + return MatchConvolutionResult::OutputDimsNotParallel; + allLoopDims.insert(outputDim); + continue; + } + if (inputExprWalker.convolvedDims.count(outputDim) && + !filterDims.count(outputDim)) { + // Output image Loop dimension. + if (*std::next(iteratorTypesRange.begin(), outputDim) != + getParallelIteratorTypeName()) + return MatchConvolutionResult::OutputDimsNotParallel; + allLoopDims.insert(outputDim); + continue; + } + if (!inputExprWalker.convolvedDims.count(outputDim) && + !inputExprWalker.unConvolvedDims.count(outputDim) && + filterDims.count(outputDim)) { + // Output channel dimension. + if (*std::next(iteratorTypesRange.begin(), outputDim) != + getParallelIteratorTypeName()) + return MatchConvolutionResult::OutputDimsNotParallel; + allLoopDims.insert(outputDim); + continue; + } + if (inputExprWalker.unConvolvedDims.count(outputDim) && + filterDims.count(outputDim)) { + // Depthwise loop. + if (*std::next(iteratorTypesRange.begin(), outputDim) != + getParallelIteratorTypeName()) + return MatchConvolutionResult::OutputDimsNotParallel; + allLoopDims.insert(outputDim); + continue; + } + return MatchConvolutionResult::NonConvolutionLoop; + } + for (auto filterExpr : indexingMaps[1].getResults()) { + unsigned filterDim = filterExpr.cast().getPosition(); + if (outputDims.count(filterDim) && + !inputExprWalker.unConvolvedDims.count(filterDim) && + !inputExprWalker.convolvedDims.count(filterDim)) { + // Output channel dimension. THis is already seen, continue; + continue; + } + if (inputExprWalker.convolvedDims.count(filterDim) && + !outputDims.count(filterDim)) { + // Filter loop dimension. + if (*std::next(iteratorTypesRange.begin(), filterDim) != + getReductionIteratorTypeName()) + return MatchConvolutionResult::NonOutputDimNotReduction; + if (allLoopDims.count(filterDim)) + return MatchConvolutionResult::NonConvolutionLoop; + allLoopDims.insert(filterDim); + continue; + } + if (inputExprWalker.unConvolvedDims.count(filterDim) && + !outputDims.count(filterDim)) { + // Input channel dimension. + if (*std::next(iteratorTypesRange.begin(), filterDim) != + getReductionIteratorTypeName()) + return MatchConvolutionResult::NonOutputDimNotReduction; + if (allLoopDims.count(filterDim)) + return MatchConvolutionResult::NonConvolutionLoop; + allLoopDims.insert(filterDim); + continue; + } + if (inputExprWalker.unConvolvedDims.count(filterDim) && + outputDims.count(filterDim)) { + // Depthwise loop. Already seen. + continue; + } + return MatchConvolutionResult::NonConvolutionLoop; + } + // All loops must be covered now. + if (allLoopDims.size() != linalgOp.getNumLoops()) + return MatchConvolutionResult::NonConvolutionLoop; + + return MatchConvolutionResult::Success; +} + +LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) { + auto res = isConvolutionInterfaceImpl(op); + if (res == MatchConvolutionResult::NotLinalgOp) + return op->emitError("expected a LinalgOp"); + if (res == MatchConvolutionResult::WrongNumOperands) + return op->emitError("expected op with 2 inputs and 1 output"); + if (res == MatchConvolutionResult::WrongInputIndexingMap) + return op->emitError("unexpected input index map for convolutions"); + if (res == MatchConvolutionResult::NotProjectedPermutations) { + return op->emitError( + "expected output/filter indexing maps to be projected permutations"); + } + if (res == MatchConvolutionResult::NonConvolutionLoop) { + return op->emitError("unexpected dimension for convolution op"); + } + if (res == MatchConvolutionResult::OutputDimsNotParallel) { + return op->emitError( + "expected all iterators used to access outputs to be parallel"); + } + if (res == MatchConvolutionResult::NonOutputDimNotReduction) { + return op->emitError( + "expected all iterators not used to access outputs to be reduction"); + } + return success(); +} + +Optional mlir::getConvolutionBatchLoopDim(Operation *op) { + auto linalgOp = cast(op); + auto indexingMaps = linalgOp.getIndexingMaps(); + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + llvm_unreachable("illegal input indexing map expression"); + } + // The batch dimensions are part of unconvoled dimensions in the input, and + // not present in the filter. + llvm::SmallDenseSet filterDims = getPreservedDims(indexingMaps[1]); + for (AffineExpr expr : indexingMaps.back().getResults()) { + unsigned dim = expr.cast().getPosition(); + if (inputExprWalker.unConvolvedDims.count(dim) && !filterDims.count(dim)) + return dim; + } + return llvm::None; +} + +SmallVector mlir::getConvolutionOutputImageLoopDims(Operation *op) { + auto linalgOp = cast(op); + auto indexingMaps = linalgOp.getIndexingMaps(); + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + llvm_unreachable("illegal input indexing map expression"); + } + // The output image loops are dims that are convolved in the input, and + // present in the output. + SmallVector outputImageLoopDims; + for (AffineExpr expr : indexingMaps.back().getResults()) { + unsigned dim = expr.cast().getPosition(); + if (inputExprWalker.convolvedDims.count(dim)) + outputImageLoopDims.push_back(dim); + } + return outputImageLoopDims; +} + +SmallVector mlir::getConvolutionFilterLoopDims(Operation *op) { + auto linalgOp = cast(op); + auto indexingMaps = linalgOp.getIndexingMaps(); + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + llvm_unreachable("illegal input indexing map expression"); + } + // The filter loops are dims that are convolved in the input, and present in + // the output. + SmallVector filterLoopDims; + for (AffineExpr expr : indexingMaps[1].getResults()) { + unsigned dim = expr.cast().getPosition(); + if (inputExprWalker.convolvedDims.count(dim)) + filterLoopDims.push_back(dim); + } + return filterLoopDims; +} + +Optional mlir::getConvolutionOutputChannelLoopDim(Operation *op) { + auto linalgOp = cast(op); + auto indexingMaps = linalgOp.getIndexingMaps(); + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + llvm_unreachable("illegal input indexing map expression"); + } + // The output channel dimensions are not part of convolved or unconvolved + // dimensions in the input, and present in the filter. + llvm::SmallDenseSet filterDims = getPreservedDims(indexingMaps[1]); + for (AffineExpr expr : indexingMaps.back().getResults()) { + unsigned dim = expr.cast().getPosition(); + if (!inputExprWalker.convolvedDims.count(dim) && + !inputExprWalker.unConvolvedDims.count(dim) && filterDims.count(dim)) + return dim; + } + return llvm::None; +} + +Optional mlir::getConvolutionInputChannelLoopDim(Operation *op) { + auto linalgOp = cast(op); + auto indexingMaps = linalgOp.getIndexingMaps(); + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + llvm_unreachable("illegal input indexing map expression"); + } + // The input channel dimensions are part of unconvolved dimensions in the + // input, and not present in the output. + llvm::SmallDenseSet outputDims = + getPreservedDims(indexingMaps.back()); + for (AffineExpr expr : indexingMaps[1].getResults()) { + unsigned dim = expr.cast().getPosition(); + if (inputExprWalker.unConvolvedDims.count(dim) && !outputDims.count(dim)) + return dim; + } + return llvm::None; +} + +Optional mlir::getConvolutionDepthwiseLoopDim(Operation *op) { + auto linalgOp = cast(op); + auto indexingMaps = linalgOp.getIndexingMaps(); + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + llvm_unreachable("illegal input indexing map expression"); + } + // The input channel dimensions are part of unconvolved dimensions in the + // input, and not present in the output. + llvm::SmallDenseSet filterDims = getPreservedDims(indexingMaps[1]); + for (AffineExpr expr : indexingMaps.back().getResults()) { + unsigned dim = expr.cast().getPosition(); + if (inputExprWalker.unConvolvedDims.count(dim) && filterDims.count(dim)) + return dim; + } + return llvm::None; +} + //===----------------------------------------------------------------------===// // StructuredOpInterface implementation //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -484,7 +484,7 @@ ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface") - +ConvolutionOpInterface = OpInterfaceDef("LinalgConvolutionOpInterface") class OpMetadataDef(YAMLObject): """Metadata about the op (generally not behavior impacting).""" diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -154,6 +154,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.ow, D.kw) O[D.ow] += cast( U, I[D.ow + D.kw]) * cast(U, K[D.kw]) @@ -168,6 +169,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.oh, D.ow, D.kh, D.kw) O[D.oh, D.ow] += cast( U, I[D.oh + D.kh, D.ow + D.kw]) * cast(U, K[D.kh, D.kw]) @@ -182,6 +184,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw) O[D.od, D.oh, D.ow] += cast( U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]) * cast(U, K[D.kd, D.kh, D.kw]) @@ -198,6 +201,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.ow, D.f, D.kw, D.c) O[D.n, D.ow, D.f] += cast( U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c @@ -219,6 +223,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.f] += cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c @@ -243,6 +248,7 @@ them to the same data type as the accumulator/output. This includes the zero point offsets common to quantized operations. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.f] += (cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c @@ -264,6 +270,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) O[D.n, D.f, D.oh, D.ow] += cast( U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW @@ -282,6 +289,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) O[D.n, D.od, D.oh, D.ow, D.f] += cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c @@ -300,6 +308,7 @@ them to the same data type as the accumulator/output. Multiplier is set to 1 which is a special case for most dpethwise convolutions. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) O[D.n, D.oh, D.ow, D.ic] += cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, @@ -319,6 +328,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) O[D.n, D.oh, D.ow, D.ic] += ( (cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, @@ -337,6 +347,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) O[D.n, D.oh, D.ow, D.ic, D.cm] += cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, @@ -356,6 +367,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) O[D.n, D.oh, D.ow, D.ic, D.cm] += ( (cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, @@ -375,6 +387,7 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] += cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) @@ -392,6 +405,7 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)( cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, @@ -409,6 +423,7 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) O[D.n, D.c, D.oh, D.ow] = ReduceFn.max(D.kh, D.kw)( cast(U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, @@ -426,6 +441,7 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)( cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, @@ -445,6 +461,7 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) O[D.n, D.od, D.oh, D.ow, D.c] += cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, @@ -464,6 +481,7 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max(D.kd, D.kh, D.kw)( cast( @@ -484,6 +502,7 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min(D.kd, D.kh, D.kw)( cast( diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -160,9 +160,9 @@ func @generic_empty_region(%arg0: memref) { %f0 = constant 0.0: f32 - // expected-error @+1 {{op expects region #0 to have 0 or 1 blocks}} + // expected-error @+1 {{op expected 1 region with 1 block}} linalg.generic { - indexing_maps = [ affine_map<() -> (0)> ], + indexing_maps = [ affine_map<() -> ()>, affine_map<() -> ()> ], iterator_types = []} ins(%arg0 : memref) outs(%arg0 : memref) { @@ -275,8 +275,8 @@ // expected-error @+2 {{op expects regions to end with 'linalg.yield', found 'std.addf'}} // expected-note @+1 {{in custom textual format, the absence of terminator implies 'linalg.yield'}} linalg.generic { - indexing_maps = [ affine_map<(i) -> (i)> ], - iterator_types = ["parallel"]} + indexing_maps = [ affine_map<(i, j) -> (i, j)> ], + iterator_types = ["parallel", "parallel"]} outs(%arg0 : memref) { ^bb(%0: i4) : %1 = std.addf %0, %0: i4