diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1824,6 +1824,7 @@ }]; let extraClassDeclaration = commonExtraClassDeclaration; + let hasCanonicalizeMethod = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3417,6 +3417,16 @@ return Speculation::Speculatable; } +/// pack(unpack(x)) -> x +LogicalResult UnPackOp::canonicalize(UnPackOp unpackOp, + PatternRewriter &rewriter) { + PackOp packOp = unpackOp.getSource().getDefiningOp(); + if (!packOp || packOp.getDestType() != unpackOp.getSourceType()) + return failure(); + rewriter.replaceOp(unpackOp, packOp.getSource()); + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1678,3 +1678,16 @@ %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x?xf32> into tensor return %1 : tensor } + +// ----- + +// CHECK: func.func @unpack_pack( +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x64x58x58xf32>) +// CHECK: return %[[ARG0]] : tensor<1x64x58x58xf32> +func.func @unpack_pack(%arg0: tensor<1x64x58x58xf32>) -> tensor<1x64x58x58xf32> { + %tensor = tensor.empty() : tensor<1x2x58x58x32xf32> + %0 = tensor.pack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %tensor : tensor<1x64x58x58xf32> -> tensor<1x2x58x58x32xf32> + %tensor1 = tensor.empty() : tensor<1x64x58x58xf32> + %1 = tensor.unpack %0 inner_dims_pos = [1] inner_tiles = [32] into %tensor1 : tensor<1x2x58x58x32xf32> -> tensor<1x64x58x58xf32> + return %1 : tensor<1x64x58x58xf32> +}