diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1740,6 +1740,82 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::func::FuncOp target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +//===----------------------------------------------------------------------===// +// ConvertConv2DToImg2ColOp +//===----------------------------------------------------------------------===// + +def ConvertConv2DToImg2ColOp : Op { + let description = [{ + Convert linalg.conv_2d_xxx into linalg.generic (for img2col packing) + and linalg.matmul. + + A convolution operation can be written as a matrix-matrix multiplication by + unfolding the cross-correlation between input and filter and explicitly copy + overlapped sliding window inputs. + + Consider 2D input X with single channel input and output and 2x2 filter W: + ``` + [x(0, 0) , x(0, 1) , ..., x(0, n) ] + [x(1, 0) , x(1, 1) , ..., x(1, n) ] + [. , . ,. , . ] [w(0, 0), w(0, 1)] + [. , . , . , . ] (conv) [w(1, 0), w(1, 1)] + [. , . , ., . ] + [x(n-1, 0), x(n-1, 1), ..., x(n-1, n-1)] + ``` + + The packed input data (img2col) is a matrix with |rows| = output spatial + size, |columns| = filter spatial size. To compute the output Y(i, j) we need + to calculate the dot product between filter window at input X(x, y)) and the + filter which will look like the following where r.h.s is the img2col matrix + and l.h.s is the flattned filter: + ``` + [x(0,0), x(0,1), x(1,0), x(1,1)] + [x(0,1), x(1,1), x(0,2), x(1,2)] (matmul) [w(0,0), w(0,1), w(1,0), w(1,1)] + [x(0,1), x(1,1), x(0,2), x(1,2)] + [ . , . , . , . ] + ``` + + In general for 2D case with (N, H, W, C) input and (Kh, Kw, C, D) filter + and output (N, Ho, Wo, D) the convolution is the following matrix-matrix + multiplication (Ho x Wo, Kh x Kw x C) * (Kh x Kw x C, D) for each input in + the N input. For the case where N > 1 its a batched matrxi-matrix + multplication. + + Returns two handles: + - One on the operation that produces the img2col tensor. + - One on the final operation of the sequence that replaces the original + convolution. + + #### Return modes: + + Returns a definite failure if target is not isolated from above. + Returns a silenceable failure if the pattern application failed. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$img2col_tensor, + TransformHandleTypeInterface:$transformed); + + let assemblyFormat = + "$target attr-dict `:` functional-type($target, results)"; + + let builders = [ + OpBuilder<(ins "Value":$target)> + ]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -879,6 +879,64 @@ void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Populates patterns to transform linalg.conv_2d_xxx operations into +/// linalg.generic (for img2col packing) and linalg.matmul. +/// \see rewriteInIm2Col for more details. +void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns); + +/// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) +/// and linalg.matmul. +/// +/// A convolution operation can be written as a matrix-matrix multiplication by +/// unfolding the cross-correlation between input and filter and explicitly copy +/// overlapped sliding window inputs. +/// +/// Consider 2D input X with single channel input and output and 2x2 filter W: +/// [x(0, 0) , x(0, 1) , ..., x(0, n) ] +/// [x(1, 0) , x(1, 1) , ..., x(1, n) ] +/// [. , . ,. , . ] [w(0, 0), w(0, 1)] +/// [. , . , . , . ] (conv) [w(1, 0), w(1, 1)] +/// [. , . , ., . ] +/// [x(n-1, 0), x(n-1, 1), ..., x(n-1, n-1)] +/// +/// The packed input data (img2col) is a matrix with |rows| = output spatial +/// size, |columns| = filter spatial size. To compute the output Y(i, j) we need +/// to calculate the dot product between filter window at input X(x, y)) and the +/// filter which will look like the following where r.h.s is the img2col matrix +/// and l.h.s is the flattned filter: +/// +/// [x(0,0), x(0,1), x(1,0), x(1,1)] +/// [x(0,1), x(1,1), x(0,2), x(1,2)] (matmul) [w(0,0), w(0,1), w(1,0), w(1,1)] +/// [x(0,1), x(1,1), x(0,2), x(1,2)] +/// [ . , . , . , . ] +/// +/// In general for 2D case with (N, H, W, C) input and (Kh, Kw, C, D) filter +/// and output (N, Ho, Wo, D) the convolution is the following matrix-matrix +/// multiplication (Ho x Wo, Kh x Kw x C) * (Kh x Kw x C, D) for each input in +/// the N input. For the case where N > 1 its a batched matrxi-matrix +/// multplication. +/// +/// On success, return both the operation that produces the img2col tensor and +/// the final operation of the sequence that replaces the original convolution. +FailureOr> +rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp); + +/// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except there is no +/// reduction among the input channels so each convolution can be a +/// matrix-vector product and by transposing both input filter so channels are +/// outer most the computation is a batched matrix-vector product. +FailureOr> +rewriteInIm2Col(RewriterBase &rewriter, + linalg::DepthwiseConv2DNhwcHwcOp convOp); + +/// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except because the +/// channels are to the left of the image shape dimensions, the position of the +/// contraction dimension in the resulting matmul is reversed. This swaps the +/// LHS and RHS of the matmul when compared with nhwc (i.e. (D, C x Kh x Kw) * +/// (C x Kh x Kw, Ho x Wo)) +FailureOr> +rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp); + //===----------------------------------------------------------------------===// // Op-specific patterns. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -34,6 +34,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" @@ -3072,6 +3073,40 @@ results.push_back(target); return DiagnosedSilenceableFailure::success(); } + +//===----------------------------------------------------------------------===// +// ConvertConv2DToImg2ColOp. +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne( + linalg::LinalgOp target, transform::ApplyToEachResultList &results, + transform::TransformState &state) { + IRRewriter rewriter(target->getContext()); + rewriter.setInsertionPoint(target); + auto maybeTransformed = + TypeSwitch>>( + target) + .Case([&](linalg::Conv2DNhwcHwcfOp op) { + return rewriteInIm2Col(rewriter, op); + }) + .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) { + return rewriteInIm2Col(rewriter, op); + }) + .Case([&](linalg::Conv2DNchwFchwOp op) { + return rewriteInIm2Col(rewriter, op); + }) + .Default([&](Operation *op) { + return rewriter.notifyMatchFailure(op, "not supported"); + }); + if (failed(maybeTransformed)) + return emitDefaultSilenceableFailure(target); + // Handle to the operation producing the img2col tensor. + results.push_back(maybeTransformed->first); + // Handle to the operation that replaces the original convolution. + results.push_back(maybeTransformed->second); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ Bufferize.cpp ConstantFold.cpp ConvertToDestinationStyle.cpp + ConvertConv2DToImg2Col.cpp DataLayoutPropagation.cpp DecomposeLinalgOps.cpp Detensorize.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp @@ -0,0 +1,540 @@ +//===- ConvertConv2DToImg2Col.cpp - im2col implementation -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include + +namespace mlir { +namespace linalg { +static bool hasAllOneValues(DenseIntElementsAttr attr) { + return llvm::all_of( + attr, [](APInt element) { return element.getSExtValue() == 1; }); +} + +static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) { + bool isInt = x.getType().isa(); + if (isInt) + return builder.create(loc, x, y); + return builder.create(loc, x, y); +} + +static Value createMul(Location loc, Value x, Value y, OpBuilder &builder) { + bool isInt = x.getType().isa(); + if (isInt) + return builder.create(loc, x, y); + return builder.create(loc, x, y); +} + +FailureOr> +rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { + auto inputType = convOp.getInputs()[0].getType().cast(); + auto filterType = convOp.getInputs()[1].getType().cast(); + auto outputType = convOp.getOutputs()[0].getType().cast(); + + if (!filterType.hasStaticShape()) + return rewriter.notifyMatchFailure( + convOp, "expected a static shape for the filter"); + + if (!inputType.hasStaticShape()) + return rewriter.notifyMatchFailure(convOp, + "expected a static shape for the input"); + + // TODO: Support dilation. + if (!hasAllOneValues(convOp.getDilations())) + return rewriter.notifyMatchFailure(convOp, + "expected all ones for dilations"); + + MLIRContext *context = rewriter.getContext(); + Value input = convOp.getInputs()[0]; + Value filter = convOp.getInputs()[1]; + Value output = convOp.getOutputs()[0]; + + ArrayRef filterShape = filterType.getShape(); + ArrayRef outputShape = outputType.getShape(); + + int n = outputShape[0]; + int oh = outputShape[1]; + int ow = outputShape[2]; + int oc = outputShape[3]; + int fh = filterShape[0]; + int fw = filterShape[1]; + int ic = filterShape[2]; + + Location loc = convOp.getLoc(); + + SmallVector colTensorShape = {n, oh, ow, fh, fw, ic}; + + Value colTensor = rewriter.create( + loc, colTensorShape, inputType.getElementType()); + + AffineExpr nDim, ohDim, owDim, khDim, kwDim, icDim; + bindDims(context, nDim, ohDim, owDim, khDim, kwDim, icDim); + + AffineExpr shSym = rewriter.getAffineConstantExpr( + convOp.getStrides().getValues()[0]); + AffineExpr swSym = rewriter.getAffineConstantExpr( + convOp.getStrides().getValues()[1]); + + SmallVector inputExprs = {nDim, ohDim * shSym + khDim, + owDim * swSym + kwDim, icDim}; + + auto nloops = colTensorShape.size(); + + auto parallel = utils::IteratorType::parallel; + auto reduction = utils::IteratorType::reduction; + SmallVector img2colIterators(nloops, parallel); + + SmallVector img2colIndexingMaps = { + AffineMap::get(nloops, 0, inputExprs, context), + AffineMap::getMultiDimIdentityMap(nloops, context)}; + + auto img2ColTensor = rewriter.create( + loc, colTensor.getType(), + /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps, + img2colIterators, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(nestedLoc, args[0]); + }); + + SmallVector img2ColTensorReassocIndices; + SmallVector outputReassocIndices; + RankedTensorType reshapedImg2ColTensorType, reshapedOutputType; + if (n == 1) { + img2ColTensorReassocIndices = {{0, 1, 2}, {3, 4, 5}}; + outputReassocIndices = {{0, 1, 2}, {3}}; + + reshapedImg2ColTensorType = RankedTensorType::get( + {oh * ow, fh * fw * ic}, inputType.getElementType()); + reshapedOutputType = + RankedTensorType::get({oh * ow, oc}, outputType.getElementType()); + } else { + img2ColTensorReassocIndices = {{0}, {1, 2}, {3, 4, 5}}; + outputReassocIndices = {{0}, {1, 2}, {3}}; + + reshapedImg2ColTensorType = RankedTensorType::get( + {n, oh * ow, fh * fw * ic}, inputType.getElementType()); + reshapedOutputType = + RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType()); + } + + SmallVector filterReassocIndices = {{0, 1, 2}, {3}}; + auto reshapedFilterType = + RankedTensorType::get({fh * fw * ic, oc}, inputType.getElementType()); + + Value reshapedImg2ColTensor = rewriter.create( + loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0), + img2ColTensorReassocIndices); + + Value reshapedFilter = rewriter.create( + loc, reshapedFilterType, filter, filterReassocIndices); + + Value reshapedOutput = rewriter.create( + loc, reshapedOutputType, output, outputReassocIndices); + + Value result; + if (n == 1) { + auto matmulOp = rewriter.create( + loc, reshapedOutputType, + ArrayRef{reshapedImg2ColTensor, reshapedFilter}, + ArrayRef{reshapedOutput}); + result = matmulOp.getResults().front(); + } else { + // For cases where batch is not 1, we need to keep the batch dimension + // separate. Because the filter does not share the same batch dimension, + // the batch dimension is only used in indexing the input and output. Thus + // we cannot use existing linalg named ops like linalg.batch_matmul. + // i.e. (B x) M x K * K x N = (B x) M x N + AffineExpr bDim, mDim, nDim, kDim; + bindDims(context, bDim, mDim, nDim, kDim); + auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context); + auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context); + auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context); + SmallVector genericIterators = {parallel, parallel, + parallel, reduction}; + + auto genericOp = rewriter.create( + loc, reshapedOutputType, + /*inputs=*/ValueRange{reshapedImg2ColTensor, reshapedFilter}, + /*outputs=*/ValueRange{reshapedOutput}, + ArrayRef{lhsMap, rhsMap, resultMap}, genericIterators, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + Value mul = createMul(loc, args[0], args[1], nestedBuilder); + Value add = createAdd(loc, mul, args[2], nestedBuilder); + nestedBuilder.create(nestedLoc, add); + }); + result = genericOp.getResults().front(); + } + + auto reshapedResult = rewriter.create( + loc, outputType, result, outputReassocIndices); + + rewriter.replaceOp(convOp, ArrayRef{reshapedResult}); + + return std::make_pair(img2ColTensor.getOperation(), + reshapedResult.getOperation()); +} + +FailureOr> +rewriteInIm2Col(RewriterBase &rewriter, + linalg::DepthwiseConv2DNhwcHwcOp convOp) { + auto inputType = convOp.getInputs()[0].getType().cast(); + auto filterType = convOp.getInputs()[1].getType().cast(); + auto outputType = convOp.getOutputs()[0].getType().cast(); + + if (!filterType.hasStaticShape()) + return rewriter.notifyMatchFailure( + convOp, "expected a static shape for the filter"); + + if (!inputType.hasStaticShape()) + return rewriter.notifyMatchFailure(convOp, + "expected a static shape for the input"); + + // TODO: Support dilation. + if (!hasAllOneValues(convOp.getDilations())) + return rewriter.notifyMatchFailure(convOp, + "expected all ones for dilations"); + + Location loc = convOp.getLoc(); + + auto transposeOperand = [&](Value operand, ArrayRef indices) { + auto operandTensorType = operand.getType().cast(); + auto nloops = indices.size(); + ArrayRef inputShape = operandTensorType.getShape(); + + SmallVector exprs = llvm::to_vector<4>( + llvm::map_range(indices, [&](int64_t index) -> AffineExpr { + return rewriter.getAffineDimExpr(index); + })); + + SmallVector targetShape = llvm::to_vector<4>(llvm::map_range( + indices, [&](int64_t index) -> int64_t { return inputShape[index]; })); + + Value outputTensor = rewriter.create( + loc, targetShape, operandTensorType.getElementType()); + + SmallVector loopAttributeTypes( + nloops, utils::IteratorType::parallel); + + SmallVector indexingMaps = { + inversePermutation( + AffineMap::get(nloops, 0, exprs, rewriter.getContext())), + AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())}; + + auto transposedOp = rewriter.create( + loc, outputTensor.getType(), + /*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps, + loopAttributeTypes, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(nestedLoc, args[0]); + }); + + return transposedOp.getResult(0); + }; + + Value input = convOp.getInputs()[0]; + Value filter = convOp.getInputs()[1]; + Value output = convOp.getOutputs()[0]; + + // Transpose input, filter so channels are outermost + Value inputT = transposeOperand(input, {0, 3, 1, 2}); + Value filterT = transposeOperand(filter, {2, 0, 1}); + ArrayRef filterTShape = + filterT.getType().cast().getShape(); + ArrayRef outputShape = outputType.getShape(); + + int n = outputShape[0]; + int oh = outputShape[1]; + int ow = outputShape[2]; + int c = outputShape[3]; + int fh = filterTShape[1]; + int fw = filterTShape[2]; + + SmallVector colTensorShape = {n, c, oh, ow, fh, fw}; + Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2}); + + AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim; + bindDims(rewriter.getContext(), nDim, cDim, ohDim, owDim, khDim, kwDim); + + AffineExpr shSym = rewriter.getAffineConstantExpr( + convOp.getStrides().getValues()[0]); + AffineExpr swSym = rewriter.getAffineConstantExpr( + convOp.getStrides().getValues()[1]); + + SmallVector inputExprs = {nDim, cDim, ohDim * shSym + khDim, + owDim * swSym + kwDim}; + + auto nloops = colTensorShape.size(); + + SmallVector loopAttributeTypes( + nloops, utils::IteratorType::parallel); + + SmallVector indexingMaps = { + AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()), + AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())}; + + Value colTensor = rewriter.create( + loc, colTensorShape, inputType.getElementType()); + + auto img2ColTensor = rewriter.create( + loc, colTensor.getType(), + /*inputs=*/inputT, /*outputs=*/colTensor, indexingMaps, + loopAttributeTypes, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(nestedLoc, args[0]); + }); + + SmallVector img2ColTensorReassocIndices = { + {0, 1}, {2, 3}, {4, 5}}; + SmallVector filterReassociationIndice = {{0}, {1, 2}}; + SmallVector outputReassociationIndice = {{0, 1}, + {2, 3}}; + + auto reshapedImg2ColTensorType = RankedTensorType::get( + {n * c, oh * ow, fh * fw}, inputType.getElementType()); + auto reshapedFilterTensorType = + RankedTensorType::get({c, fh * fw}, filterType.getElementType()); + auto reshapedOutputTensorType = + RankedTensorType::get({n * c, oh * ow}, outputType.getElementType()); + + Value reshapedImg2ColTensor = rewriter.create( + loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0), + img2ColTensorReassocIndices); + Value reshapedFilterTensor = rewriter.create( + loc, reshapedFilterTensorType, filterT, filterReassociationIndice); + Value reshapedoutputTensor = rewriter.create( + loc, reshapedOutputTensorType, transposedOutputTensor, + outputReassociationIndice); + + auto batchMatVecResult = rewriter.create( + loc, TypeRange{reshapedoutputTensor.getType()}, + ValueRange{reshapedImg2ColTensor, reshapedFilterTensor}, + ValueRange{reshapedoutputTensor}); + + SmallVector batchMatVecReassociationIndice = {{0, 1}, + {2, 3}}; + + Value batchMatVecResultReshaped = rewriter.create( + loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0), + batchMatVecReassociationIndice); + + Value transposedResult = + transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1}); + + rewriter.replaceOp(convOp, ArrayRef{transposedResult}); + return std::make_pair(img2ColTensor.getOperation(), + transposedResult.getDefiningOp()); +} + +FailureOr> +rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { + auto inputType = convOp.getInputs()[0].getType().cast(); + auto filterType = convOp.getInputs()[1].getType().cast(); + auto outputType = convOp.getOutputs()[0].getType().cast(); + + if (!filterType.hasStaticShape()) + return rewriter.notifyMatchFailure( + convOp, "expected a static shape for the filter"); + + if (!inputType.hasStaticShape()) + return rewriter.notifyMatchFailure(convOp, + "expected a static shape for the input"); + + // TODO: Support dilation. + if (!hasAllOneValues(convOp.getDilations())) + return rewriter.notifyMatchFailure(convOp, + "expected all ones for dilations"); + + Value input = convOp.getInputs()[0]; + Value filter = convOp.getInputs()[1]; + Value output = convOp.getOutputs()[0]; + + auto filterShape = filterType.getShape(); + auto outputShape = outputType.getShape(); + + int n = outputShape[0]; + int oc = outputShape[1]; + int oh = outputShape[2]; + int ow = outputShape[3]; + int ic = filterShape[1]; + int fh = filterShape[2]; + int fw = filterShape[3]; + + auto loc = convOp.getLoc(); + + SmallVector colTensorShape = {n, ic, fh, fw, oh, ow}; + + Value colTensor = rewriter.create( + loc, colTensorShape, inputType.getElementType()); + + MLIRContext *context = rewriter.getContext(); + + AffineExpr nDim, icDim, khDim, kwDim, ohDim, owDim; + bindDims(context, nDim, icDim, khDim, kwDim, ohDim, owDim); + + auto shSym = rewriter.getAffineConstantExpr( + convOp.getStrides().getValues()[0]); + auto swSym = rewriter.getAffineConstantExpr( + convOp.getStrides().getValues()[1]); + + SmallVector inputExprs = {nDim, icDim, ohDim * shSym + khDim, + owDim * swSym + kwDim}; + + auto nloops = colTensorShape.size(); + + auto parallel = utils::IteratorType::parallel; + auto reduction = utils::IteratorType::reduction; + SmallVector img2colIterators(nloops, parallel); + + SmallVector img2colIndexingMaps = { + AffineMap::get(nloops, 0, inputExprs, context), + AffineMap::getMultiDimIdentityMap(nloops, context)}; + + auto img2ColTensor = rewriter.create( + loc, colTensor.getType(), + /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps, + img2colIterators, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(nestedLoc, args[0]); + }); + + SmallVector filterReassocIndices = {{0}, {1, 2, 3}}; + auto reshapedFilterType = + RankedTensorType::get({oc, fh * fw * ic}, inputType.getElementType()); + Value reshapedFilter = rewriter.create( + loc, reshapedFilterType, filter, filterReassocIndices); + + SmallVector img2ColTensorReassocIndices; + SmallVector outputReassocIndices; + RankedTensorType reshapedImg2ColTensorType, reshapedOutputType; + if (n == 1) { + img2ColTensorReassocIndices = {{0, 1, 2, 3}, {4, 5}}; + outputReassocIndices = {{0, 1}, {2, 3}}; + + reshapedImg2ColTensorType = RankedTensorType::get( + {fh * fw * ic, oh * ow}, inputType.getElementType()); + reshapedOutputType = + RankedTensorType::get({oc, oh * ow}, outputType.getElementType()); + } else { + img2ColTensorReassocIndices = {{0}, {1, 2, 3}, {4, 5}}; + outputReassocIndices = {{0}, {1}, {2, 3}}; + + reshapedImg2ColTensorType = RankedTensorType::get( + {n, fh * fw * ic, oh * ow}, inputType.getElementType()); + reshapedOutputType = + RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType()); + } + + Value reshapedImg2ColTensor = rewriter.create( + loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0), + img2ColTensorReassocIndices); + + Value reshapedOutput = rewriter.create( + loc, reshapedOutputType, output, outputReassocIndices); + + Value result; + if (n == 1) { + auto matmulOp = rewriter.create( + loc, reshapedOutputType, + ArrayRef{reshapedFilter, reshapedImg2ColTensor}, + ArrayRef{reshapedOutput}); + result = matmulOp.getResults().front(); + } else { + // For cases where batch is not 1, we need to keep the batch dimension + // separate. Because the filter does not share the same batch dimension, + // the batch dimension is only used in indexing the input and output. Thus + // we cannot use existing linalg named ops like linalg.batch_matmul. + // i.e. M x K * (B x) K x N = (B x) M x N + AffineExpr bDim, mDim, nDim, kDim; + bindDims(context, bDim, mDim, nDim, kDim); + auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context); + auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context); + auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context); + SmallVector genericIterators = {parallel, parallel, + parallel, reduction}; + auto genericOp = rewriter.create( + loc, reshapedOutputType, + /*inputs=*/ValueRange{reshapedFilter, reshapedImg2ColTensor}, + /*outputs=*/ValueRange{reshapedOutput}, + ArrayRef{lhsMap, rhsMap, resultMap}, genericIterators, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + Value mul = createMul(loc, args[0], args[1], nestedBuilder); + Value add = createAdd(loc, mul, args[2], nestedBuilder); + nestedBuilder.create(nestedLoc, add); + }); + result = genericOp.getResults().front(); + } + + auto reshapedResult = rewriter.create( + loc, outputType, result, outputReassocIndices); + + rewriter.replaceOp(convOp, ArrayRef{reshapedResult}); + + return std::make_pair(img2ColTensor.getOperation(), + reshapedResult.getOperation()); +} + +namespace { + +class ConvertConv2DNhwcHwcf final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, + PatternRewriter &rewriter) const override { + if (failed(rewriteInIm2Col(rewriter, convOp))) + return failure(); + return success(); + } +}; + +class ConvertDepthwiseConv2DNhwcHwc final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp, + PatternRewriter &rewriter) const override { + if (failed(rewriteInIm2Col(rewriter, convOp))) + return failure(); + return success(); + } +}; + +class ConvertConv2DNchwFchw final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp, + PatternRewriter &rewriter) const override { + if (failed(rewriteInIm2Col(rewriter, convOp))) + return failure(); + return success(); + } +}; +} // end anonymous namespace + +void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.insert(context); +} +} // end namespace linalg +} // end namespace mlir diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir @@ -0,0 +1,245 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics | FileCheck %s + +// Check that the im2col patterns are properly connected with the +// transform dialect. + +// Non static shapes are not supported. +// Check that we emit an error. +// TODO: Hook up the rewriter errors in transform dialect. +func.func @conv_non_static(%arg0: tensor, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor) -> tensor { + // expected-note@below {{when applied to this op}} + %0 = linalg.conv_2d_nhwc_hwcf + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor, tensor<3x3x4x16xf32>) + outs(%arg2: tensor) -> tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!pdl.operation) -> !pdl.operation + // expected-error@below {{failed to apply}} + %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation) +} + +// ----- + +// Check that we get the proper handles for the img2col tensor producer +// and the final instruction. + +// CHECK: IR printer: tensor_producer +// CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>] +// CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32) +// CHECK: linalg.yield %[[IN_DATA]] : f32 + +// CHECK: IR printer: transformed +// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0, 1, 2], [3]] : tensor<196x16xf32> into tensor<1x14x14x16xf32> + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK: @conv_16433136 +// CHECK: %[[INPUT:.+]]: tensor<1x16x16x4xf32> +// CHECK: %[[FILTER:.+]]: tensor<3x3x4x16xf32> +// CHECK: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32> +// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x14x14x3x3x4xf32> +// CHECK: %[[COL_TENSOR:.+]] = linalg.generic +// CHECK-SAME: #[[MAP0]] +// CHECK-SAME: #[[MAP1]] +// CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32) +// CHECK: linalg.yield %[[IN_DATA]] : f32 +// CHECK-DAG: %[[RESHAPED_INIT_COL_TENSOR:.+]] = tensor.collapse_shape %[[COL_TENSOR]] +// CHECK-SAME: [0, 1, 2], [3, 4, 5] +// CHECK-SAME: tensor<1x14x14x3x3x4xf32> into tensor<196x36xf32> +// CHECK-DAG: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] +// CHECK-SAME: [0, 1, 2], [3] +// CHECK-SAME: tensor<3x3x4x16xf32> into tensor<36x16xf32> +// CHECK-DAG: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]] +// CHECK-SAME: [0, 1, 2], [3] +// CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_INIT_COL_TENSOR]], %[[RESHAPED_FILTER]] : tensor<196x36xf32>, tensor<36x16xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<196x16xf32>) +// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] : tensor<196x16xf32> into tensor<1x14x14x16xf32> +// CHECK: return %[[RESULT]] + +func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + %0 = linalg.conv_2d_nhwc_hwcf + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>) + outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + return %0 : tensor<1x14x14x16xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation) + transform.print %img2col_tensor_producer {name = "tensor_producer"}: !pdl.operation + transform.print %transformed {name = "transformed"}: !pdl.operation +} + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)> +// CHECK: @depthwise_conv_hwc_114x16x3 +// CHECK-SAME: %[[INPUT:.+]]: tensor<1x114x114x16xf32> +// CHECK-SAME: %[[FILTER:.+]]: tensor<3x3x16xf32> +// CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x112x112x16xf32> +// CHECK: %[[INPUT_T_INIT:.+]] = tensor.empty() : tensor<1x16x114x114xf32> +// CHECK: %[[INPUT_T:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[INPUT]] : tensor<1x114x114x16xf32>) outs(%[[INPUT_T_INIT]] : tensor<1x16x114x114xf32>) { +// CHECK-NEXT: ^bb0(%[[ARG3:.+]]: f32, %[[ARG4:.+]]: f32): +// CHECK-NEXT: linalg.yield %[[ARG3]] : f32 +// CHECK-NEXT: } -> tensor<1x16x114x114xf32> +// CHECK: %[[FILTER_T_INIT:.+]] = tensor.empty() : tensor<16x3x3xf32> +// CHECK: %[[FILTER_T:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[FILTER]] : tensor<3x3x16xf32>) outs(%[[FILTER_T_INIT]] : tensor<16x3x3xf32>) { +// CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32): +// CHECK: linalg.yield +// CHECK: } -> tensor<16x3x3xf32> +// CHECK: %[[INIT_OUTPUT_TENSOR:.+]] = tensor.empty() : tensor<1x16x112x112xf32> +// CHECK: %[[OUTPUT_T:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[OUTPUT]] : tensor<1x112x112x16xf32>) outs(%[[INIT_OUTPUT_TENSOR]] : tensor<1x16x112x112xf32>) { +// CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32): +// CHECK-NEXT: linalg.yield +// CHECK-NEXT: } -> tensor<1x16x112x112xf32> +// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x16x112x112x3x3xf32> +// CHECK: %[[COL_TENSOR:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP4]], #[[MAP5]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[INPUT_T]] : tensor<1x16x114x114xf32>) outs(%[[INIT_COL_TENSOR]] : tensor<1x16x112x112x3x3xf32>) { +// CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32): +// CHECK-NEXT: linalg.yield +// CHECK-NEXT: } -> tensor<1x16x112x112x3x3xf32> +// CHECK: %[[COL_TENSOR_R:.+]] = tensor.collapse_shape %[[COL_TENSOR]] +// CHECK-SAME: tensor<1x16x112x112x3x3xf32> into tensor<16x12544x9xf32> +// CHECK: %[[FILTER_T_R:.+]] = tensor.collapse_shape %[[FILTER_T]] +// CHECK-SAME: tensor<16x3x3xf32> into tensor<16x9xf32> +// CHECK: %[[OUTPUT_T_R:.+]] = tensor.collapse_shape %[[OUTPUT_T]] +// CHECK-SAME: tensor<1x16x112x112xf32> into tensor<16x12544xf32> +// CHECK: %[[BMV_RESULT:.+]] = linalg.batch_matvec ins(%[[COL_TENSOR_R]], %[[FILTER_T_R]] : tensor<16x12544x9xf32>, tensor<16x9xf32>) outs(%[[OUTPUT_T_R]] : tensor<16x12544xf32>) -> tensor<16x12544xf32> +// CHECK: %[[RESULT_R:.+]] = tensor.expand_shape %[[BMV_RESULT]] +// CHECK-SAME: tensor<16x12544xf32> into tensor<1x16x112x112xf32> +// CHECK: %[[RESULT_INIT:.+]] = tensor.empty() : tensor<1x112x112x16xf32> +// CHECK: %[[RESULT:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP6]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[RESULT_R]] : tensor<1x16x112x112xf32>) outs(%[[RESULT_INIT]] : tensor<1x112x112x16xf32>) { +// CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32): +// CHECK-NEXT: linalg.yield +// CHECK-NEXT: } -> tensor<1x112x112x16xf32> +// CHECK: return %[[RESULT]] : tensor<1x112x112x16xf32> +func.func @depthwise_conv_hwc_114x16x3(%input: tensor<1x114x114x16xf32>, %filter: tensor<3x3x16xf32>, %output: tensor<1x112x112x16xf32>) -> tensor<1x112x112x16xf32> { + %0 = linalg.depthwise_conv_2d_nhwc_hwc { + dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64> + } ins(%input, %filter : tensor<1x114x114x16xf32>, tensor<3x3x16xf32>) outs(%output : tensor<1x112x112x16xf32>) -> tensor<1x112x112x16xf32> + return %0 : tensor<1x112x112x16xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.depthwise_conv_2d_nhwc_hwc"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation) +} + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +// CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK: func.func @batch_nhwc_conv +// CHECK-SAME: (%[[INPUT:.+]]: tensor<8x16x16x4xf32>, %[[FILTER:.+]]: tensor<3x3x4x16xf32>, %[[INIT:.+]]: tensor<8x14x14x16xf32>) +// CHECK: %[[IT:.+]] = tensor.empty() : tensor<8x14x14x3x3x4xf32> +// CHECK: %[[IMG2COL:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[INPUT]] : tensor<8x16x16x4xf32>) +// CHECK-SAME: outs(%[[IT]] : tensor<8x14x14x3x3x4xf32>) +// CHECK: %[[CS_INPUT:.+]] = tensor.collapse_shape %[[IMG2COL]] {{\[}}[0], [1, 2], [3, 4, 5]] : tensor<8x14x14x3x3x4xf32> into tensor<8x196x36xf32> +// CHECK: %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32> +// CHECK: %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3]] : tensor<8x14x14x16xf32> into tensor<8x196x16xf32> +// CHECK: %[[MATMUL:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: ins(%[[CS_INPUT]], %[[CS_FILTER]] : tensor<8x196x36xf32>, tensor<36x16xf32>) +// CHECK-SAME: outs(%[[CS_RESULT]] : tensor<8x196x16xf32>) +// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32): +// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32 +// CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32 +// CHECK: linalg.yield %[[ADD]] : f32 +// CHECK: } -> tensor<8x196x16xf32> +// CHECK: %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1, 2], [3]] : tensor<8x196x16xf32> into tensor<8x14x14x16xf32> +// CHECK: return %[[CS_FINAL]] +func.func @batch_nhwc_conv(%arg0: tensor<8x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<8x14x14x16xf32>) -> tensor<8x14x14x16xf32> { + %0 = linalg.conv_2d_nhwc_hwcf + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<8x16x16x4xf32>, tensor<3x3x4x16xf32>) + outs(%arg2: tensor<8x14x14x16xf32>) -> tensor<8x14x14x16xf32> + return %0 : tensor<8x14x14x16xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation) +} + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4 + d2, d5 + d3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +// CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK: func.func @batch_nchw_conv +// CHECK-SAME: (%[[INPUT:.+]]: tensor<8x4x16x16xf32>, %[[FILTER:.+]]: tensor<16x4x3x3xf32>, %[[INIT:.+]]: tensor<8x16x14x14xf32>) +// CHECK: %[[IT:.+]] = tensor.empty() : tensor<8x4x3x3x14x14xf32> +// CHECK: %[[IMG2COL:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[INPUT]] : tensor<8x4x16x16xf32>) +// CHECK-SAME: outs(%[[IT]] : tensor<8x4x3x3x14x14xf32>) +// CHECK: %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<16x4x3x3xf32> into tensor<16x36xf32> +// CHECK: %[[CS_INPUT:.+]] = tensor.collapse_shape %[[IMG2COL]] {{\[}}[0], [1, 2, 3], [4, 5]] : tensor<8x4x3x3x14x14xf32> into tensor<8x36x196xf32> +// CHECK: %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x14x14xf32> into tensor<8x16x196xf32> +// CHECK: %[[MATMUL:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: ins(%[[CS_FILTER]], %[[CS_INPUT]] : tensor<16x36xf32>, tensor<8x36x196xf32>) +// CHECK-SAME: outs(%[[CS_RESULT]] : tensor<8x16x196xf32>) +// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32): +// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32 +// CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32 +// CHECK: linalg.yield %[[ADD]] : f32 +// CHECK: } -> tensor<8x16x196xf32> +// CHECK: %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x196xf32> into tensor<8x16x14x14xf32> +// CHECK: return %[[CS_FINAL]] +func.func @batch_nchw_conv(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> { + %0 = linalg.conv_2d_nchw_fchw + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>) + outs(%arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> + return %0 : tensor<8x16x14x14xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation) +}