diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" @@ -41,37 +42,17 @@ return builder.create(loc, x, y); } -// Unrolls the given composite `index` into a set of subindices with maximum -// iteration ranges specified by `factors` according to the following -// assumptions: -// 1. The iteration range for `index` is [0, f1 * f2 * ... * fn] i.e. the -// product of the given list of factors -// 2. The iterators corresponding to the entries in `factors` are ordered from -// slowest to fastest varying -// Each subindex is then computed as: -// subindex[i] = floor( (index % (fi * ... * fn)) / (fi-1 * ... * fn) ) -static SmallVector unrollIndex(OpBuilder &b, Location loc, - Value index, - ArrayRef factors) { +// Delinearizes the given composite `index` by the basis specified in `factors`. +static SmallVector unrollIndex(OpBuilder &b, Location loc, Value index, + ArrayRef factors) { assert(factors.size() >= 1 && "empty factor list"); - SmallVector indices(factors.size()); - int64_t runningProd = 1; - for (int i = factors.size() - 1, end = 0; i >= end; i--) { - Value unrolledIndex = index; - if (i > 0) { - Value modBase = b.create( - loc, b.getIndexAttr(runningProd * factors[i])); - unrolledIndex = b.create(loc, unrolledIndex, modBase); - } - if (runningProd > 1) { - Value divDenom = - b.create(loc, b.getIndexAttr(runningProd)); - unrolledIndex = b.create(loc, unrolledIndex, divDenom); - } - runningProd *= factors[i]; - indices[i] = unrolledIndex; - } - return indices; + SmallVector basis; + for (int64_t f : factors) + basis.push_back(b.create(loc, b.getIndexAttr(f))); + FailureOr> multiIndex = + delinearizeIndex(b, loc, index, basis); + assert(!failed(multiIndex) && "Failed to linearize img2col index"); + return *multiIndex; } // Given indices corresponding to iterators in the output (oIndex) and filter @@ -79,9 +60,10 @@ // input as `oIndex * stride + fIndex`. static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex, Value fIndex, int64_t stride) { - Value strideVal = b.create(loc, b.getIndexAttr(stride)); - Value convIndex = b.create(loc, oIndex, strideVal); - return b.create(loc, convIndex, fIndex); + AffineExpr oExpr, fExpr; + bindSymbols(b.getContext(), oExpr, fExpr); + AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr); + return makeComposedAffineApply(b, loc, convMap, ValueRange{oIndex, fIndex}); } FailureOr> @@ -159,12 +141,12 @@ Value kIndex = nestedBuilder.create(loc, 2); // Recover the original iteration indices from the problem/input sizes. - SmallVector mIndices = unrollIndex( + SmallVector mIndices = unrollIndex( nestedBuilder, nestedLoc, mIndex, ArrayRef{oh, ow}); auto ohIndex = mIndices[0]; auto owIndex = mIndices[1]; - SmallVector kIndices = unrollIndex( + SmallVector kIndices = unrollIndex( nestedBuilder, nestedLoc, kIndex, ArrayRef{fh, fw, ic}); auto fhIndex = kIndices[0]; auto fwIndex = kIndices[1]; @@ -443,13 +425,13 @@ Value nIndex = nestedBuilder.create(loc, 2); // Recover the original iteration indices from the problem/input sizes. - SmallVector kIndices = unrollIndex( + SmallVector kIndices = unrollIndex( nestedBuilder, nestedLoc, kIndex, ArrayRef{ic, fh, fw}); auto icIndex = kIndices[0]; auto fhIndex = kIndices[1]; auto fwIndex = kIndices[2]; - SmallVector nIndices = unrollIndex( + SmallVector nIndices = unrollIndex( nestedBuilder, nestedLoc, nIndex, ArrayRef{oh, ow}); auto ohIndex = nIndices[0]; auto owIndex = nIndices[1]; diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir --- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir +++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir @@ -37,29 +37,12 @@ // CHECK: %[[MINDEX:.+]] = linalg.index 1 : index // CHECK: %[[KINDEX:.+]] = linalg.index 2 : index -// Unrolled output shape indices. -// CHECK: %[[C14:.+]] = arith.constant 14 : index -// CHECK: %[[OWINDEX:.+]] = arith.remui %[[MINDEX]], %[[C14]] : index -// CHECK: %[[C14_1:.+]] = arith.constant 14 : index -// CHECK: %[[OHINDEX:.+]] = arith.divui %[[MINDEX]], %[[C14_1]] : index +// Compute input channel/convolved indices. +// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<(d0) -> (d0 mod 4)>(%[[KINDEX]]) +// CHECK: %[[CONVH:.+]] = affine.apply affine_map<(d0, d1) -> (d0 floordiv 14 + d1 floordiv 12)>(%[[MINDEX]], %[[KINDEX]]) +// CHECK: %[[CONVW:.+]] = affine.apply affine_map<(d0, d1) -> (d0 mod 14 + (d1 mod 12) floordiv 4)>(%[[MINDEX]], %[[KINDEX]]) -// Unrolled filter shape indices. -// CHECK: %[[C4:.+]] = arith.constant 4 : index -// CHECK: %[[ICINDEX:.+]] = arith.remui %[[KINDEX]], %[[C4]] : index -// CHECK: %[[C12:.+]] = arith.constant 12 : index -// CHECK: %[[FWREM:.+]] = arith.remui %[[KINDEX]], %[[C12]] : index -// CHECK: %[[C4_2:.+]] = arith.constant 4 : index -// CHECK: %[[FWINDEX:.+]] = arith.divui %[[FWREM]], %[[C4_2]] : index -// CHECK: %[[C12_3:.+]] = arith.constant 12 : index -// CHECK: %[[FHINDEX:.+]] = arith.divui %[[KINDEX]], %[[C12_3]] : index - -// Compute input indices. -// CHECK: %[[SH:.+]] = arith.constant 1 : index -// CHECK: %[[STRIDEDOH:.+]] = arith.muli %[[OHINDEX]], %[[SH]] : index -// CHECK: %[[CONVH:.+]] = arith.addi %[[STRIDEDOH]], %[[FHINDEX]] : index -// CHECK: %[[SW:.+]] = arith.constant 1 : index -// CHECK: %[[STRIDEDOW:.+]] = arith.muli %[[OWINDEX]], %[[SW]] : index -// CHECK: %[[CONVW:.+]] = arith.addi %[[STRIDEDOW]], %[[FWINDEX]] : index +// Extract from the input tensor. // CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract // CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32> // CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32 @@ -234,6 +217,13 @@ // ----- // CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +// Im2col maps +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 9)> +// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1) -> (d0 floordiv 14 + (d1 mod 9) floordiv 3)> +// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1) -> (d0 + d1 - (d0 floordiv 14) * 14 - (d1 floordiv 3) * 3)> + + // CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> // CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> // CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> @@ -252,29 +242,12 @@ // CHECK: %[[KINDEX:.+]] = linalg.index 1 : index // CHECK: %[[NINDEX:.+]] = linalg.index 2 : index -// Unrolled filter shape indices. -// CHECK: %[[C3:.+]] = arith.constant 3 : index -// CHECK: %[[FWINDEX:.+]] = arith.remui %[[KINDEX]], %[[C3]] : index -// CHECK: %[[C9:.+]] = arith.constant 9 : index -// CHECK: %[[FHREM:.+]] = arith.remui %[[KINDEX]], %[[C9]] : index -// CHECK: %[[C3_1:.+]] = arith.constant 3 : index -// CHECK: %[[FHINDEX:.+]] = arith.divui %[[FHREM]], %[[C3_1]] : index -// CHECK: %[[C9_2:.+]] = arith.constant 9 : index -// CHECK: %[[ICINDEX:.+]] = arith.divui %[[KINDEX]], %[[C9_2]] : index - -// Unrolled output shape indices. -// CHECK: %[[C14:.+]] = arith.constant 14 : index -// CHECK: %[[OWINDEX:.+]] = arith.remui %[[NINDEX]], %[[C14]] : index -// CHECK: %[[C14_3:.+]] = arith.constant 14 : index -// CHECK: %[[OHINDEX:.+]] = arith.divui %[[NINDEX]], %[[C14_3]] : index +// Compute input channel/convolved indices. +// CHECK: %[[ICINDEX:.+]] = affine.apply #[[MAP1]](%[[KINDEX]]) +// CHECK: %[[CONVH:.+]] = affine.apply #[[MAP7]](%[[NINDEX]], %[[KINDEX]]) +// CHECK: %[[CONVW:.+]] = affine.apply #[[MAP8]](%[[NINDEX]], %[[KINDEX]]) -// Compute input indices. -// CHECK: %[[SH:.+]] = arith.constant 1 : index -// CHECK: %[[STRIDEDOH:.+]] = arith.muli %[[OHINDEX]], %[[SH]] : index -// CHECK: %[[CONVH:.+]] = arith.addi %[[STRIDEDOH]], %[[FHINDEX]] : index -// CHECK: %[[SW:.+]] = arith.constant 1 : index -// CHECK: %[[STRIDEDOW:.+]] = arith.muli %[[OWINDEX]], %[[SW]] : index -// CHECK: %[[CONVW:.+]] = arith.addi %[[STRIDEDOW]], %[[FWINDEX]] : index +// Extract from the input tensor. // CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract // CHECK-SAME: %[[INPUT]]{{\[}}%[[BINDEX]], %[[ICINDEX]], %[[CONVH]], %[[CONVW]]] : tensor<8x4x16x16xf32> // CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32