diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1667,4 +1667,41 @@ }]; } +//===----------------------------------------------------------------------===// +// ConvertConv2DToImg2ColOp +//===----------------------------------------------------------------------===// + +def ConvertConv2DToImg2ColOp : Op { + let description = [{ + Applies im2col to every supported operation in the given target. + This is a wrapper around populateConvertConv2DToImg2ColPatterns. + + Return modes: + ============= + Returns a definite failure if target is not isolated from above. + Returns a silenceable failure if the pattern application failed. + }]; + + let arguments = (ins PDL_Operation:$target); + let results = (outs PDL_Operation:$transformed); + + let assemblyFormat = "$target attr-dict"; + + let builders = [ + OpBuilder<(ins "Value":$target)> + ]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // LINALG_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -878,6 +878,10 @@ void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Populates patterns to transform linalg.conv_2d_xxx operations into +/// linalg.generic (for img2col packing) and linalg.matmul. +void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns); + //===----------------------------------------------------------------------===// // Op-specific patterns. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3032,6 +3032,28 @@ return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b); } +//===----------------------------------------------------------------------===// +// ConvertConv2DToImg2ColOp. +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne( + Operation *target, transform::ApplyToEachResultList &results, + transform::TransformState &state) { + if (!target->hasTrait()) { + auto diag = this->emitOpError("requires isolated-from-above targets"); + diag.attachNote(target->getLoc()) << "non-isolated target"; + return DiagnosedSilenceableFailure::definiteFailure(); + } + RewritePatternSet patterns(getContext()); + populateConvertConv2DToImg2ColPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) { + return emitSilenceableFailure(target->getLoc()) + << "failed to apply img2col"; + } + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ Bufferize.cpp ConstantFold.cpp ConvertToDestinationStyle.cpp + ConvertConv2DToImg2Col.cpp DataLayoutPropagation.cpp DecomposeLinalgOps.cpp Detensorize.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp @@ -0,0 +1,539 @@ +//===- ConvertConv2DToImg2Col.cpp - im2col implementation -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace linalg { +static bool hasAllOneValues(DenseIntElementsAttr attr) { + return llvm::all_of( + attr, [](APInt element) { return element.getSExtValue() == 1; }); +} + +static Value createAdd(Location loc, Value x, Value y, bool isInt, + OpBuilder &builder) { + if (isInt) + return builder.create(loc, x, y); + return builder.create(loc, x, y); +} + +static Value createMul(Location loc, Value x, Value y, bool isInt, + OpBuilder &builder) { + if (isInt) + return builder.create(loc, x, y); + return builder.create(loc, x, y); +} + +namespace { + +// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) +// and linalg.matmul. +// +// A convolution operaton can be written as a matrix-matrix multiplication by +// unfolding the cross corrolation between input and filter and explicitly copy +// overlapped sliding window inputs. +// +// Consider 2D input X with single channel input and output and 2x2 filter W: +// [x(0, 0) , x(0, 1) , ..., x(0, n) ] +// [x(1, 0) , x(1, 1) , ..., x(1, n) ] +// [. , . ,. , . ] [w(0, 0), w(0, 1)] +// [. , . , . , . ] (conv) [w(1, 0), w(1, 1)] +// [. , . , ., . ] +// [x(n-1, 0), x(n-1, 1), ..., x(n-1, n-1)] +// +// The packed input data (img2col) is a matrix with |rows| = output spatial +// size, |columns| = filter spatial size. To compute the output Y(i, j) we need +// to calculate the dot product between filter window at input X(x, y)) and the +// filter which will look like the following where r.h.s is the img2col matrix +// and l.h.s is the flattned filter: +// +// clang-format off +// [x(0, 0), x(0, 1), x(1, 0), x(1, 1)] +// [x(0, 1), x(1, 1), x(0, 2), x(1, 2)] (matmul) [w(0, 0), w(0, 1), w(1, 0), w(1, 1)] +// [x(0, 1), x(1, 1), x(0, 2), x(1, 2)] +// [ . , . , . , . ] +// clang-format on +// +// In general for 2D case with (N, H, W, C) input and (Kh, Kw, C, D) filter +// and output (N, Ho, Wo, D) the convolutin is the following matrix-matrix +// multiplication (Ho x Wo, Kh x Kw x C) * (Kh x Kw x C, D) for each input in +// the N input. For the case where N > 1 its a batched matrxi-matrix +// multplication. +class ConvertConv2DNhwcHwcf final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, + PatternRewriter &rewriter) const override { + auto inputType = convOp.getInputs()[0].getType().cast(); + auto filterType = convOp.getInputs()[1].getType().cast(); + auto outputType = convOp.getOutputs()[0].getType().cast(); + + if (!filterType.hasStaticShape() || !inputType.hasStaticShape()) { + return failure(); + } + + // TODO: Support dilation. + if (!hasAllOneValues(convOp.getDilations())) + return failure(); + + Value input = convOp.getInputs()[0]; + Value filter = convOp.getInputs()[1]; + Value output = convOp.getOutputs()[0]; + + auto filterShape = filterType.getShape(); + auto outputShape = outputType.getShape(); + + const int n = outputShape[0]; + const int oh = outputShape[1]; + const int ow = outputShape[2]; + const int oc = outputShape[3]; + const int fh = filterShape[0]; + const int fw = filterShape[1]; + const int ic = filterShape[2]; + + auto loc = convOp.getLoc(); + + SmallVector colTensorShape = {n, oh, ow, fh, fw, ic}; + + Value colTensor = rewriter.create( + loc, colTensorShape, inputType.getElementType()); + + AffineExpr nDim, ohDim, owDim, khDim, kwDim, icDim; + bindDims(getContext(), nDim, ohDim, owDim, khDim, kwDim, icDim); + + auto shSym = rewriter.getAffineConstantExpr( + convOp.getStrides().getValues()[0]); + auto swSym = rewriter.getAffineConstantExpr( + convOp.getStrides().getValues()[1]); + + SmallVector inputExprs = {nDim, ohDim * shSym + khDim, + owDim * swSym + kwDim, icDim}; + + auto nloops = colTensorShape.size(); + + auto parallel = utils::IteratorType::parallel; + auto reduction = utils::IteratorType::reduction; + SmallVector img2colIterators(nloops, parallel); + + SmallVector img2colIndexingMaps = { + AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()), + AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())}; + + auto img2ColTensor = rewriter.create( + loc, colTensor.getType(), + /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps, + img2colIterators, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(nestedLoc, args[0]); + }); + + SmallVector img2ColTensorReassocIndices; + SmallVector outputReassocIndices; + RankedTensorType reshapedImg2ColTensorType, reshapedOutputType; + if (n == 1) { + img2ColTensorReassocIndices = {{0, 1, 2}, {3, 4, 5}}; + outputReassocIndices = {{0, 1, 2}, {3}}; + + reshapedImg2ColTensorType = RankedTensorType::get( + {oh * ow, fh * fw * ic}, inputType.getElementType()); + reshapedOutputType = + RankedTensorType::get({oh * ow, oc}, outputType.getElementType()); + } else { + img2ColTensorReassocIndices = {{0}, {1, 2}, {3, 4, 5}}; + outputReassocIndices = {{0}, {1, 2}, {3}}; + + reshapedImg2ColTensorType = RankedTensorType::get( + {n, oh * ow, fh * fw * ic}, inputType.getElementType()); + reshapedOutputType = + RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType()); + } + + SmallVector filterReassocIndices = {{0, 1, 2}, {3}}; + auto reshapedFilterType = + RankedTensorType::get({fh * fw * ic, oc}, inputType.getElementType()); + + Value reshapedImg2ColTensor = rewriter.create( + loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0), + img2ColTensorReassocIndices); + + Value reshapedFilter = rewriter.create( + loc, reshapedFilterType, filter, filterReassocIndices); + + Value reshapedOutput = rewriter.create( + loc, reshapedOutputType, output, outputReassocIndices); + + Value result; + if (n == 1) { + auto matmulOp = rewriter.create( + loc, reshapedOutputType, + ArrayRef{reshapedImg2ColTensor, reshapedFilter}, + ArrayRef{reshapedOutput}); + result = matmulOp.getResults().front(); + } else { + // For cases where batch is not 1, we need to keep the batch dimension + // separate. Because the filter does not share the same batch dimension, + // the batch dimension is only used in indexing the input and output. Thus + // we cannot use existing linalg named ops like linalg.batch_matmul. + // i.e. (B x) M x K * K x N = (B x) M x N + AffineExpr bDim, mDim, nDim, kDim; + bindDims(getContext(), bDim, mDim, nDim, kDim); + auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, getContext()); + auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, getContext()); + auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, getContext()); + SmallVector genericIterators = {parallel, parallel, + parallel, reduction}; + bool isInt = outputType.getElementType().isa(); + auto genericOp = rewriter.create( + loc, reshapedOutputType, + /*inputs=*/ValueRange{reshapedImg2ColTensor, reshapedFilter}, + /*outputs=*/ValueRange{reshapedOutput}, + ArrayRef{lhsMap, rhsMap, resultMap}, genericIterators, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + Value mul = createMul(loc, args[0], args[1], isInt, nestedBuilder); + Value add = createAdd(loc, mul, args[2], isInt, nestedBuilder); + nestedBuilder.create(nestedLoc, add); + }); + result = genericOp.getResults().front(); + } + + auto reshapedResult = rewriter.create( + loc, outputType, result, outputReassocIndices); + + rewriter.replaceOp(convOp, ArrayRef{reshapedResult}); + + return success(); + } +}; + +// Similar to the conv pattern above except there is no reduction among the +// input channles so each convolution can be a matrix-vector product and +// by transposing both input filter so channles are outer most the computation +// is a batched matrix-vector product. +class ConvertDepthwiseConv2DNhwcHwc final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp, + PatternRewriter &rewriter) const override { + auto inputType = convOp.getInputs()[0].getType().cast(); + auto filterType = convOp.getInputs()[1].getType().cast(); + auto outputType = convOp.getOutputs()[0].getType().cast(); + + if (!filterType.hasStaticShape() || !inputType.hasStaticShape()) { + return failure(); + } + + // TODO: Support dilation. + if (!hasAllOneValues(convOp.getDilations())) + return failure(); + + auto loc = convOp.getLoc(); + + auto transposeOperand = [&](Value operand, ArrayRef indices) { + auto operandTensorType = operand.getType().cast(); + auto nloops = indices.size(); + auto inputShape = operandTensorType.getShape(); + + SmallVector exprs = llvm::to_vector<4>( + llvm::map_range(indices, [&](int64_t index) -> AffineExpr { + return rewriter.getAffineDimExpr(index); + })); + + SmallVector targetShape = llvm::to_vector<4>( + llvm::map_range(indices, [&](int64_t index) -> int64_t { + return inputShape[index]; + })); + + Value outputTensor = rewriter.create( + loc, targetShape, operandTensorType.getElementType()); + + SmallVector loopAttributeTypes( + nloops, utils::IteratorType::parallel); + + SmallVector indexingMaps = { + inversePermutation( + AffineMap::get(nloops, 0, exprs, rewriter.getContext())), + AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())}; + + auto transposedOp = rewriter.create( + loc, outputTensor.getType(), + /*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps, + loopAttributeTypes, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(nestedLoc, args[0]); + }); + + return transposedOp.getResult(0); + }; + + Value input = convOp.getInputs()[0]; + Value filter = convOp.getInputs()[1]; + Value output = convOp.getOutputs()[0]; + + // Transpose input, filter so channels are outermost + auto inputT = transposeOperand(input, {0, 3, 1, 2}); + auto filterT = transposeOperand(filter, {2, 0, 1}); + auto filterTShape = filterT.getType().cast().getShape(); + auto outputShape = outputType.getShape(); + + const int n = outputShape[0]; + const int oh = outputShape[1]; + const int ow = outputShape[2]; + const int c = outputShape[3]; + const int fh = filterTShape[1]; + const int fw = filterTShape[2]; + + SmallVector colTensorShape = {n, c, oh, ow, fh, fw}; + Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2}); + + AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim; + bindDims(getContext(), nDim, cDim, ohDim, owDim, khDim, kwDim); + + auto shSym = rewriter.getAffineConstantExpr( + convOp.getStrides().getValues()[0]); + auto swSym = rewriter.getAffineConstantExpr( + convOp.getStrides().getValues()[1]); + + SmallVector inputExprs = {nDim, cDim, ohDim * shSym + khDim, + owDim * swSym + kwDim}; + + auto nloops = colTensorShape.size(); + + SmallVector loopAttributeTypes( + nloops, utils::IteratorType::parallel); + + SmallVector indexingMaps = { + AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()), + AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())}; + + Value colTensor = rewriter.create( + loc, colTensorShape, inputType.getElementType()); + + auto img2ColTensor = rewriter.create( + loc, colTensor.getType(), + /*inputs=*/inputT, /*outputs=*/colTensor, indexingMaps, + loopAttributeTypes, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(nestedLoc, args[0]); + }); + + SmallVector img2ColTensorReassocIndices = { + {0, 1}, {2, 3}, {4, 5}}; + SmallVector filterReassociationIndice = {{0}, {1, 2}}; + SmallVector outputReassociationIndice = {{0, 1}, + {2, 3}}; + + auto reshapedImg2ColTensorType = RankedTensorType::get( + {n * c, oh * ow, fh * fw}, inputType.getElementType()); + auto reshapedFilterTensorType = + RankedTensorType::get({c, fh * fw}, filterType.getElementType()); + auto reshapedOutputTensorType = + RankedTensorType::get({n * c, oh * ow}, outputType.getElementType()); + + Value reshapedImg2ColTensor = rewriter.create( + loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0), + img2ColTensorReassocIndices); + Value reshapedFilterTensor = rewriter.create( + loc, reshapedFilterTensorType, filterT, filterReassociationIndice); + Value reshapedoutputTensor = rewriter.create( + loc, reshapedOutputTensorType, transposedOutputTensor, + outputReassociationIndice); + + auto batchMatVecResult = rewriter.create( + loc, TypeRange{reshapedoutputTensor.getType()}, + ValueRange{reshapedImg2ColTensor, reshapedFilterTensor}, + ValueRange{reshapedoutputTensor}); + + SmallVector batchMatVecReassociationIndice = {{0, 1}, + {2, 3}}; + + Value batchMatVecResultReshaped = rewriter.create( + loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0), + batchMatVecReassociationIndice); + + auto transposedResult = + transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1}); + + rewriter.replaceOp(convOp, ArrayRef{transposedResult}); + return success(); + } +}; + +// For nchw, because the channels are to the left of the image shape dimensions, +// the position of the contraction dimension in the resulting matmul is +// reversed. This swaps the LHS and RHS of the matmul when compared with nhwc +// (i.e. (D, C x Kh x Kw) * (C x Kh x Kw, Ho x Wo)) +class ConvertConv2DNchwFchw final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp, + PatternRewriter &rewriter) const override { + auto inputType = convOp.getInputs()[0].getType().cast(); + auto filterType = convOp.getInputs()[1].getType().cast(); + auto outputType = convOp.getOutputs()[0].getType().cast(); + + if (!filterType.hasStaticShape() || !inputType.hasStaticShape()) { + return failure(); + } + + // TODO: Support dilation. + if (!hasAllOneValues(convOp.getDilations())) + return failure(); + + Value input = convOp.getInputs()[0]; + Value filter = convOp.getInputs()[1]; + Value output = convOp.getOutputs()[0]; + + auto filterShape = filterType.getShape(); + auto outputShape = outputType.getShape(); + + const int n = outputShape[0]; + const int oc = outputShape[1]; + const int oh = outputShape[2]; + const int ow = outputShape[3]; + const int ic = filterShape[1]; + const int fh = filterShape[2]; + const int fw = filterShape[3]; + + auto loc = convOp.getLoc(); + + SmallVector colTensorShape = {n, ic, fh, fw, oh, ow}; + + Value colTensor = rewriter.create( + loc, colTensorShape, inputType.getElementType()); + + AffineExpr nDim, icDim, khDim, kwDim, ohDim, owDim; + bindDims(getContext(), nDim, icDim, khDim, kwDim, ohDim, owDim); + + auto shSym = rewriter.getAffineConstantExpr( + convOp.getStrides().getValues()[0]); + auto swSym = rewriter.getAffineConstantExpr( + convOp.getStrides().getValues()[1]); + + SmallVector inputExprs = {nDim, icDim, ohDim * shSym + khDim, + owDim * swSym + kwDim}; + + auto nloops = colTensorShape.size(); + + auto parallel = utils::IteratorType::parallel; + auto reduction = utils::IteratorType::reduction; + SmallVector img2colIterators(nloops, parallel); + + SmallVector img2colIndexingMaps = { + AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()), + AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())}; + + auto img2ColTensor = rewriter.create( + loc, colTensor.getType(), + /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps, + img2colIterators, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(nestedLoc, args[0]); + }); + + SmallVector filterReassocIndices = {{0}, {1, 2, 3}}; + auto reshapedFilterType = + RankedTensorType::get({oc, fh * fw * ic}, inputType.getElementType()); + Value reshapedFilter = rewriter.create( + loc, reshapedFilterType, filter, filterReassocIndices); + + SmallVector img2ColTensorReassocIndices; + SmallVector outputReassocIndices; + RankedTensorType reshapedImg2ColTensorType, reshapedOutputType; + if (n == 1) { + img2ColTensorReassocIndices = {{0, 1, 2, 3}, {4, 5}}; + outputReassocIndices = {{0, 1}, {2, 3}}; + + reshapedImg2ColTensorType = RankedTensorType::get( + {fh * fw * ic, oh * ow}, inputType.getElementType()); + reshapedOutputType = + RankedTensorType::get({oc, oh * ow}, outputType.getElementType()); + } else { + img2ColTensorReassocIndices = {{0}, {1, 2, 3}, {4, 5}}; + outputReassocIndices = {{0}, {1}, {2, 3}}; + + reshapedImg2ColTensorType = RankedTensorType::get( + {n, fh * fw * ic, oh * ow}, inputType.getElementType()); + reshapedOutputType = + RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType()); + } + + Value reshapedImg2ColTensor = rewriter.create( + loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0), + img2ColTensorReassocIndices); + + Value reshapedOutput = rewriter.create( + loc, reshapedOutputType, output, outputReassocIndices); + + Value result; + if (n == 1) { + auto matmulOp = rewriter.create( + loc, reshapedOutputType, + ArrayRef{reshapedFilter, reshapedImg2ColTensor}, + ArrayRef{reshapedOutput}); + result = matmulOp.getResults().front(); + } else { + // For cases where batch is not 1, we need to keep the batch dimension + // separate. Because the filter does not share the same batch dimension, + // the batch dimension is only used in indexing the input and output. Thus + // we cannot use existing linalg named ops like linalg.batch_matmul. + // i.e. M x K * (B x) K x N = (B x) M x N + AffineExpr bDim, mDim, nDim, kDim; + bindDims(getContext(), bDim, mDim, nDim, kDim); + auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, getContext()); + auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, getContext()); + auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, getContext()); + SmallVector genericIterators = {parallel, parallel, + parallel, reduction}; + bool isInt = outputType.getElementType().isa(); + auto genericOp = rewriter.create( + loc, reshapedOutputType, + /*inputs=*/ValueRange{reshapedFilter, reshapedImg2ColTensor}, + /*outputs=*/ValueRange{reshapedOutput}, + ArrayRef{lhsMap, rhsMap, resultMap}, genericIterators, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + Value mul = createMul(loc, args[0], args[1], isInt, nestedBuilder); + Value add = createAdd(loc, mul, args[2], isInt, nestedBuilder); + nestedBuilder.create(nestedLoc, add); + }); + result = genericOp.getResults().front(); + } + + auto reshapedResult = rewriter.create( + loc, outputType, result, outputReassocIndices); + + rewriter.replaceOp(convOp, ArrayRef{reshapedResult}); + + return success(); + } +}; +} // end anonymous namespace + +void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.insert(context); +} +} // end namespace linalg +} // end namespace mlir diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir @@ -0,0 +1,207 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s + +// Check that the im2col patterns are properly connected with the +// transform dialect. + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK: @conv_16433136 +// CHECK: %[[INPUT:.+]]: tensor<1x16x16x4xf32> +// CHECK: %[[FILTER:.+]]: tensor<3x3x4x16xf32> +// CHECK: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32> +// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x14x14x3x3x4xf32> +// CHECK: %[[COL_TENSOR:.+]] = linalg.generic +// CHECK-SAME: #[[MAP0]] +// CHECK-SAME: #[[MAP1]] +// CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32) +// CHECK: linalg.yield %[[IN_DATA]] : f32 +// CHECK-DAG: %[[RESHAPED_INIT_COL_TENSOR:.+]] = tensor.collapse_shape %[[COL_TENSOR]] +// CHECK-SAME: [0, 1, 2], [3, 4, 5] +// CHECK-SAME: tensor<1x14x14x3x3x4xf32> into tensor<196x36xf32> +// CHECK-DAG: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] +// CHECK-SAME: [0, 1, 2], [3] +// CHECK-SAME: tensor<3x3x4x16xf32> into tensor<36x16xf32> +// CHECK-DAG: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]] +// CHECK-SAME: [0, 1, 2], [3] +// CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_INIT_COL_TENSOR]], %[[RESHAPED_FILTER]] : tensor<196x36xf32>, tensor<36x16xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<196x16xf32>) +// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] : tensor<196x16xf32> into tensor<1x14x14x16xf32> +// CHECK: return %[[RESULT]] + +func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + %0 = linalg.conv_2d_nhwc_hwcf + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>) + outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + return %0 : tensor<1x14x14x16xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.convert_conv2d_to_img2col %1 +} + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)> +// CHECK: @depthwise_conv_hwc_114x16x3 +// CHECK-SAME: %[[INPUT:.+]]: tensor<1x114x114x16xf32> +// CHECK-SAME: %[[FILTER:.+]]: tensor<3x3x16xf32> +// CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x112x112x16xf32> +// CHECK: %[[INPUT_T_INIT:.+]] = tensor.empty() : tensor<1x16x114x114xf32> +// CHECK: %[[INPUT_T:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[INPUT]] : tensor<1x114x114x16xf32>) outs(%[[INPUT_T_INIT]] : tensor<1x16x114x114xf32>) { +// CHECK-NEXT: ^bb0(%[[ARG3:.+]]: f32, %[[ARG4:.+]]: f32): +// CHECK-NEXT: linalg.yield %[[ARG3]] : f32 +// CHECK-NEXT: } -> tensor<1x16x114x114xf32> +// CHECK: %[[FILTER_T_INIT:.+]] = tensor.empty() : tensor<16x3x3xf32> +// CHECK: %[[FILTER_T:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[FILTER]] : tensor<3x3x16xf32>) outs(%[[FILTER_T_INIT]] : tensor<16x3x3xf32>) { +// CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32): +// CHECK: linalg.yield +// CHECK: } -> tensor<16x3x3xf32> +// CHECK: %[[INIT_OUTPUT_TENSOR:.+]] = tensor.empty() : tensor<1x16x112x112xf32> +// CHECK: %[[OUTPUT_T:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[OUTPUT]] : tensor<1x112x112x16xf32>) outs(%[[INIT_OUTPUT_TENSOR]] : tensor<1x16x112x112xf32>) { +// CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32): +// CHECK-NEXT: linalg.yield +// CHECK-NEXT: } -> tensor<1x16x112x112xf32> +// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x16x112x112x3x3xf32> +// CHECK: %[[COL_TENSOR:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP4]], #[[MAP5]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[INPUT_T]] : tensor<1x16x114x114xf32>) outs(%[[INIT_COL_TENSOR]] : tensor<1x16x112x112x3x3xf32>) { +// CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32): +// CHECK-NEXT: linalg.yield +// CHECK-NEXT: } -> tensor<1x16x112x112x3x3xf32> +// CHECK: %[[COL_TENSOR_R:.+]] = tensor.collapse_shape %[[COL_TENSOR]] +// CHECK-SAME: tensor<1x16x112x112x3x3xf32> into tensor<16x12544x9xf32> +// CHECK: %[[FILTER_T_R:.+]] = tensor.collapse_shape %[[FILTER_T]] +// CHECK-SAME: tensor<16x3x3xf32> into tensor<16x9xf32> +// CHECK: %[[OUTPUT_T_R:.+]] = tensor.collapse_shape %[[OUTPUT_T]] +// CHECK-SAME: tensor<1x16x112x112xf32> into tensor<16x12544xf32> +// CHECK: %[[BMV_RESULT:.+]] = linalg.batch_matvec ins(%[[COL_TENSOR_R]], %[[FILTER_T_R]] : tensor<16x12544x9xf32>, tensor<16x9xf32>) outs(%[[OUTPUT_T_R]] : tensor<16x12544xf32>) -> tensor<16x12544xf32> +// CHECK: %[[RESULT_R:.+]] = tensor.expand_shape %[[BMV_RESULT]] +// CHECK-SAME: tensor<16x12544xf32> into tensor<1x16x112x112xf32> +// CHECK: %[[RESULT_INIT:.+]] = tensor.empty() : tensor<1x112x112x16xf32> +// CHECK: %[[RESULT:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP6]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[RESULT_R]] : tensor<1x16x112x112xf32>) outs(%[[RESULT_INIT]] : tensor<1x112x112x16xf32>) { +// CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32): +// CHECK-NEXT: linalg.yield +// CHECK-NEXT: } -> tensor<1x112x112x16xf32> +// CHECK: return %[[RESULT]] : tensor<1x112x112x16xf32> +func.func @depthwise_conv_hwc_114x16x3(%input: tensor<1x114x114x16xf32>, %filter: tensor<3x3x16xf32>, %output: tensor<1x112x112x16xf32>) -> tensor<1x112x112x16xf32> { + %0 = linalg.depthwise_conv_2d_nhwc_hwc { + dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64> + } ins(%input, %filter : tensor<1x114x114x16xf32>, tensor<3x3x16xf32>) outs(%output : tensor<1x112x112x16xf32>) -> tensor<1x112x112x16xf32> + return %0 : tensor<1x112x112x16xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %2 = transform.structured.convert_conv2d_to_img2col %arg1 +} + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +// CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK: func.func @batch_nhwc_conv +// CHECK-SAME: (%[[INPUT:.+]]: tensor<8x16x16x4xf32>, %[[FILTER:.+]]: tensor<3x3x4x16xf32>, %[[INIT:.+]]: tensor<8x14x14x16xf32>) +// CHECK: %[[IT:.+]] = tensor.empty() : tensor<8x14x14x3x3x4xf32> +// CHECK: %[[IMG2COL:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[INPUT]] : tensor<8x16x16x4xf32>) +// CHECK-SAME: outs(%[[IT]] : tensor<8x14x14x3x3x4xf32>) +// CHECK: %[[CS_INPUT:.+]] = tensor.collapse_shape %[[IMG2COL]] {{\[}}[0], [1, 2], [3, 4, 5]] : tensor<8x14x14x3x3x4xf32> into tensor<8x196x36xf32> +// CHECK: %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32> +// CHECK: %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3]] : tensor<8x14x14x16xf32> into tensor<8x196x16xf32> +// CHECK: %[[MATMUL:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: ins(%[[CS_INPUT]], %[[CS_FILTER]] : tensor<8x196x36xf32>, tensor<36x16xf32>) +// CHECK-SAME: outs(%[[CS_RESULT]] : tensor<8x196x16xf32>) +// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32): +// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32 +// CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32 +// CHECK: linalg.yield %[[ADD]] : f32 +// CHECK: } -> tensor<8x196x16xf32> +// CHECK: %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1, 2], [3]] : tensor<8x196x16xf32> into tensor<8x14x14x16xf32> +// CHECK: return %[[CS_FINAL]] +func.func @batch_nhwc_conv(%arg0: tensor<8x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<8x14x14x16xf32>) -> tensor<8x14x14x16xf32> { + %0 = linalg.conv_2d_nhwc_hwcf + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<8x16x16x4xf32>, tensor<3x3x4x16xf32>) + outs(%arg2: tensor<8x14x14x16xf32>) -> tensor<8x14x14x16xf32> + return %0 : tensor<8x14x14x16xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %2 = transform.structured.convert_conv2d_to_img2col %arg1 +} + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4 + d2, d5 + d3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +// CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK: func.func @batch_nchw_conv +// CHECK-SAME: (%[[INPUT:.+]]: tensor<8x4x16x16xf32>, %[[FILTER:.+]]: tensor<16x4x3x3xf32>, %[[INIT:.+]]: tensor<8x16x14x14xf32>) +// CHECK: %[[IT:.+]] = tensor.empty() : tensor<8x4x3x3x14x14xf32> +// CHECK: %[[IMG2COL:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[INPUT]] : tensor<8x4x16x16xf32>) +// CHECK-SAME: outs(%[[IT]] : tensor<8x4x3x3x14x14xf32>) +// CHECK: %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<16x4x3x3xf32> into tensor<16x36xf32> +// CHECK: %[[CS_INPUT:.+]] = tensor.collapse_shape %[[IMG2COL]] {{\[}}[0], [1, 2, 3], [4, 5]] : tensor<8x4x3x3x14x14xf32> into tensor<8x36x196xf32> +// CHECK: %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x14x14xf32> into tensor<8x16x196xf32> +// CHECK: %[[MATMUL:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: ins(%[[CS_FILTER]], %[[CS_INPUT]] : tensor<16x36xf32>, tensor<8x36x196xf32>) +// CHECK-SAME: outs(%[[CS_RESULT]] : tensor<8x16x196xf32>) +// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32): +// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32 +// CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32 +// CHECK: linalg.yield %[[ADD]] : f32 +// CHECK: } -> tensor<8x16x196xf32> +// CHECK: %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x196xf32> into tensor<8x16x14x14xf32> +// CHECK: return %[[CS_FINAL]] +func.func @batch_nchw_conv(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> { + %0 = linalg.conv_2d_nchw_fchw + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>) + outs(%arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> + return %0 : tensor<8x16x14x14xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %2 = transform.structured.convert_conv2d_to_img2col %arg1 +}