diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1824,6 +1824,7 @@ }]; let extraClassDeclaration = commonExtraClassDeclaration; + let hasCanonicalizeMethod = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3419,6 +3419,20 @@ return Speculation::Speculatable; } +/// pack(unpack(x)) -> x +LogicalResult UnPackOp::canonicalize(UnPackOp unpackOp, + PatternRewriter &rewriter) { + PackOp packOp = unpackOp.getSource().getDefiningOp(); + if (!packOp || packOp.getDestType() != unpackOp.getSourceType()) + return failure(); + if (packOp.getInnerDimsPos() != unpackOp.getInnerDimsPos()) + return failure(); + if (packOp.getOuterDimsPerm() != unpackOp.getOuterDimsPerm()) + return failure(); + rewriter.replaceOp(unpackOp, packOp.getSource()); + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1678,3 +1678,62 @@ %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x?xf32> into tensor return %1 : tensor } + +// ----- + +// Chain: NC -> NCnc -> NCnc -> NC +// CHECK: func.func @unpack_pack( +// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>) +// CHECK: return %[[T]] : tensor<128x128xf32> +func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> { + %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32> + %packed = tensor.pack %t inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32> + %tensor_empty1 = tensor.empty() : tensor<128x128xf32> + %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor<128x128xf32> + return %unpacked : tensor<128x128xf32> +} + +// ----- + +// Chain: NC -> NCcn -> NCnc -> NC +// CHECK: func.func @unpack_pack( +// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>) +// CHECK-NOT: return %[[T]] : tensor<128x128xf32> +func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> { + %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32> + %packed = tensor.pack %t inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32> + %tensor_empty1 = tensor.empty() : tensor<128x128xf32> + %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor +<128x128xf32> + return %unpacked : tensor<128x128xf32> +} + +// ----- + +// Chain: NC -> CNcn -> NCnc -> NC +// CHECK: func.func @unpack_pack( +// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>) +// CHECK-NOT: return %[[T]] : tensor<128x128xf32> +func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> { + %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32> + %packed = tensor.pack %t outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32> + %tensor_empty1 = tensor.empty() : tensor<128x128xf32> + %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor +<128x128xf32> + return %unpacked : tensor<128x128xf32> +} + +// ----- + +// Chain: NC -> NCnc -> NCnc -> NC +// CHECK: func.func @unpack_pack( +// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>, +// CHECK: return %[[T]] : tensor<128x128xf32> +func.func @unpack_pack(%t: tensor<128x128xf32>, %tile1: index, %tile2: index) -> tensor<128x128xf32> { + %tensor_empty = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32> + %packed = tensor.pack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x?x?xf32> + %tensor_empty1 = tensor.empty() : tensor<128x128xf32> + %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<16x16x?x?xf32> -> tensor +<128x128xf32> + return %unpacked : tensor<128x128xf32> +}