diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -23,6 +23,40 @@ #include "mlir/Interfaces/ViewLikeInterface.h" namespace mlir { + +/// Get the loop dimension that iterates over the batch dimension for `op` that +/// implements the `ConvolutionOpInterface`. Returns llvm::None if the `op` does +/// not implement the `ConvolutionOpInterface`, or if no batch dimensions exist. +Optional getConvolutionBatchLoopDim(Operation *op); + +/// Get the loop dimensions that iterate over the output image for `op` that +/// implements the `ConvolutionOpInterface`. Returns `{}` if the `op` does not +/// implement the `ConvolutionOpInterface`. +SmallVector getConvolutionOutputImageLoopDims(Operation *op); + +/// Get the loop dimension that iterates over the output channel dimensions for +/// `op` that implements the `ConvolutionOpInterface`. Returns llvm::None if +/// the `op` does not implement the `ConvolutionOpInterface`, or if no output +/// channel dimensions exist. +Optional getConvolutionOutputChannelLoopDim(Operation *op); + +/// Get the loop dimensions that iterate over the filter loops for `op` that +/// implements the `ConvolutionOpInterface`. Returns `{}` if the `op` does not +/// implement the `ConvolutionOpInterface`. +SmallVector getConvolutionFilterLoopDims(Operation *op); + +/// Get the loop dimension that iterates over the input channel dimensions for +/// `op` that implements the `ConvolutionOpInterface`. Returns llvm::None if +/// the `op` does not implement the `ConvolutionOpInterface`, or if no input +/// channel dimensions exist. +Optional getConvolutionInputChannelLoopDim(Operation *op); + +/// Get the loop dimension that iterates over the depthwise dimension for `op` +/// that implements the `ConvolutionOpInterface`. Returns llvm::None if the +/// `op` does not implement the `ConvolutionOpInterface`, or is not a depthwise +/// convolution. +Optional getConvolutionDepthwiseLoopDim(Operation *op); + namespace linalg { class LinalgOp; @@ -44,6 +78,9 @@ /// Verify that `op` conforms to ContractionOpInterface. LogicalResult verifyContractionInterface(Operation *op); +/// Verify that `op` conforms to the ConvolutionOpInterface. +LogicalResult verifyConvolutionInterface(Operation *op); + /// Verify that `op` conforms to the invariants of StructuredOpInterface LogicalResult verifyStructuredOpInterface(Operation *op); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -87,6 +87,88 @@ ]; } +def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> { + let description = [{ + TODO + }]; + let cppNamespace = "::mlir::linalg"; + let verify = [{ return detail::verifyConvolutionInterface($_op); }]; + let methods = [ + InterfaceMethod< + /*desc=*/"Return the image operand.", + /*retTy=*/"Value", + /*methodName=*/"image", + /*args=*/(ins), + /*methodBody=*/[{ + return $_op.getOperation()->getOperand(0); + }] + >, + InterfaceMethod< + /*desc=*/"Return the filter operand.", + /*retTy=*/"Value", + /*methodName=*/"filter", + /*args=*/(ins), + /*methodBody=*/[{ + return $_op.getOperation()->getOperand(1); + }] + >, + InterfaceMethod< + /*desc=*/"Return the loop over batch dimensions of the convolution operation.", + /*retTy=*/"Optional", + /*methodName=*/"getBatchLoopDim", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::getConvolutionBatchLoopDim($_op); + }] + >, + InterfaceMethod< + /*desc=*/"Return the loops over output image dimensions of the convolution operation.", + /*retTy=*/"SmallVector", + /*methodName=*/"getOutputImageLoopDims", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::getConvolutionOutputImageLoopDims($_op); + }] + >, + InterfaceMethod< + /*desc=*/"Return the loop over output channel dimensions of the convolution operation.", + /*retTy=*/"Optional", + /*methodName=*/"getOutputChannelLoopDim", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::getConvolutionOutputChannelLoopDim($_op); + }] + >, + InterfaceMethod< + /*desc=*/"Return the loop over filter dimensions of the convolution operation.", + /*retTy=*/"SmallVector", + /*methodName=*/"getFilterLoopDims", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::getConvolutionFilterLoopDims($_op); + }] + >, + InterfaceMethod< + /*desc=*/"Return the loop over input channel dimensions of the convolution operation.", + /*retTy=*/"Optional", + /*methodName=*/"getInputChannelLoopDim", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::getConvolutionInputChannelLoopDim($_op); + }] + >, + InterfaceMethod< + /*desc=*/"Return the depthwise loop of the convolution operation.", + /*retTy=*/"Optional", + /*methodName=*/"getDepthwiseLoopDim", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::getConvolutionDepthwiseLoopDim($_op); + }] + >, + ]; +} + // The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface. def LinalgStructuredInterface : OpInterface<"LinalgOp"> { let cppNamespace = "::mlir::linalg"; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -636,6 +636,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -695,6 +697,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -756,6 +760,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -822,6 +828,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -901,6 +909,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -988,6 +998,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. This includes the zero point offsets common to quantized operations. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1106,6 +1118,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1188,6 +1202,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1274,6 +1290,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. Multiplier is set to 1 which is a special case for most dpethwise convolutions. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1355,6 +1373,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1468,6 +1488,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1550,6 +1572,8 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1664,6 +1688,8 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1735,6 +1761,8 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1806,6 +1834,8 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1877,6 +1907,8 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1948,6 +1980,8 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -2022,6 +2056,8 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -2096,6 +2132,8 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -89,7 +89,7 @@ return success; } -enum MatchContractionResult { +enum class MatchContractionResult { Success = 0, NotLinalgOp, WrongNumOperands, @@ -152,6 +152,388 @@ return success(); } +//===----------------------------------------------------------------------===// +// ConvolutionOpInterface implementation +//===----------------------------------------------------------------------===// + +/// Of the given two expressions returns one that is of type T (`lhs` gets +/// preference over `rhs`) +template +static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs) { + return lhs.isa() ? lhs.cast() + : (rhs.isa() ? rhs.cast() : nullptr); +} + +namespace { +/// Walk the indexing expressions for input of a convolution operation to verify +/// its of the right form, either +/// - AffineDimExpr +/// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))? +/// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)* +/// +/// classifies the AffineDimExpr as convolved dimensions or unconvolved +/// dimensions and verifies each dimension occurs only once. +struct ConvAccessExprWalker + : public AffineExprVisitor { + llvm::SmallDenseSet convolvedDims; + llvm::SmallDenseSet unConvolvedDims; + + LogicalResult visitDimExpr(AffineDimExpr dimExpr) { + unsigned position = dimExpr.getPosition(); + if (unConvolvedDims.count(position) || convolvedDims.count(position)) { + return failure(); + } + unConvolvedDims.insert(position); + return success(); + } + + LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); } + + LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); } + + LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) { + // In pre-order visit, top level op has to be an add op. + if (binaryExpr.getKind() != AffineExprKind::Add) + return failure(); + + if (succeeded(isDimExprOrMulExpr(binaryExpr.getLHS())) && + succeeded(isDimExprOrMulExpr(binaryExpr.getRHS()))) { + return success(); + } + return failure(); + } + + LogicalResult isDimExprOrMulExpr(AffineExpr expr) { + if (auto dimExpr = expr.dyn_cast()) { + unsigned dim = dimExpr.getPosition(); + if (convolvedDims.count(dim) || unConvolvedDims.count(dim)) + return failure(); + convolvedDims.insert(dim); + return success(); + } + if (auto symbolMulExpr = expr.dyn_cast()) { + if (symbolMulExpr.getKind() != AffineExprKind::Mul) + return failure(); + auto lhsExpr = symbolMulExpr.getLHS(); + auto rhsExpr = symbolMulExpr.getRHS(); + // Check for symbol expression. + AffineExpr mulExpr = + getAffineExprOfType(lhsExpr, rhsExpr); + // If there was no symbol expr, check for constant expression. + if (!mulExpr) { + mulExpr = getAffineExprOfType(lhsExpr, rhsExpr); + } + auto dimExpr = getAffineExprOfType(lhsExpr, rhsExpr); + if (!mulExpr || !dimExpr) + return failure(); + unsigned dim = dimExpr.getPosition(); + if (convolvedDims.count(dim) || unConvolvedDims.count(dim)) + return failure(); + convolvedDims.insert(dim); + return success(); + } + return failure(); + } +}; +} // namespace + +static llvm::SmallDenseSet getPreservedDims(AffineMap map) { + assert(map.isProjectedPermutation() && + "expected map to have projected permutations"); + llvm::SmallDenseSet preservedDims; + for (auto expr : map.getResults()) + preservedDims.insert(expr.cast().getPosition()); + return preservedDims; +} + +enum class MatchConvolutionResult { + Success = 0, + NotLinalgOp, + WrongNumOperands, + WrongInputIndexingMap, + NotProjectedPermutations, + NonConvolutionLoop, + OutputDimsNotParallel, + NonOutputDimNotReduction +}; + +static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) { + auto linalgOp = dyn_cast(op); + if (!linalgOp) + return MatchConvolutionResult::NotLinalgOp; + if (linalgOp.getNumInputs() < 2 || linalgOp.getNumOutputs() != 1) + return MatchConvolutionResult::WrongNumOperands; + + auto indexingMaps = linalgOp.getIndexingMaps(); + + // Check the input indexing map has the right form. + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + return MatchConvolutionResult::WrongInputIndexingMap; + } + + // Filter and output maps must be projected permutation. + if (!indexingMaps[1].isProjectedPermutation() || + !indexingMaps.back().isProjectedPermutation()) + return MatchConvolutionResult::NotProjectedPermutations; + + auto iteratorTypesRange = + linalgOp.iterator_types().getAsValueRange(); + + llvm::SmallDenseSet outputDims = + getPreservedDims(indexingMaps.back()); + llvm::SmallDenseSet filterDims = getPreservedDims(indexingMaps[1]); + // Make sure all loops are charecterized. + llvm::SmallDenseSet allLoopDims; + // Batch loop : present in output, as non-convolved in input, not present in + // filter. Output image dimension : present in output, convolved dims in + // input, not present in filter. Output channel dimension : present in output, + // not present in input, present in filter. Filter loop dimension : present in + // filter, convolved in input, not present in output. Input channel dimension + // : unconvolved in input, not present in output, present in filter. depthwise + // dimension : unconvolved in input, present in output, present in filter. + for (auto outputExpr : indexingMaps.back().getResults()) { + unsigned outputDim = outputExpr.cast().getPosition(); + if (inputExprWalker.unConvolvedDims.count(outputDim) && + !filterDims.count(outputDim)) { + // Batch dimension. + if (*std::next(iteratorTypesRange.begin(), outputDim) != + getParallelIteratorTypeName()) + return MatchConvolutionResult::OutputDimsNotParallel; + allLoopDims.insert(outputDim); + continue; + } + if (inputExprWalker.convolvedDims.count(outputDim) && + !filterDims.count(outputDim)) { + // Output image Loop dimension. + if (*std::next(iteratorTypesRange.begin(), outputDim) != + getParallelIteratorTypeName()) + return MatchConvolutionResult::OutputDimsNotParallel; + allLoopDims.insert(outputDim); + continue; + } + if (!inputExprWalker.convolvedDims.count(outputDim) && + !inputExprWalker.unConvolvedDims.count(outputDim) && + filterDims.count(outputDim)) { + // Output channel dimension. + if (*std::next(iteratorTypesRange.begin(), outputDim) != + getParallelIteratorTypeName()) + return MatchConvolutionResult::OutputDimsNotParallel; + allLoopDims.insert(outputDim); + continue; + } + if (inputExprWalker.unConvolvedDims.count(outputDim) && + filterDims.count(outputDim)) { + // Depthwise loop. + if (*std::next(iteratorTypesRange.begin(), outputDim) != + getParallelIteratorTypeName()) + return MatchConvolutionResult::OutputDimsNotParallel; + allLoopDims.insert(outputDim); + continue; + } + return MatchConvolutionResult::NonConvolutionLoop; + } + for (auto filterExpr : indexingMaps[1].getResults()) { + unsigned filterDim = filterExpr.cast().getPosition(); + if (outputDims.count(filterDim) && + !inputExprWalker.unConvolvedDims.count(filterDim) && + !inputExprWalker.convolvedDims.count(filterDim)) { + // Output channel dimension. THis is already seen, continue; + continue; + } + if (inputExprWalker.convolvedDims.count(filterDim) && + !outputDims.count(filterDim)) { + // Filter loop dimension. + if (*std::next(iteratorTypesRange.begin(), filterDim) != + getReductionIteratorTypeName()) + return MatchConvolutionResult::NonOutputDimNotReduction; + if (allLoopDims.count(filterDim)) + return MatchConvolutionResult::NonConvolutionLoop; + allLoopDims.insert(filterDim); + continue; + } + if (inputExprWalker.unConvolvedDims.count(filterDim) && + !outputDims.count(filterDim)) { + // Input channel dimension. + if (*std::next(iteratorTypesRange.begin(), filterDim) != + getReductionIteratorTypeName()) + return MatchConvolutionResult::NonOutputDimNotReduction; + if (allLoopDims.count(filterDim)) + return MatchConvolutionResult::NonConvolutionLoop; + allLoopDims.insert(filterDim); + continue; + } + if (inputExprWalker.unConvolvedDims.count(filterDim) && + outputDims.count(filterDim)) { + // Depthwise loop. Already seen. + continue; + } + return MatchConvolutionResult::NonConvolutionLoop; + } + // All loops must be covered now. + if (allLoopDims.size() != linalgOp.getNumLoops()) + return MatchConvolutionResult::NonConvolutionLoop; + + return MatchConvolutionResult::Success; +} + +LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) { + auto res = isConvolutionInterfaceImpl(op); + if (res == MatchConvolutionResult::NotLinalgOp) + return op->emitError("expected a LinalgOp"); + if (res == MatchConvolutionResult::WrongNumOperands) + return op->emitError("expected op with 2 inputs and 1 output"); + if (res == MatchConvolutionResult::WrongInputIndexingMap) + return op->emitError("unexpected input index map for convolutions"); + if (res == MatchConvolutionResult::NotProjectedPermutations) { + return op->emitError( + "expected output/filter indexing maps to be projected permutations"); + } + if (res == MatchConvolutionResult::NonConvolutionLoop) { + return op->emitError("unexpected dimension for convolution op"); + } + if (res == MatchConvolutionResult::OutputDimsNotParallel) { + return op->emitError( + "expected all iterators used to access outputs to be parallel"); + } + if (res == MatchConvolutionResult::NonOutputDimNotReduction) { + return op->emitError( + "expected all iterators not used to access outputs to be reduction"); + } + return success(); +} + +Optional mlir::getConvolutionBatchLoopDim(Operation *op) { + auto linalgOp = cast(op); + auto indexingMaps = linalgOp.getIndexingMaps(); + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + llvm_unreachable("illegal input indexing map expression"); + } + // The batch dimensions are part of unconvoled dimensions in the input, and + // not present in the filter. + llvm::SmallDenseSet filterDims = getPreservedDims(indexingMaps[1]); + for (AffineExpr expr : indexingMaps.back().getResults()) { + unsigned dim = expr.cast().getPosition(); + if (inputExprWalker.unConvolvedDims.count(dim) && !filterDims.count(dim)) + return dim; + } + return llvm::None; +} + +SmallVector mlir::getConvolutionOutputImageLoopDims(Operation *op) { + auto linalgOp = cast(op); + auto indexingMaps = linalgOp.getIndexingMaps(); + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + llvm_unreachable("illegal input indexing map expression"); + } + // The output image loops are dims that are convolved in the input, and + // present in the output. + SmallVector outputImageLoopDims; + for (AffineExpr expr : indexingMaps.back().getResults()) { + unsigned dim = expr.cast().getPosition(); + if (inputExprWalker.convolvedDims.count(dim)) + outputImageLoopDims.push_back(dim); + } + return outputImageLoopDims; +} + +SmallVector mlir::getConvolutionFilterLoopDims(Operation *op) { + auto linalgOp = cast(op); + auto indexingMaps = linalgOp.getIndexingMaps(); + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + llvm_unreachable("illegal input indexing map expression"); + } + // The filter loops are dims that are convolved in the input, and present in + // the output. + SmallVector filterLoopDims; + for (AffineExpr expr : indexingMaps[1].getResults()) { + unsigned dim = expr.cast().getPosition(); + if (inputExprWalker.convolvedDims.count(dim)) + filterLoopDims.push_back(dim); + } + return filterLoopDims; +} + +Optional mlir::getConvolutionOutputChannelLoopDim(Operation *op) { + auto linalgOp = cast(op); + auto indexingMaps = linalgOp.getIndexingMaps(); + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + llvm_unreachable("illegal input indexing map expression"); + } + // The output channel dimensions are not part of convolved or unconvolved + // dimensions in the input, and present in the filter. + llvm::SmallDenseSet filterDims = getPreservedDims(indexingMaps[1]); + for (AffineExpr expr : indexingMaps.back().getResults()) { + unsigned dim = expr.cast().getPosition(); + if (!inputExprWalker.convolvedDims.count(dim) && + !inputExprWalker.unConvolvedDims.count(dim) && filterDims.count(dim)) + return dim; + } + return llvm::None; +} + +Optional mlir::getConvolutionInputChannelLoopDim(Operation *op) { + auto linalgOp = cast(op); + auto indexingMaps = linalgOp.getIndexingMaps(); + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + llvm_unreachable("illegal input indexing map expression"); + } + // The input channel dimensions are part of unconvolved dimensions in the + // input, and not present in the output. + llvm::SmallDenseSet outputDims = + getPreservedDims(indexingMaps.back()); + for (AffineExpr expr : indexingMaps[1].getResults()) { + unsigned dim = expr.cast().getPosition(); + if (inputExprWalker.unConvolvedDims.count(dim) && !outputDims.count(dim)) + return dim; + } + return llvm::None; +} + +Optional mlir::getConvolutionDepthwiseLoopDim(Operation *op) { + auto linalgOp = cast(op); + auto indexingMaps = linalgOp.getIndexingMaps(); + ConvAccessExprWalker inputExprWalker; + if (llvm::any_of(indexingMaps[0].getResults(), + [&inputExprWalker](AffineExpr expr) { + return failed(inputExprWalker.visit(expr)); + })) { + llvm_unreachable("illegal input indexing map expression"); + } + // The input channel dimensions are part of unconvolved dimensions in the + // input, and not present in the output. + llvm::SmallDenseSet filterDims = getPreservedDims(indexingMaps[1]); + for (AffineExpr expr : indexingMaps.back().getResults()) { + unsigned dim = expr.cast().getPosition(); + if (inputExprWalker.unConvolvedDims.count(dim) && filterDims.count(dim)) + return dim; + } + return llvm::None; +} + //===----------------------------------------------------------------------===// // StructuredOpInterface implementation //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -484,7 +484,7 @@ ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface") - +ConvolutionOpInterface = OpInterfaceDef("LinalgConvolutionOpInterface") class OpMetadataDef(YAMLObject): """Metadata about the op (generally not behavior impacting).""" diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -154,6 +154,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.ow, D.kw) O[D.ow] += cast( U, I[D.ow + D.kw]) * cast(U, K[D.kw]) @@ -168,6 +169,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.oh, D.ow, D.kh, D.kw) O[D.oh, D.ow] += cast( U, I[D.oh + D.kh, D.ow + D.kw]) * cast(U, K[D.kh, D.kw]) @@ -182,6 +184,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw) O[D.od, D.oh, D.ow] += cast( U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]) * cast(U, K[D.kd, D.kh, D.kw]) @@ -198,6 +201,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.ow, D.f, D.kw, D.c) O[D.n, D.ow, D.f] += cast( U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c @@ -219,6 +223,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.f] += cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c @@ -243,6 +248,7 @@ them to the same data type as the accumulator/output. This includes the zero point offsets common to quantized operations. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.f] += (cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c @@ -264,6 +270,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) O[D.n, D.f, D.oh, D.ow] += cast( U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW @@ -281,6 +288,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) O[D.n, D.od, D.oh, D.ow, D.f] += cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c @@ -299,6 +307,7 @@ them to the same data type as the accumulator/output. Multiplier is set to 1 which is a special case for most dpethwise convolutions. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) O[D.n, D.oh, D.ow, D.ic] += cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, @@ -318,6 +327,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) O[D.n, D.oh, D.ow, D.ic] += ( (cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, @@ -336,6 +346,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) O[D.n, D.oh, D.ow, D.ic, D.cm] += cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, @@ -355,6 +366,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) O[D.n, D.oh, D.ow, D.ic, D.cm] += ( (cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, @@ -374,6 +386,7 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] += cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) @@ -391,6 +404,7 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)( cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, @@ -408,6 +422,7 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) O[D.n, D.c, D.oh, D.ow] = ReduceFn.max(D.kh, D.kw)( cast(U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, @@ -425,6 +440,7 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)( cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, @@ -443,6 +459,7 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) O[D.n, D.od, D.oh, D.ow, D.c] += cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, @@ -461,6 +478,7 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max(D.kd, D.kh, D.kw)( cast( @@ -480,6 +498,7 @@ Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min(D.kd, D.kh, D.kw)( cast( diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -82,45 +82,48 @@ return } -// ----- - -func @depthwise_conv_2d_input_nhwc_filter_missing_stride(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { - // expected-error @+1 {{missing indexing map required attribute 'strides'}} - linalg.depthwise_conv2D_nhw {dilations = dense<1> : vector<2xi64>} - ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) - outs(%output: memref<1x56x56x96xf32>) - return -} - -// ----- - -func @depthwise_conv_2d_input_nhwc_filter_missing_dilations(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { - // expected-error @+1 {{missing indexing map required attribute 'dilations'}} - linalg.depthwise_conv2D_nhw {strides = dense<1> : vector<2xi64>} - ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) - outs(%output: memref<1x56x56x96xf32>) - return -} - -// ----- - -func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { - // expected-error @+1 {{incorrect element type for indexing map required attribute 'strides'}} - linalg.depthwise_conv2D_nhw {dilations = dense<1> : vector<2xi64>, strides = dense<2.0> : vector<2xf32>} - ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) - outs(%output: memref<1x56x56x96xf32>) - return -} - -// ----- - -func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_size(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { - // expected-error @+1 {{incorrect shape for indexing map required attribute 'strides'}} - linalg.depthwise_conv2D_nhw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<3xi64> } - ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) - outs(%output: memref<1x56x56x96xf32>) - return -} +// // ----- + +// DO-NOT-SUBMIT: These tests fail since the op is not verified before +// the interface, which causes a segfault in the interface verifier. + +// func @depthwise_conv_2d_input_nhwc_filter_missing_stride(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { +// // eXpected-error @+1 {{missing indexing map required attribute 'strides'}} +// linalg.depthwise_conv2D_nhw {dilations = dense<1> : vector<2xi64>} +// ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) +// outs(%output: memref<1x56x56x96xf32>) +// return +// } + +// // ----- + +// func @depthwise_conv_2d_input_nhwc_filter_missing_dilations(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { +// // eXpected-error @+1 {{missing indexing map required attribute 'dilations'}} +// linalg.depthwise_conv2D_nhw {strides = dense<1> : vector<2xi64>} +// ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) +// outs(%output: memref<1x56x56x96xf32>) +// return +// } + +// // ----- + +// func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { +// // eXpected-error @+1 {{incorrect element type for indexing map required attribute 'strides'}} +// linalg.depthwise_conv2D_nhw {dilations = dense<1> : vector<2xi64>, strides = dense<2.0> : vector<2xf32>} +// ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) +// outs(%output: memref<1x56x56x96xf32>) +// return +// } + +// // ----- + +// func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_size(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { +// // eXpected-error @+1 {{incorrect shape for indexing map required attribute 'strides'}} +// linalg.depthwise_conv2D_nhw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<3xi64> } +// ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) +// outs(%output: memref<1x56x56x96xf32>) +// return +// } // -----