diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -87,11 +87,39 @@ return packInfo; } +static SmallVector computeOuterDims(ArrayRef perm, + ArrayRef exprs) { + // Compute `outer_dims_perm`. See example: + // current exprs : (d0, d1, d2, d3) -> (d2, d3) + // perm : [0, 3, 1, 2] + // First map d2, d3 with their position in the array as: + // currentPositionTileLoops: dim | pos + // d2 | 0 + // d3 | 1 + // then scan `perm` in order and get the `outer_dims_perm` + // to be used, here it would be [1, 0]. + assert(!perm.empty() && "expect perm not to be empty"); + assert(!exprs.empty() && "expect exprs not to be empty"); + if (exprs.size() == 1) + return {}; + SmallVector outerDimsPerm; + DenseMap currentPositionTileLoops; + for (auto [pos, expr] : llvm::enumerate(exprs)) { + unsigned posInDomain = expr.cast().getPosition(); + currentPositionTileLoops[posInDomain] = pos; + } + for (int64_t loopIdx : perm) { + if (currentPositionTileLoops.count(loopIdx)) + outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx)); + } + return outerDimsPerm; +} + /// Returns a tuple for packed operand and indexing_map with the assumptions: /// 1) The generic op is the producer of the pack op. /// 2) The generic op has only one result. /// If the operand is a scalar or packing dimensions are all irrelevant to the -/// operand, the opreand and the updated indexing map will be returned. +/// operand, the operand and the updated indexing map will be returned. /// Otherwise, it returns the packed operand and the updated indexing map. E.g., /// /// #map0 = affine_map<(d0, d1) -> (d0, d1)> @@ -148,16 +176,26 @@ exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos])); } - // Step 2. Fold transpose variants (i.e., outerDimsPerm) into generic op. - // TODO: should we propagate the permutation of outer dims to the pack op? + // Step 2. Handle outer dim permutations. SmallVector outerDimsPerm; if (!packInfo.outerDimsOnDomainPerm.empty()) { + outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs); + + // Step 2.1: Fold transpose into the linalg.generic. SmallVector inversedOuterPerm = invertPermutationVector(packInfo.outerDimsOnDomainPerm); for (auto i : llvm::seq(0, origIndexingMap.getNumResults())) { int64_t dimPos = exprs[i].cast().getPosition(); exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]); } + // Step 2.2: Undo the transposition on `exprs` and propagate the + // transposition on the pack using outerDimsPerm. + if (!outerDimsPerm.empty()) { + SmallVector auxVec = exprs; + for (const auto &en : enumerate(outerDimsPerm)) + auxVec[en.index()] = exprs[en.value()]; + exprs = auxVec; + } } auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext()); @@ -254,9 +292,7 @@ indexingMaps.push_back(packedIndexingMap); } - int64_t numLoops = genericOp.getNumLoops(); int64_t numInnerLoops = packInfo.getNumTiledLoops(); - int64_t newNumLoops = numLoops + numInnerLoops; SmallVector iterTypes = genericOp.getIteratorTypesArray(); iterTypes.append(numInnerLoops, utils::IteratorType::parallel); @@ -265,24 +301,18 @@ auto [packedOutOperand, packedOutIndexingMap] = getOrCreatePackedViewOfOperand(rewriter, loc, packInfo, genericOp, opOperand); - SmallVector outExprs( - packedOutIndexingMap.getResults().drop_back(numInnerLoops)); - // Apply transpose to the indexing map, because we'll replace the init operand - // with the destination of pack op. - auto outerDimsPerm = packOp.getOuterDimsPerm(); - if (!outerDimsPerm.empty()) { - applyPermutationToVector(outExprs, outerDimsPerm); - } - for (int i = 0; i < numInnerLoops; ++i) - outExprs.push_back(rewriter.getAffineDimExpr(numLoops + i)); - AffineMap outMap = - AffineMap::get(newNumLoops, 0, outExprs, rewriter.getContext()); - indexingMaps.push_back(outMap); + indexingMaps.push_back(packedOutIndexingMap); + // We'll replace the init operand with the destination of pack op if the init + // operand has not users in the body of the linalg.generic (pure elementwise). + // If it has users we need to pack the init operand too and replace the init + // with the packing result. + Value dest = (genericOp.getRegionOutputArgs()[0].use_empty()) + ? packOp.getDest() + : packedOutOperand; auto newGenericOp = rewriter.create( - loc, packOp.getDestType(), inputOperands, packOp.getDest(), indexingMaps, - iterTypes, /*bodyBuild=*/nullptr, - linalg::getPrunedAttributeList(genericOp)); + loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes, + /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), newGenericOp.getRegion().begin()); return newGenericOp; diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -96,17 +96,16 @@ into %dest : tensor<128x256xi32> -> tensor<16x4x32x16xi32> return %pack : tensor<16x4x32x16xi32> } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK: func.func @elem_pack_transpose_outer_dims // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] -// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x32x16xi32> +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32> // CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] -// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16] -// CHECK-SAME: into %[[ARG0_EMPTY]] : tensor<128x256xi32> -> tensor<4x16x32x16xi32> +// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] +// CHECK-SAME: into %[[ARG0_EMPTY]] : tensor<128x256xi32> -> tensor<16x4x32x16xi32> // CHECK: %[[ELEM:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[PACK_ARG0]] // CHECK-SAME: outs(%[[DEST]] @@ -131,17 +130,16 @@ into %dest : tensor<128x256xi32> -> tensor<16x4x16x32xi32> return %pack : tensor<16x4x16x32xi32> } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK: func.func @elem_pack_transpose_inner_and_outer_dims // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] -// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32> +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x16x32xi32> // CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] -// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32] +// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 32] // CHECK-SAME: into %[[ARG0_EMPTY]] // CHECK: %[[ELEM:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[PACK_ARG0]] // CHECK-SAME: outs(%[[DEST]] @@ -285,7 +283,7 @@ #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d0)> #map2 = affine_map<(d0, d1) -> (d1)> -func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<200x4x16x100x16x32xi32>) -> tensor<200x4x16x100x16x32xi32> +func.func @transpose_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<200x4x16x100x16x32xi32>) -> tensor<200x4x16x100x16x32xi32> { %init_transpose = tensor.empty() : tensor<100x200x128x256xi32> %transpose = linalg.generic { @@ -308,3 +306,61 @@ into %dest : tensor<100x200x128x256xi32> -> tensor<200x4x16x100x16x32xi32> return %4 : tensor<200x4x16x100x16x32xi32> } + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d5)> +// CHECK: func.func @transpose_pack_with_outer_dims +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<200x4x16x100x16x32xi32> +// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: outer_dims_perm = [2, 1, 3, 0] inner_dims_pos = [3, 1] inner_tiles = [16, 32] +// CHECK-SAME: into %[[ARG0_EMPTY]] +// CHECK: %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<4x32xi32> +// CHECK: %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]] +// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [32] +// CHECK-SAME: into %[[ARG2_EMPTY]] +// CHECK: %[[RES:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]], #[[MAP]]] +// CHECK-SAME: ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]] +// CHECK-SAME: outs(%[[DEST]] + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %init: tensor<128x256xi32>) -> tensor<16x4x32x16xi32>{ + %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<128x256xi32>) + outs(%init : tensor<128x256xi32>) { + ^bb0(%arg3: i32, %arg4: i32): + %4 = arith.addi %arg3, %arg4 : i32 + linalg.yield %4 : i32 + } -> tensor<128x256xi32> + %empty = tensor.empty() : tensor<16x4x32x16xi32> + %pack = tensor.pack %elem + outer_dims_perm = [1, 0] + inner_dims_pos = [0, 1] + inner_tiles = [32, 16] + into %empty : tensor<128x256xi32> -> tensor<16x4x32x16xi32> + return %pack : tensor<16x4x32x16xi32> +} + +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: func.func @elem_pack_transpose_outer_dims +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32> +// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] +// CHECK-SAME: into %[[ARG0_EMPTY]] +// CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32> +// CHECK: %[[PACKED_ARG1:.+]] = tensor.pack %[[ARG1]] +// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] +// CHECK-SAME: into %[[ARG1_EMPTY]] +// CHECK: %[[RES:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] +// CHECK-SAME: ins(%[[PACKED_ARG0]] +// CHECK-SAME: outs(%[[PACKED_ARG1]]