diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -229,6 +229,8 @@ AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand); llvm::DenseMap domainDimToOperandDim; SmallVector exprs(origIndexingMap.getResults()); + + // If the OpOperand is a scalar or a zero-rank tensor, no need to pack. if (genericOp.isScalar(opOperand) || exprs.empty()) return std::make_tuple(opOperand->get(), AffineMap::get(numLoops, 0, exprs, b.getContext())); @@ -293,10 +295,10 @@ return std::make_tuple(packedOperand, indexingMap); } -/// Pack an element-wise genericOp and return it. -static GenericOp packElementWiseOp(RewriterBase &rewriter, GenericOp genericOp, - Value dest, AffineMap packedOutIndexingMap, - const PackInfo &packInfo) { +/// Pack a linalg genericOp and return it. +static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, + Value dest, AffineMap packedOutIndexingMap, + const PackInfo &packInfo) { Location loc = genericOp.getLoc(); SmallVector inputOperands; SmallVector indexingMaps; @@ -442,8 +444,8 @@ .getDefiningOp()) { dest = packOpDest; } - return packElementWiseOp(rewriter, genericOp, dest, packedOutIndexingMap, - *packInfo); + return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, + *packInfo); } /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method. @@ -468,7 +470,7 @@ ControlPropagationFn controlFn; }; -// TODO: Relax this restriction. We should unpack an elementwise also +// TODO: Relax this restriction. We should unpack a generic op also // in the presence of multiple unpack ops as producers. /// Return the unpacked operand, if present, for the current generic op. static FailureOr getUnPackedOperand(GenericOp genericOp) { @@ -486,7 +488,7 @@ return unPackedOperand; } -/// Push down a tensor.unpack op through elementwise generic op. +/// Push down a tensor.unpack op through a generic op. /// The new generic op works on packed domain; pack ops are created for input /// and output operands. A tensor.unpack op is inserted right after the packed /// generic. E.g. @@ -560,8 +562,8 @@ } // Pack the genericOp. - GenericOp newGenericOp = packElementWiseOp(rewriter, genericOp, dest, - packedOutIndexingMap, *packInfo); + GenericOp newGenericOp = + packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo); Value newResult = newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)); diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -706,15 +706,16 @@ #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> #map1 = affine_map<(d0, d1, d2) -> (d0, d1)> -func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>, %dest: tensor<4x16x16x32xi32>) -> tensor<4x16x16x32xi32>{ - %init = tensor.empty() : tensor<128x256xi32> +func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>, + %arg1: tensor<128x256xi32>) -> tensor<4x16x16x32xi32>{ %elem = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<128x256x32xi32>) - outs(%init : tensor<128x256xi32>) { + outs(%arg1 : tensor<128x256xi32>) { ^bb0(%arg3: i32, %arg4: i32): %4 = arith.addi %arg3, %arg4 : i32 linalg.yield %4 : i32 } -> tensor<128x256xi32> + %dest = tensor.empty() : tensor<4x16x16x32xi32> %pack = tensor.pack %elem inner_dims_pos = [1, 0] inner_tiles = [16, 32] @@ -725,7 +726,11 @@ // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)> // CHECK: func.func @reduction_pack_transpose_inner_dims // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32> +// CHECK: %[[PACK_ARG1:.+]] = tensor.pack %[[ARG1]] +// CHECK-SME: inner_dims_pos = [1, 0] inner_tiles = [16, 32] +// CHECK-SAME: into %[[ARG1_EMPTY]] // CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x32x16x32xi32> // CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] // CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32] @@ -734,14 +739,14 @@ // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel"] // CHECK-SAME: ins(%[[PACK_ARG0]] -// CHECK-SAME: outs(%[[DEST]] +// CHECK-SAME: outs(%[[PACK_ARG1]] // CHECK: return %[[RED]] : tensor<4x16x16x32xi32> // ----- -func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>) -> tensor<4x16x100x16x32xi32> +func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, + %arg2: tensor<128xi32>, %init_reduction: tensor<100x128x256xi32>) -> tensor<4x16x100x16x32xi32> { - %init_reduction = tensor.empty() : tensor<100x128x256xi32> %reduction = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0)>, @@ -773,7 +778,11 @@ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK: %[[INIT_EMPTY:.+]] = tensor.empty() : tensor<4x16x100x16x32xi32> +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]] +// CHECK: %[[ARG3_EMPTY:.+]] = tensor.empty() : tensor<4x16x100x16x32xi32> +// CHECK: %[[PACKED_ARG3:.+]] = tensor.pack %[[ARG3]] +// CHECK-SAME: outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 32] +// CHECK-SAME: into %[[ARG3_EMPTY]] // CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x200x100x16x32xi32> // CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]] // CHECK-SAME: outer_dims_perm = [1, 3, 2, 0] inner_dims_pos = [3, 1] inner_tiles = [16, 32] @@ -785,16 +794,16 @@ // CHECK: %[[RES:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]], #[[MAP3]]] // CHECK-SAME: ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]] -// CHECK-SAME: outs(%[[INIT_EMPTY]] +// CHECK-SAME: outs(%[[PACKED_ARG3]] // ----- #map0 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d3)> -func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32>) -> tensor<16x540x960xi32>{ +func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32>, + %filter: tensor<2x2xi32>) -> tensor<16x540x960xi32>{ %init = tensor.empty() : tensor<16x540x960xi32> - %filter = tensor.empty() : tensor<2x2xi32> %empty = tensor.empty() : tensor<1x16x1080x1920xi32> %unpack = tensor.unpack %arg0 inner_dims_pos = [1] @@ -804,7 +813,7 @@ ins(%unpack, %filter : tensor<1x16x1080x1920xi32>, tensor<2x2xi32>) outs(%init : tensor<16x540x960xi32>) { ^bb0(%in: i32, %in_1: i32, %out: i32): - %max = arith.maxui %in, %out : i32 + %max = arith.maxui %in, %in_1 : i32 linalg.yield %max : i32 } -> tensor<16x540x960xi32> return %pool : tensor<16x540x960xi32> @@ -814,7 +823,7 @@ // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d3, d6)> // CHECK: func.func @unpack_different_destination_shape // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK: %[[FILTER:.+]] = tensor.empty() : tensor<2x2xi32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32> // CHECK: %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32> // CHECK: %[[PACK_ARG0:.+]] = tensor.pack @@ -823,7 +832,7 @@ // CHECK: %[[POOL:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"] -// CHECK-SAME: ins(%[[PACK_ARG0]], %[[FILTER]] +// CHECK-SAME: ins(%[[PACK_ARG0]], %[[ARG1]] // CHECK-SAME: outs(%[[INIT]] // CHECK: %[[UNPACK_NEW_DEST:.+]] = tensor.empty() : tensor<16x540x960xi32> // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[POOL]]