diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -434,14 +434,30 @@ GenericOp newGenericOp = packElementWiseOp(rewriter, genericOp, dest, packedOutIndexingMap, packInfo); - auto unPackOp = unPackedOperand->get().getDefiningOp(); + // If the output element type for the generic differs from the source + // unpack op, we need to create a new destination tensor. + auto loc = genericOp.getLoc(); + Value unPackDest = producerUnPackOp.getDest(); + auto genericOutElementType = getElementTypeOrSelf(genericOp.getResult(0)); + if (producerUnPackOp.getDestType().getElementType() != + genericOutElementType) { + SmallVector unPackMixedSizes; + if (auto unPackEmpty = unPackDest.getDefiningOp()) + unPackMixedSizes = unPackEmpty.getMixedSizes(); + else + unPackMixedSizes = tensor::getMixedSizes(rewriter, loc, unPackDest); + + unPackDest = rewriter.create(loc, unPackMixedSizes, + genericOutElementType); + } + // Insert an unPackOp right after the packed generic. Value unPackOpRes = rewriter .create( - genericOp.getLoc(), + loc, newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)), - unPackOp.getDest(), producerUnPackOp.getInnerDimsPos(), + unPackDest, producerUnPackOp.getInnerDimsPos(), producerUnPackOp.getMixedTiles(), producerUnPackOp.getOuterDimsPerm()) .getResult(); diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -441,6 +441,46 @@ #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> { + %0 = tensor.empty() : tensor<12x56x56x64xf32> + %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32> + %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf16>) { + ^bb0(%in: f32, %out: f16): + %3 = arith.truncf %in : f32 to f16 + linalg.yield %3 : f16 + } -> tensor<12x56x56x64xf16> + return %2 : tensor<12x56x56x64xf16> +} + +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +// CHECK: func.func @unpack_element_type_change +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32> +// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]] +// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16> +// CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK-SAME: into %[[ARG1_PACK_EMPTY]] +// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> +// CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK-SAME: into %[[ARG0_PACK_EMPTY]] +// CHECK: %[[RES:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] +// CHECK-SAME: ins(%[[ARG0_PACK]] +// CHECK-SAME: outs(%[[ARG1_PACK]] +// CHECK: %[[ARG0_NEW_EMPTY_UNPACK:.+]] = tensor.empty() : tensor<12x56x56x64xf16> +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK-SAME: into %[[ARG0_NEW_EMPTY_UNPACK]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> { %init = tensor.empty() : tensor<12x56x56x64xf32> %0 = tensor.empty() : tensor<12x56x56x64xf32>