diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp @@ -41,6 +41,49 @@ return builder.create(loc, x, y); } +// Unrolls the given composite `index` into a set of subindices with maximum +// iteration ranges specified by `factors` according to the following +// assumptions: +// 1. The iteration range for `index` is [0, f1 * f2 * ... * fn] i.e. the +// product of the given list of factors +// 2. The iterators corresponding to the entries in `factors` are ordered from +// slowest to fastest varying +// Each subindex is then computed as: +// subindex[i] = floor( (index % (fi * ... * fn)) / (fi-1 * ... * fn) ) +static SmallVector unrollIndex(OpBuilder &b, Location loc, + Value index, + ArrayRef factors) { + assert(factors.size() >= 1 && "empty factor list"); + SmallVector indices(factors.size()); + int64_t runningProd = 1; + for (int i = factors.size() - 1, end = 0; i >= end; i--) { + Value unrolledIndex = index; + if (i > 0) { + Value modBase = b.create( + loc, b.getIndexAttr(runningProd * factors[i])); + unrolledIndex = b.create(loc, unrolledIndex, modBase); + } + if (runningProd > 1) { + Value divDenom = + b.create(loc, b.getIndexAttr(runningProd)); + unrolledIndex = b.create(loc, unrolledIndex, divDenom); + } + runningProd *= factors[i]; + indices[i] = unrolledIndex; + } + return indices; +} + +// Given indices corresponding to iterators in the output (oIndex) and filter +// (fIndex) for a convolution, compute the convolved index for the +// input as `oIndex * stride + fIndex`. +static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex, + Value fIndex, int64_t stride) { + Value strideVal = b.create(loc, b.getIndexAttr(stride)); + Value convIndex = b.create(loc, oIndex, strideVal); + return b.create(loc, convIndex, fIndex); +} + FailureOr> rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { auto inputType = convOp.getInputs()[0].getType().cast(); @@ -68,32 +111,34 @@ ArrayRef filterShape = filterType.getShape(); ArrayRef outputShape = outputType.getShape(); - int n = outputShape[0]; - int oh = outputShape[1]; - int ow = outputShape[2]; - int oc = outputShape[3]; - int fh = filterShape[0]; - int fw = filterShape[1]; - int ic = filterShape[2]; + int64_t n = outputShape[0]; + int64_t oh = outputShape[1]; + int64_t ow = outputShape[2]; + int64_t oc = outputShape[3]; + int64_t fh = filterShape[0]; + int64_t fw = filterShape[1]; + int64_t ic = filterShape[2]; Location loc = convOp.getLoc(); - SmallVector colTensorShape = {n, oh, ow, fh, fw, ic}; + // Reshape output and filter to the LHS and result of a (B)MNK matmul. + SmallVector filterReassocIndices = {{0, 1, 2}, {3}}; + auto reshapedFilterType = + RankedTensorType::get({fh * fw * ic, oc}, inputType.getElementType()); + Value reshapedFilter = rewriter.create( + loc, reshapedFilterType, filter, filterReassocIndices); + + SmallVector outputReassocIndices = {{0}, {1, 2}, {3}}; + RankedTensorType reshapedOutputType = + RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType()); + Value reshapedOutput = rewriter.create( + loc, reshapedOutputType, output, outputReassocIndices); + SmallVector colTensorShape = {n, oh * ow, fh * fw * ic}; Value colTensor = rewriter.create( loc, colTensorShape, inputType.getElementType()); - AffineExpr nDim, ohDim, owDim, khDim, kwDim, icDim; - bindDims(context, nDim, ohDim, owDim, khDim, kwDim, icDim); - - AffineExpr shSym = rewriter.getAffineConstantExpr( - convOp.getStrides().getValues()[0]); - AffineExpr swSym = rewriter.getAffineConstantExpr( - convOp.getStrides().getValues()[1]); - - SmallVector inputExprs = {nDim, ohDim * shSym + khDim, - owDim * swSym + kwDim, icDim}; - + // Convert the input to a (BMK) column tensor. auto nloops = colTensorShape.size(); auto parallel = utils::IteratorType::parallel; @@ -101,85 +146,68 @@ SmallVector img2colIterators(nloops, parallel); SmallVector img2colIndexingMaps = { - AffineMap::get(nloops, 0, inputExprs, context), AffineMap::getMultiDimIdentityMap(nloops, context)}; auto img2ColTensor = rewriter.create( loc, colTensor.getType(), - /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps, + /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps, img2colIterators, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - nestedBuilder.create(nestedLoc, args[0]); + // Get the iterators named based on the matmul (batch, m, k). + Value bIndex = nestedBuilder.create(loc, 0); + Value mIndex = nestedBuilder.create(loc, 1); + Value kIndex = nestedBuilder.create(loc, 2); + + // Recover the original iteration indices from the problem/input sizes. + SmallVector mIndices = unrollIndex( + nestedBuilder, nestedLoc, mIndex, ArrayRef{oh, ow}); + auto ohIndex = mIndices[0]; + auto owIndex = mIndices[1]; + + SmallVector kIndices = unrollIndex( + nestedBuilder, nestedLoc, kIndex, ArrayRef{fh, fw, ic}); + auto fhIndex = kIndices[0]; + auto fwIndex = kIndices[1]; + auto icIndex = kIndices[2]; + + // Extract the input element corresponding to the expanded indices. + Value hIndex = + getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex, + convOp.getStrides().getValues()[0]); + Value wIndex = + getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex, + convOp.getStrides().getValues()[1]); + + // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic] + SmallVector extractionIndices{bIndex, hIndex, wIndex, icIndex}; + Value inputVal = nestedBuilder.create( + loc, input, extractionIndices); + nestedBuilder.create(nestedLoc, inputVal); }); - SmallVector img2ColTensorReassocIndices; - SmallVector outputReassocIndices; - RankedTensorType reshapedImg2ColTensorType, reshapedOutputType; - if (n == 1) { - img2ColTensorReassocIndices = {{0, 1, 2}, {3, 4, 5}}; - outputReassocIndices = {{0, 1, 2}, {3}}; - - reshapedImg2ColTensorType = RankedTensorType::get( - {oh * ow, fh * fw * ic}, inputType.getElementType()); - reshapedOutputType = - RankedTensorType::get({oh * ow, oc}, outputType.getElementType()); - } else { - img2ColTensorReassocIndices = {{0}, {1, 2}, {3, 4, 5}}; - outputReassocIndices = {{0}, {1, 2}, {3}}; - - reshapedImg2ColTensorType = RankedTensorType::get( - {n, oh * ow, fh * fw * ic}, inputType.getElementType()); - reshapedOutputType = - RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType()); - } - - SmallVector filterReassocIndices = {{0, 1, 2}, {3}}; - auto reshapedFilterType = - RankedTensorType::get({fh * fw * ic, oc}, inputType.getElementType()); - - Value reshapedImg2ColTensor = rewriter.create( - loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0), - img2ColTensorReassocIndices); - - Value reshapedFilter = rewriter.create( - loc, reshapedFilterType, filter, filterReassocIndices); - - Value reshapedOutput = rewriter.create( - loc, reshapedOutputType, output, outputReassocIndices); - - Value result; - if (n == 1) { - auto matmulOp = rewriter.create( - loc, reshapedOutputType, - ArrayRef{reshapedImg2ColTensor, reshapedFilter}, - ArrayRef{reshapedOutput}); - result = matmulOp.getResults().front(); - } else { - // For cases where batch is not 1, we need to keep the batch dimension - // separate. Because the filter does not share the same batch dimension, - // the batch dimension is only used in indexing the input and output. Thus - // we cannot use existing linalg named ops like linalg.batch_matmul. - // i.e. (B x) M x K * K x N = (B x) M x N - AffineExpr bDim, mDim, nDim, kDim; - bindDims(context, bDim, mDim, nDim, kDim); - auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context); - auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context); - auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context); - SmallVector genericIterators = {parallel, parallel, - parallel, reduction}; - - auto genericOp = rewriter.create( - loc, reshapedOutputType, - /*inputs=*/ValueRange{reshapedImg2ColTensor, reshapedFilter}, - /*outputs=*/ValueRange{reshapedOutput}, - ArrayRef{lhsMap, rhsMap, resultMap}, genericIterators, - [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value mul = createMul(loc, args[0], args[1], nestedBuilder); - Value add = createAdd(loc, mul, args[2], nestedBuilder); - nestedBuilder.create(nestedLoc, add); - }); - result = genericOp.getResults().front(); - } + // Because the filter does not share the same batch dimension, + // the batch dimension is only used in indexing the input and output. Thus + // we cannot use existing linalg named ops like linalg.batch_matmul. + // i.e. (B x) M x K * K x N = (B x) M x N + AffineExpr bDim, mDim, nDim, kDim; + bindDims(context, bDim, mDim, nDim, kDim); + auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context); + auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context); + auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context); + SmallVector genericIterators = {parallel, parallel, + parallel, reduction}; + + auto genericOp = rewriter.create( + loc, reshapedOutputType, + /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter}, + /*outputs=*/ValueRange{reshapedOutput}, + ArrayRef{lhsMap, rhsMap, resultMap}, genericIterators, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + Value mul = createMul(loc, args[0], args[1], nestedBuilder); + Value add = createAdd(loc, mul, args[2], nestedBuilder); + nestedBuilder.create(nestedLoc, add); + }); + Value result = genericOp.getResults().front(); auto reshapedResult = rewriter.create( loc, outputType, result, outputReassocIndices); @@ -367,33 +395,33 @@ auto filterShape = filterType.getShape(); auto outputShape = outputType.getShape(); - int n = outputShape[0]; - int oc = outputShape[1]; - int oh = outputShape[2]; - int ow = outputShape[3]; - int ic = filterShape[1]; - int fh = filterShape[2]; - int fw = filterShape[3]; + int64_t n = outputShape[0]; + int64_t oc = outputShape[1]; + int64_t oh = outputShape[2]; + int64_t ow = outputShape[3]; + int64_t ic = filterShape[1]; + int64_t fh = filterShape[2]; + int64_t fw = filterShape[3]; auto loc = convOp.getLoc(); - - SmallVector colTensorShape = {n, ic, fh, fw, oh, ow}; - - Value colTensor = rewriter.create( - loc, colTensorShape, inputType.getElementType()); - MLIRContext *context = rewriter.getContext(); - AffineExpr nDim, icDim, khDim, kwDim, ohDim, owDim; - bindDims(context, nDim, icDim, khDim, kwDim, ohDim, owDim); + SmallVector filterReassocIndices = {{0}, {1, 2, 3}}; + auto reshapedFilterType = + RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType()); + Value reshapedFilter = rewriter.create( + loc, reshapedFilterType, filter, filterReassocIndices); - auto shSym = rewriter.getAffineConstantExpr( - convOp.getStrides().getValues()[0]); - auto swSym = rewriter.getAffineConstantExpr( - convOp.getStrides().getValues()[1]); + SmallVector outputReassocIndices = {{0}, {1}, {2, 3}}; + auto reshapedOutputType = + RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType()); + Value reshapedOutput = rewriter.create( + loc, reshapedOutputType, output, outputReassocIndices); - SmallVector inputExprs = {nDim, icDim, ohDim * shSym + khDim, - owDim * swSym + kwDim}; + // Convert the input to a (BKN) tensor. + SmallVector colTensorShape = {n, ic * fh * fw, oh * ow}; + Value colTensor = rewriter.create( + loc, colTensorShape, inputType.getElementType()); auto nloops = colTensorShape.size(); @@ -402,83 +430,67 @@ SmallVector img2colIterators(nloops, parallel); SmallVector img2colIndexingMaps = { - AffineMap::get(nloops, 0, inputExprs, context), AffineMap::getMultiDimIdentityMap(nloops, context)}; auto img2ColTensor = rewriter.create( loc, colTensor.getType(), - /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps, + /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps, img2colIterators, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - nestedBuilder.create(nestedLoc, args[0]); + // Get the iterators named based on the matmul (batch, m, k). + Value bIndex = nestedBuilder.create(loc, 0); + Value kIndex = nestedBuilder.create(loc, 1); + Value nIndex = nestedBuilder.create(loc, 2); + + // Recover the original iteration indices from the problem/input sizes. + SmallVector kIndices = unrollIndex( + nestedBuilder, nestedLoc, kIndex, ArrayRef{ic, fh, fw}); + auto icIndex = kIndices[0]; + auto fhIndex = kIndices[1]; + auto fwIndex = kIndices[2]; + + SmallVector nIndices = unrollIndex( + nestedBuilder, nestedLoc, nIndex, ArrayRef{oh, ow}); + auto ohIndex = nIndices[0]; + auto owIndex = nIndices[1]; + + // Extract the input element corresponding to the expanded indices. + Value hIndex = + getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex, + convOp.getStrides().getValues()[0]); + Value wIndex = + getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex, + convOp.getStrides().getValues()[1]); + + // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw] + SmallVector extractionIndices{bIndex, icIndex, hIndex, wIndex}; + Value inputVal = nestedBuilder.create( + loc, input, extractionIndices); + nestedBuilder.create(nestedLoc, inputVal); }); - SmallVector filterReassocIndices = {{0}, {1, 2, 3}}; - auto reshapedFilterType = - RankedTensorType::get({oc, fh * fw * ic}, inputType.getElementType()); - Value reshapedFilter = rewriter.create( - loc, reshapedFilterType, filter, filterReassocIndices); - - SmallVector img2ColTensorReassocIndices; - SmallVector outputReassocIndices; - RankedTensorType reshapedImg2ColTensorType, reshapedOutputType; - if (n == 1) { - img2ColTensorReassocIndices = {{0, 1, 2, 3}, {4, 5}}; - outputReassocIndices = {{0, 1}, {2, 3}}; - - reshapedImg2ColTensorType = RankedTensorType::get( - {fh * fw * ic, oh * ow}, inputType.getElementType()); - reshapedOutputType = - RankedTensorType::get({oc, oh * ow}, outputType.getElementType()); - } else { - img2ColTensorReassocIndices = {{0}, {1, 2, 3}, {4, 5}}; - outputReassocIndices = {{0}, {1}, {2, 3}}; - - reshapedImg2ColTensorType = RankedTensorType::get( - {n, fh * fw * ic, oh * ow}, inputType.getElementType()); - reshapedOutputType = - RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType()); - } - - Value reshapedImg2ColTensor = rewriter.create( - loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0), - img2ColTensorReassocIndices); - - Value reshapedOutput = rewriter.create( - loc, reshapedOutputType, output, outputReassocIndices); - - Value result; - if (n == 1) { - auto matmulOp = rewriter.create( - loc, reshapedOutputType, - ArrayRef{reshapedFilter, reshapedImg2ColTensor}, - ArrayRef{reshapedOutput}); - result = matmulOp.getResults().front(); - } else { - // For cases where batch is not 1, we need to keep the batch dimension - // separate. Because the filter does not share the same batch dimension, - // the batch dimension is only used in indexing the input and output. Thus - // we cannot use existing linalg named ops like linalg.batch_matmul. - // i.e. M x K * (B x) K x N = (B x) M x N - AffineExpr bDim, mDim, nDim, kDim; - bindDims(context, bDim, mDim, nDim, kDim); - auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context); - auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context); - auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context); - SmallVector genericIterators = {parallel, parallel, - parallel, reduction}; - auto genericOp = rewriter.create( - loc, reshapedOutputType, - /*inputs=*/ValueRange{reshapedFilter, reshapedImg2ColTensor}, - /*outputs=*/ValueRange{reshapedOutput}, - ArrayRef{lhsMap, rhsMap, resultMap}, genericIterators, - [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value mul = createMul(loc, args[0], args[1], nestedBuilder); - Value add = createAdd(loc, mul, args[2], nestedBuilder); - nestedBuilder.create(nestedLoc, add); - }); - result = genericOp.getResults().front(); - } + // Because the filter does not share the same batch dimension, + // the batch dimension is only used in indexing the input and output. Thus + // we cannot use existing linalg named ops like linalg.batch_matmul. + // i.e. M x K * (B x) K x N = (B x) M x N + AffineExpr bDim, mDim, nDim, kDim; + bindDims(context, bDim, mDim, nDim, kDim); + auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context); + auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context); + auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context); + SmallVector genericIterators = {parallel, parallel, + parallel, reduction}; + auto genericOp = rewriter.create( + loc, reshapedOutputType, + /*inputs=*/ValueRange{reshapedFilter, img2ColTensor.getResult(0)}, + /*outputs=*/ValueRange{reshapedOutput}, + ArrayRef{lhsMap, rhsMap, resultMap}, genericIterators, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + Value mul = createMul(loc, args[0], args[1], nestedBuilder); + Value add = createAdd(loc, mul, args[2], nestedBuilder); + nestedBuilder.create(nestedLoc, add); + }); + Value result = genericOp.getResults().front(); auto reshapedResult = rewriter.create( loc, outputType, result, outputReassocIndices); diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir --- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir +++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir @@ -29,36 +29,71 @@ // CHECK: IR printer: tensor_producer // CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic -// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)>, -// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>] -// CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32) -// CHECK: linalg.yield %[[IN_DATA]] : f32 +// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>] +// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32) + +// Collapsed indices. +// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index +// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index +// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index + +// Unrolled output shape indices. +// CHECK: %[[C14:.+]] = arith.constant 14 : index +// CHECK: %[[OWINDEX:.+]] = arith.remui %[[MINDEX]], %[[C14]] : index +// CHECK: %[[C14_1:.+]] = arith.constant 14 : index +// CHECK: %[[OHINDEX:.+]] = arith.divui %[[MINDEX]], %[[C14_1]] : index + +// Unrolled filter shape indices. +// CHECK: %[[C4:.+]] = arith.constant 4 : index +// CHECK: %[[ICINDEX:.+]] = arith.remui %[[KINDEX]], %[[C4]] : index +// CHECK: %[[C12:.+]] = arith.constant 12 : index +// CHECK: %[[FWREM:.+]] = arith.remui %[[KINDEX]], %[[C12]] : index +// CHECK: %[[C4_2:.+]] = arith.constant 4 : index +// CHECK: %[[FWINDEX:.+]] = arith.divui %[[FWREM]], %[[C4_2]] : index +// CHECK: %[[C12_3:.+]] = arith.constant 12 : index +// CHECK: %[[FHINDEX:.+]] = arith.divui %[[KINDEX]], %[[C12_3]] : index + +// Compute input indices. +// CHECK: %[[SH:.+]] = arith.constant 1 : index +// CHECK: %[[STRIDEDOH:.+]] = arith.muli %[[OHINDEX]], %[[SH]] : index +// CHECK: %[[CONVH:.+]] = arith.addi %[[STRIDEDOH]], %[[FHINDEX]] : index +// CHECK: %[[SW:.+]] = arith.constant 1 : index +// CHECK: %[[STRIDEDOW:.+]] = arith.muli %[[OWINDEX]], %[[SW]] : index +// CHECK: %[[CONVW:.+]] = arith.addi %[[STRIDEDOW]], %[[FWINDEX]] : index +// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract +// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32> +// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32 // CHECK: IR printer: transformed -// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0, 1, 2], [3]] : tensor<196x16xf32> into tensor<1x14x14x16xf32> +// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32> -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> // CHECK: @conv_16433136 -// CHECK: %[[INPUT:.+]]: tensor<1x16x16x4xf32> -// CHECK: %[[FILTER:.+]]: tensor<3x3x4x16xf32> -// CHECK: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32> -// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x14x14x3x3x4xf32> +// CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32> +// CHECK-SAME: %[[FILTER:.+]]: tensor<3x3x4x16xf32> +// CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32> +// CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32> +// CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32> +// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32> // CHECK: %[[COL_TENSOR:.+]] = linalg.generic // CHECK-SAME: #[[MAP0]] +// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32) +// CHECK: linalg.yield %{{.+}} : f32 +// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic // CHECK-SAME: #[[MAP1]] -// CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32) -// CHECK: linalg.yield %[[IN_DATA]] : f32 -// CHECK-DAG: %[[RESHAPED_INIT_COL_TENSOR:.+]] = tensor.collapse_shape %[[COL_TENSOR]] -// CHECK-SAME: [0, 1, 2], [3, 4, 5] -// CHECK-SAME: tensor<1x14x14x3x3x4xf32> into tensor<196x36xf32> -// CHECK-DAG: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] -// CHECK-SAME: [0, 1, 2], [3] -// CHECK-SAME: tensor<3x3x4x16xf32> into tensor<36x16xf32> -// CHECK-DAG: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]] -// CHECK-SAME: [0, 1, 2], [3] -// CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_INIT_COL_TENSOR]], %[[RESHAPED_FILTER]] : tensor<196x36xf32>, tensor<36x16xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<196x16xf32>) -// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] : tensor<196x16xf32> into tensor<1x14x14x16xf32> +// CHECK-SAME: #[[MAP2]] +// CHECK-SAME: #[[MAP3]] +// CHECK-SAME: ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<36x16xf32>) +// CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xf32>) +// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32) +// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32 +// CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32 +// CHECK: linalg.yield %[[ADD]] : f32 +// CHECK: } -> tensor<1x196x16xf32> +// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32> // CHECK: return %[[RESULT]] func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { @@ -156,27 +191,24 @@ // ----- -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> // CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> // CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> // CHECK: func.func @batch_nhwc_conv // CHECK-SAME: (%[[INPUT:.+]]: tensor<8x16x16x4xf32>, %[[FILTER:.+]]: tensor<3x3x4x16xf32>, %[[INIT:.+]]: tensor<8x14x14x16xf32>) -// CHECK: %[[IT:.+]] = tensor.empty() : tensor<8x14x14x3x3x4xf32> +// CHECK-DAG: %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32> +// CHECK-DAG: %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3]] : tensor<8x14x14x16xf32> into tensor<8x196x16xf32> +// CHECK: %[[IT:.+]] = tensor.empty() : tensor<8x196x36xf32> // CHECK: %[[IMG2COL:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"] -// CHECK-SAME: ins(%[[INPUT]] : tensor<8x16x16x4xf32>) -// CHECK-SAME: outs(%[[IT]] : tensor<8x14x14x3x3x4xf32>) -// CHECK: %[[CS_INPUT:.+]] = tensor.collapse_shape %[[IMG2COL]] {{\[}}[0], [1, 2], [3, 4, 5]] : tensor<8x14x14x3x3x4xf32> into tensor<8x196x36xf32> -// CHECK: %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32> -// CHECK: %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3]] : tensor<8x14x14x16xf32> into tensor<8x196x16xf32> +// CHECK-SAME: indexing_maps = [#[[MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: outs(%[[IT]] : tensor<8x196x36xf32>) // CHECK: %[[MATMUL:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] -// CHECK-SAME: ins(%[[CS_INPUT]], %[[CS_FILTER]] : tensor<8x196x36xf32>, tensor<36x16xf32>) +// CHECK-SAME: ins(%[[IMG2COL]], %[[CS_FILTER]] : tensor<8x196x36xf32>, tensor<36x16xf32>) // CHECK-SAME: outs(%[[CS_RESULT]] : tensor<8x196x16xf32>) // CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32): // CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32 @@ -201,27 +233,55 @@ // ----- -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4 + d2, d5 + d3)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> // CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> // CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> // CHECK: func.func @batch_nchw_conv // CHECK-SAME: (%[[INPUT:.+]]: tensor<8x4x16x16xf32>, %[[FILTER:.+]]: tensor<16x4x3x3xf32>, %[[INIT:.+]]: tensor<8x16x14x14xf32>) -// CHECK: %[[IT:.+]] = tensor.empty() : tensor<8x4x3x3x14x14xf32> +// CHECK-DAG: %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<16x4x3x3xf32> into tensor<16x36xf32> +// CHECK-DAG: %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x14x14xf32> into tensor<8x16x196xf32> +// CHECK: %[[IT:.+]] = tensor.empty() : tensor<8x36x196xf32> // CHECK: %[[IMG2COL:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"] -// CHECK-SAME: ins(%[[INPUT]] : tensor<8x4x16x16xf32>) -// CHECK-SAME: outs(%[[IT]] : tensor<8x4x3x3x14x14xf32>) -// CHECK: %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<16x4x3x3xf32> into tensor<16x36xf32> -// CHECK: %[[CS_INPUT:.+]] = tensor.collapse_shape %[[IMG2COL]] {{\[}}[0], [1, 2, 3], [4, 5]] : tensor<8x4x3x3x14x14xf32> into tensor<8x36x196xf32> -// CHECK: %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x14x14xf32> into tensor<8x16x196xf32> +// CHECK-SAME: indexing_maps = [#[[MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: outs(%[[IT]] : tensor<8x36x196xf32>) +// Collapsed indices. +// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index +// CHECK: %[[KINDEX:.+]] = linalg.index 1 : index +// CHECK: %[[NINDEX:.+]] = linalg.index 2 : index + +// Unrolled filter shape indices. +// CHECK: %[[C3:.+]] = arith.constant 3 : index +// CHECK: %[[FWINDEX:.+]] = arith.remui %[[KINDEX]], %[[C3]] : index +// CHECK: %[[C9:.+]] = arith.constant 9 : index +// CHECK: %[[FHREM:.+]] = arith.remui %[[KINDEX]], %[[C9]] : index +// CHECK: %[[C3_1:.+]] = arith.constant 3 : index +// CHECK: %[[FHINDEX:.+]] = arith.divui %[[FHREM]], %[[C3_1]] : index +// CHECK: %[[C9_2:.+]] = arith.constant 9 : index +// CHECK: %[[ICINDEX:.+]] = arith.divui %[[KINDEX]], %[[C9_2]] : index + +// Unrolled output shape indices. +// CHECK: %[[C14:.+]] = arith.constant 14 : index +// CHECK: %[[OWINDEX:.+]] = arith.remui %[[NINDEX]], %[[C14]] : index +// CHECK: %[[C14_3:.+]] = arith.constant 14 : index +// CHECK: %[[OHINDEX:.+]] = arith.divui %[[NINDEX]], %[[C14_3]] : index + +// Compute input indices. +// CHECK: %[[SH:.+]] = arith.constant 1 : index +// CHECK: %[[STRIDEDOH:.+]] = arith.muli %[[OHINDEX]], %[[SH]] : index +// CHECK: %[[CONVH:.+]] = arith.addi %[[STRIDEDOH]], %[[FHINDEX]] : index +// CHECK: %[[SW:.+]] = arith.constant 1 : index +// CHECK: %[[STRIDEDOW:.+]] = arith.muli %[[OWINDEX]], %[[SW]] : index +// CHECK: %[[CONVW:.+]] = arith.addi %[[STRIDEDOW]], %[[FWINDEX]] : index +// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract +// CHECK-SAME: %[[INPUT]]{{\[}}%[[BINDEX]], %[[ICINDEX]], %[[CONVH]], %[[CONVW]]] : tensor<8x4x16x16xf32> +// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32 // CHECK: %[[MATMUL:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] -// CHECK-SAME: ins(%[[CS_FILTER]], %[[CS_INPUT]] : tensor<16x36xf32>, tensor<8x36x196xf32>) +// CHECK-SAME: ins(%[[CS_FILTER]], %[[IMG2COL]] : tensor<16x36xf32>, tensor<8x36x196xf32>) // CHECK-SAME: outs(%[[CS_RESULT]] : tensor<8x16x196xf32>) // CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32): // CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32