diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1667,4 +1667,41 @@ }]; } +//===----------------------------------------------------------------------===// +// ConvertConv2DToImg2ColOp +//===----------------------------------------------------------------------===// + +def ConvertConv2DToImg2ColOp : Op { + let description = [{ + Applies im2col to every supported operation in the given target. + This is a wrapper around populateConvertConv2DToImg2ColPatterns. + + Return modes: + ============= + Returns a definite failure if target is not isolated from above. + Returns a silenceable failure if the pattern application failed. + }]; + + let arguments = (ins PDL_Operation:$target); + let results = (outs PDL_Operation:$transformed); + + let assemblyFormat = "$target attr-dict"; + + let builders = [ + OpBuilder<(ins "Value":$target)> + ]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // LINALG_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -878,6 +878,10 @@ void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Populates patterns to transform linalg.conv_2d_xxx operations into +/// linalg.generic (for img2col packing) and linalg.matmul. +void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns); + //===----------------------------------------------------------------------===// // Op-specific patterns. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3032,6 +3032,28 @@ return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b); } +//===----------------------------------------------------------------------===// +// ConvertConv2DToImg2ColOp. +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne( + Operation *target, transform::ApplyToEachResultList &results, + transform::TransformState &state) { + if (!target->hasTrait()) { + auto diag = this->emitOpError("requires isolated-from-above targets"); + diag.attachNote(target->getLoc()) << "non-isolated target"; + return DiagnosedSilenceableFailure::definiteFailure(); + } + RewritePatternSet patterns(getContext()); + populateConvertConv2DToImg2ColPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) { + return emitSilenceableFailure(target->getLoc()) + << "failed to apply img2col"; + } + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ Bufferize.cpp ConstantFold.cpp ConvertToDestinationStyle.cpp + ConvertConv2DToImg2Col.cpp DataLayoutPropagation.cpp DecomposeLinalgOps.cpp Detensorize.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp @@ -0,0 +1,539 @@ +//===- ConvertConv2DToImg2Col.cpp - im2col implementation -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace linalg { +static bool hasAllOneValues(DenseIntElementsAttr attr) { + return llvm::all_of( + attr, [](APInt element) { return element.getSExtValue() == 1; }); +} + +static Value createAdd(Location loc, Value x, Value y, bool isInt, + OpBuilder &builder) { + if (isInt) + return builder.create(loc, x, y); + return builder.create(loc, x, y); +} + +static Value createMul(Location loc, Value x, Value y, bool isInt, + OpBuilder &builder) { + if (isInt) + return builder.create(loc, x, y); + return builder.create(loc, x, y); +} + +namespace { + +// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) +// and linalg.matmul. +// +// A convolution operaton can be written as a matrix-matrix multiplication by +// unfolding the cross corrolation between input and filter and explicitly copy +// overlapped sliding window inputs. +// +// Consider 2D input X with single channel input and output and 2x2 filter W: +// [x(0, 0) , x(0, 1) , ..., x(0, n) ] +// [x(1, 0) , x(1, 1) , ..., x(1, n) ] +// [. , . ,. , . ] [w(0, 0), w(0, 1)] +// [. , . , . , . ] (conv) [w(1, 0), w(1, 1)] +// [. , . , ., . ] +// [x(n-1, 0), x(n-1, 1), ..., x(n-1, n-1)] +// +// The packed input data (img2col) is a matrix with |rows| = output spatial +// size, |columns| = filter spatial size. To compute the output Y(i, j) we need +// to calculate the dot product between filter window at input X(x, y)) and the +// filter which will look like the following where r.h.s is the img2col matrix +// and l.h.s is the flattned filter: +// +// clang-format off +// [x(0, 0), x(0, 1), x(1, 0), x(1, 1)] +// [x(0, 1), x(1, 1), x(0, 2), x(1, 2)] (matmul) [w(0, 0), w(0, 1), w(1, 0), w(1, 1)] +// [x(0, 1), x(1, 1), x(0, 2), x(1, 2)] +// [ . , . , . , . ] +// clang-format on +// +// In general for 2D case with (N, H, W, C) input and (Kh, Kw, C, D) filter +// and output (N, Ho, Wo, D) the convolutin is the following matrix-matrix +// multiplication (Ho x Wo, Kh x Kw x C) * (Kh x Kw x C, D) for each input in +// the N input. For the case where N > 1 its a batched matrxi-matrix +// multplication. +class ConvertConv2DNhwcHwcf final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, + PatternRewriter &rewriter) const override { + auto inputType = convOp.getInputs()[0].getType().cast(); + auto filterType = convOp.getInputs()[1].getType().cast(); + auto outputType = convOp.getOutputs()[0].getType().cast(); + + if (!filterType.hasStaticShape() || !inputType.hasStaticShape()) { + return failure(); + } + + // TODO: Support dilation. + if (!hasAllOneValues(convOp.getDilations())) + return failure(); + + Value input = convOp.getInputs()[0]; + Value filter = convOp.getInputs()[1]; + Value output = convOp.getOutputs()[0]; + + auto filterShape = filterType.getShape(); + auto outputShape = outputType.getShape(); + + const int n = outputShape[0]; + const int oh = outputShape[1]; + const int ow = outputShape[2]; + const int oc = outputShape[3]; + const int fh = filterShape[0]; + const int fw = filterShape[1]; + const int ic = filterShape[2]; + + auto loc = convOp.getLoc(); + + SmallVector colTensorShape = {n, oh, ow, fh, fw, ic}; + + Value colTensor = rewriter.create( + loc, colTensorShape, inputType.getElementType()); + + AffineExpr nDim, ohDim, owDim, khDim, kwDim, icDim; + bindDims(getContext(), nDim, ohDim, owDim, khDim, kwDim, icDim); + + auto shSym = rewriter.getAffineConstantExpr( + convOp.getStrides().getValues()[0]); + auto swSym = rewriter.getAffineConstantExpr( + convOp.getStrides().getValues()[1]); + + SmallVector inputExprs = {nDim, ohDim * shSym + khDim, + owDim * swSym + kwDim, icDim}; + + auto nloops = colTensorShape.size(); + + auto parallel = utils::IteratorType::parallel; + auto reduction = utils::IteratorType::reduction; + SmallVector img2colIterators(nloops, parallel); + + SmallVector img2colIndexingMaps = { + AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()), + AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())}; + + auto img2ColTensor = rewriter.create( + loc, colTensor.getType(), + /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps, + img2colIterators, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(nestedLoc, args[0]); + }); + + SmallVector img2ColTensorReassocIndices; + SmallVector outputReassocIndices; + RankedTensorType reshapedImg2ColTensorType, reshapedOutputType; + if (n == 1) { + img2ColTensorReassocIndices = {{0, 1, 2}, {3, 4, 5}}; + outputReassocIndices = {{0, 1, 2}, {3}}; + + reshapedImg2ColTensorType = RankedTensorType::get( + {oh * ow, fh * fw * ic}, inputType.getElementType()); + reshapedOutputType = + RankedTensorType::get({oh * ow, oc}, outputType.getElementType()); + } else { + img2ColTensorReassocIndices = {{0}, {1, 2}, {3, 4, 5}}; + outputReassocIndices = {{0}, {1, 2}, {3}}; + + reshapedImg2ColTensorType = RankedTensorType::get( + {n, oh * ow, fh * fw * ic}, inputType.getElementType()); + reshapedOutputType = + RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType()); + } + + SmallVector filterReassocIndices = {{0, 1, 2}, {3}}; + auto reshapedFilterType = + RankedTensorType::get({fh * fw * ic, oc}, inputType.getElementType()); + + Value reshapedImg2ColTensor = rewriter.create( + loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0), + img2ColTensorReassocIndices); + + Value reshapedFilter = rewriter.create( + loc, reshapedFilterType, filter, filterReassocIndices); + + Value reshapedOutput = rewriter.create( + loc, reshapedOutputType, output, outputReassocIndices); + + Value result; + if (n == 1) { + auto matmulOp = rewriter.create( + loc, reshapedOutputType, + ArrayRef{reshapedImg2ColTensor, reshapedFilter}, + ArrayRef{reshapedOutput}); + result = matmulOp.getResults().front(); + } else { + // For cases where batch is not 1, we need to keep the batch dimension + // separate. Because the filter does not share the same batch dimension, + // the batch dimension is only used in indexing the input and output. Thus + // we cannot use existing linalg named ops like linalg.batch_matmul. + // i.e. (B x) M x K * K x N = (B x) M x N + AffineExpr bDim, mDim, nDim, kDim; + bindDims(getContext(), bDim, mDim, nDim, kDim); + auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, getContext()); + auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, getContext()); + auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, getContext()); + SmallVector genericIterators = {parallel, parallel, + parallel, reduction}; + bool isInt = outputType.getElementType().isa(); + auto genericOp = rewriter.create( + loc, reshapedOutputType, + /*inputs=*/ValueRange{reshapedImg2ColTensor, reshapedFilter}, + /*outputs=*/ValueRange{reshapedOutput}, + ArrayRef{lhsMap, rhsMap, resultMap}, genericIterators, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + Value mul = createMul(loc, args[0], args[1], isInt, nestedBuilder); + Value add = createAdd(loc, mul, args[2], isInt, nestedBuilder); + nestedBuilder.create(nestedLoc, add); + }); + result = genericOp.getResults().front(); + } + + auto reshapedResult = rewriter.create( + loc, outputType, result, outputReassocIndices); + + rewriter.replaceOp(convOp, ArrayRef{reshapedResult}); + + return success(); + } +}; + +// Similar to the conv pattern above except there is no reduction among the +// input channles so each convolution can be a matrix-vector product and +// by transposing both input filter so channles are outer most the computation +// is a batched matrix-vector product. +class ConvertDepthwiseConv2DNhwcHwc final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp, + PatternRewriter &rewriter) const override { + auto inputType = convOp.getInputs()[0].getType().cast(); + auto filterType = convOp.getInputs()[1].getType().cast(); + auto outputType = convOp.getOutputs()[0].getType().cast(); + + if (!filterType.hasStaticShape() || !inputType.hasStaticShape()) { + return failure(); + } + + // TODO: Support dilation. + if (!hasAllOneValues(convOp.getDilations())) + return failure(); + + auto loc = convOp.getLoc(); + + auto transposeOperand = [&](Value operand, ArrayRef indices) { + auto operandTensorType = operand.getType().cast(); + auto nloops = indices.size(); + auto inputShape = operandTensorType.getShape(); + + SmallVector exprs = llvm::to_vector<4>( + llvm::map_range(indices, [&](int64_t index) -> AffineExpr { + return rewriter.getAffineDimExpr(index); + })); + + SmallVector targetShape = llvm::to_vector<4>( + llvm::map_range(indices, [&](int64_t index) -> int64_t { + return inputShape[index]; + })); + + Value outputTensor = rewriter.create( + loc, targetShape, operandTensorType.getElementType()); + + SmallVector loopAttributeTypes( + nloops, utils::IteratorType::parallel); + + SmallVector indexingMaps = { + inversePermutation( + AffineMap::get(nloops, 0, exprs, rewriter.getContext())), + AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())}; + + auto transposedOp = rewriter.create( + loc, outputTensor.getType(), + /*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps, + loopAttributeTypes, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(nestedLoc, args[0]); + }); + + return transposedOp.getResult(0); + }; + + Value input = convOp.getInputs()[0]; + Value filter = convOp.getInputs()[1]; + Value output = convOp.getOutputs()[0]; + + // Transpose input, filter so channels are outermost + auto inputT = transposeOperand(input, {0, 3, 1, 2}); + auto filterT = transposeOperand(filter, {2, 0, 1}); + auto filterTShape = filterT.getType().cast().getShape(); + auto outputShape = outputType.getShape(); + + const int n = outputShape[0]; + const int oh = outputShape[1]; + const int ow = outputShape[2]; + const int c = outputShape[3]; + const int fh = filterTShape[1]; + const int fw = filterTShape[2]; + + SmallVector colTensorShape = {n, c, oh, ow, fh, fw}; + Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2}); + + AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim; + bindDims(getContext(), nDim, cDim, ohDim, owDim, khDim, kwDim); + + auto shSym = rewriter.getAffineConstantExpr( + convOp.getStrides().getValues()[0]); + auto swSym = rewriter.getAffineConstantExpr( + convOp.getStrides().getValues()[1]); + + SmallVector inputExprs = {nDim, cDim, ohDim * shSym + khDim, + owDim * swSym + kwDim}; + + auto nloops = colTensorShape.size(); + + SmallVector loopAttributeTypes( + nloops, utils::IteratorType::parallel); + + SmallVector indexingMaps = { + AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()), + AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())}; + + Value colTensor = rewriter.create( + loc, colTensorShape, inputType.getElementType()); + + auto img2ColTensor = rewriter.create( + loc, colTensor.getType(), + /*inputs=*/inputT, /*outputs=*/colTensor, indexingMaps, + loopAttributeTypes, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(nestedLoc, args[0]); + }); + + SmallVector img2ColTensorReassocIndices = { + {0, 1}, {2, 3}, {4, 5}}; + SmallVector filterReassociationIndice = {{0}, {1, 2}}; + SmallVector outputReassociationIndice = {{0, 1}, + {2, 3}}; + + auto reshapedImg2ColTensorType = RankedTensorType::get( + {n * c, oh * ow, fh * fw}, inputType.getElementType()); + auto reshapedFilterTensorType = + RankedTensorType::get({c, fh * fw}, filterType.getElementType()); + auto reshapedOutputTensorType = + RankedTensorType::get({n * c, oh * ow}, outputType.getElementType()); + + Value reshapedImg2ColTensor = rewriter.create( + loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0), + img2ColTensorReassocIndices); + Value reshapedFilterTensor = rewriter.create( + loc, reshapedFilterTensorType, filterT, filterReassociationIndice); + Value reshapedoutputTensor = rewriter.create( + loc, reshapedOutputTensorType, transposedOutputTensor, + outputReassociationIndice); + + auto batchMatVecResult = rewriter.create( + loc, TypeRange{reshapedoutputTensor.getType()}, + ValueRange{reshapedImg2ColTensor, reshapedFilterTensor}, + ValueRange{reshapedoutputTensor}); + + SmallVector batchMatVecReassociationIndice = {{0, 1}, + {2, 3}}; + + Value batchMatVecResultReshaped = rewriter.create( + loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0), + batchMatVecReassociationIndice); + + auto transposedResult = + transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1}); + + rewriter.replaceOp(convOp, ArrayRef{transposedResult}); + return success(); + } +}; + +// For nchw, because the channels are to the left of the image shape dimensions, +// the position of the contraction dimension in the resulting matmul is +// reversed. This swaps the LHS and RHS of the matmul when compared with nhwc +// (i.e. (D, C x Kh x Kw) * (C x Kh x Kw, Ho x Wo)) +class ConvertConv2DNchwFchw final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp, + PatternRewriter &rewriter) const override { + auto inputType = convOp.getInputs()[0].getType().cast(); + auto filterType = convOp.getInputs()[1].getType().cast(); + auto outputType = convOp.getOutputs()[0].getType().cast(); + + if (!filterType.hasStaticShape() || !inputType.hasStaticShape()) { + return failure(); + } + + // TODO: Support dilation. + if (!hasAllOneValues(convOp.getDilations())) + return failure(); + + Value input = convOp.getInputs()[0]; + Value filter = convOp.getInputs()[1]; + Value output = convOp.getOutputs()[0]; + + auto filterShape = filterType.getShape(); + auto outputShape = outputType.getShape(); + + const int n = outputShape[0]; + const int oc = outputShape[1]; + const int oh = outputShape[2]; + const int ow = outputShape[3]; + const int ic = filterShape[1]; + const int fh = filterShape[2]; + const int fw = filterShape[3]; + + auto loc = convOp.getLoc(); + + SmallVector colTensorShape = {n, ic, fh, fw, oh, ow}; + + Value colTensor = rewriter.create( + loc, colTensorShape, inputType.getElementType()); + + AffineExpr nDim, icDim, khDim, kwDim, ohDim, owDim; + bindDims(getContext(), nDim, icDim, khDim, kwDim, ohDim, owDim); + + auto shSym = rewriter.getAffineConstantExpr( + convOp.getStrides().getValues()[0]); + auto swSym = rewriter.getAffineConstantExpr( + convOp.getStrides().getValues()[1]); + + SmallVector inputExprs = {nDim, icDim, ohDim * shSym + khDim, + owDim * swSym + kwDim}; + + auto nloops = colTensorShape.size(); + + auto parallel = utils::IteratorType::parallel; + auto reduction = utils::IteratorType::reduction; + SmallVector img2colIterators(nloops, parallel); + + SmallVector img2colIndexingMaps = { + AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()), + AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())}; + + auto img2ColTensor = rewriter.create( + loc, colTensor.getType(), + /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps, + img2colIterators, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(nestedLoc, args[0]); + }); + + SmallVector filterReassocIndices = {{0}, {1, 2, 3}}; + auto reshapedFilterType = + RankedTensorType::get({oc, fh * fw * ic}, inputType.getElementType()); + Value reshapedFilter = rewriter.create( + loc, reshapedFilterType, filter, filterReassocIndices); + + SmallVector img2ColTensorReassocIndices; + SmallVector outputReassocIndices; + RankedTensorType reshapedImg2ColTensorType, reshapedOutputType; + if (n == 1) { + img2ColTensorReassocIndices = {{0, 1, 2, 3}, {4, 5}}; + outputReassocIndices = {{0, 1}, {2, 3}}; + + reshapedImg2ColTensorType = RankedTensorType::get( + {fh * fw * ic, oh * ow}, inputType.getElementType()); + reshapedOutputType = + RankedTensorType::get({oc, oh * ow}, outputType.getElementType()); + } else { + img2ColTensorReassocIndices = {{0}, {1, 2, 3}, {4, 5}}; + outputReassocIndices = {{0}, {1}, {2, 3}}; + + reshapedImg2ColTensorType = RankedTensorType::get( + {n, fh * fw * ic, oh * ow}, inputType.getElementType()); + reshapedOutputType = + RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType()); + } + + Value reshapedImg2ColTensor = rewriter.create( + loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0), + img2ColTensorReassocIndices); + + Value reshapedOutput = rewriter.create( + loc, reshapedOutputType, output, outputReassocIndices); + + Value result; + if (n == 1) { + auto matmulOp = rewriter.create( + loc, reshapedOutputType, + ArrayRef{reshapedFilter, reshapedImg2ColTensor}, + ArrayRef{reshapedOutput}); + result = matmulOp.getResults().front(); + } else { + // For cases where batch is not 1, we need to keep the batch dimension + // separate. Because the filter does not share the same batch dimension, + // the batch dimension is only used in indexing the input and output. Thus + // we cannot use existing linalg named ops like linalg.batch_matmul. + // i.e. M x K * (B x) K x N = (B x) M x N + AffineExpr bDim, mDim, nDim, kDim; + bindDims(getContext(), bDim, mDim, nDim, kDim); + auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, getContext()); + auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, getContext()); + auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, getContext()); + SmallVector genericIterators = {parallel, parallel, + parallel, reduction}; + bool isInt = outputType.getElementType().isa(); + auto genericOp = rewriter.create( + loc, reshapedOutputType, + /*inputs=*/ValueRange{reshapedFilter, reshapedImg2ColTensor}, + /*outputs=*/ValueRange{reshapedOutput}, + ArrayRef{lhsMap, rhsMap, resultMap}, genericIterators, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + Value mul = createMul(loc, args[0], args[1], isInt, nestedBuilder); + Value add = createAdd(loc, mul, args[2], isInt, nestedBuilder); + nestedBuilder.create(nestedLoc, add); + }); + result = genericOp.getResults().front(); + } + + auto reshapedResult = rewriter.create( + loc, outputType, result, outputReassocIndices); + + rewriter.replaceOp(convOp, ArrayRef{reshapedResult}); + + return success(); + } +}; +} // end anonymous namespace + +void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.insert(context); +} +} // end namespace linalg +} // end namespace mlir