diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1790,6 +1790,8 @@ Value source, ArrayRef innerTileSizes, ArrayRef innerDimsPos, ArrayRef outerDimsPerm); }]; + + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3376,6 +3376,63 @@ return Speculation::Speculatable; } +namespace { + +/// Fold an unpack(pack(x)) to x. +struct PackOfUnPack : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PackOp packOp, + PatternRewriter &rewriter) const override { + UnPackOp unPackOp = packOp.getSource().getDefiningOp(); + if (!unPackOp || unPackOp.getSourceType() != packOp.getDestType()) + return failure(); + if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos()) + return failure(); + if (packOp.getOuterDimsPerm() != unPackOp.getOuterDimsPerm()) + return failure(); + rewriter.replaceOp(packOp, unPackOp.getSource()); + return success(); + } +}; + +/// Packing one-dimensional tensor can be expressed as an expand shape op. +struct PackToExpandShape : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + Value insertExpand(RewriterBase &rewriter, Location loc, Value operand, + Type newOperandType, ArrayAttr reassociation) const { + Type operandType = operand.getType(); + if (operandType == newOperandType) + return operand; + return rewriter.create(loc, newOperandType, operand, + reassociation); + } + + LogicalResult matchAndRewrite(PackOp packOp, + PatternRewriter &rewriter) const override { + ShapedType sourceType = packOp.getSourceType(); + ShapedType destType = packOp.getDestType(); + if (sourceType.getRank() != 1 || packOp.getPaddingValue()) + return failure(); + auto reassociation = + getReassociationIndicesForReshape(sourceType, destType); + if (!reassociation) + return failure(); + Value expanded = insertExpand( + rewriter, packOp.getLoc(), packOp.getSource(), destType, + getReassociationIndicesAttribute(rewriter, *reassociation)); + rewriter.replaceOp(packOp, expanded); + return success(); + } +}; + +} // end namespace + +void PackOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *ctx) { + results.add(ctx); +} + //===----------------------------------------------------------------------===// // UnPackOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1704,3 +1704,43 @@ <128x128xf32> return %unpacked : tensor<128x128xf32> } + +// ----- + +// CHECK: func.func @single_dim_packing( +// CHECK-SAME: %[[ARG0:.+]]: tensor<256xf32>) +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<256xf32> into tensor<8x32xf32> +// CHECK: return %[[EXPANDED]] : tensor<8x32xf32> +func.func @single_dim_packing(%arg0: tensor<256xf32>) -> tensor<8x32xf32> { + %alloc = tensor.empty() : tensor<8x32xf32> + %0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %alloc : tensor<256xf32> -> tensor<8x32xf32> + return %0 : tensor<8x32xf32> +} + +// ----- + +// Chain NCnc -> NC -> NC -> NCnc +// CHECK: func.func @pack_unpack( +// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>, +// CHECK: return %[[T]] : tensor<16x16x?x?xf32> +func.func @pack_unpack(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> { + %tensor_empty = tensor.empty() : tensor<128x128xf32> + %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32> + %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32> + %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32> + return %packed : tensor<16x16x?x?xf32> +} + +// ----- + +// Chain NCnc -> NC -> NC -> NCnc +// CHECK: func.func @pack_unpack( +// CHECK-SAME: %[[T:.+]]: tensor<16x16x8x8xf32> +// CHECK: return %[[T]] : tensor<16x16x8x8xf32> +func.func @pack_unpack(%t: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> { + %tensor_empty = tensor.empty() : tensor<128x128xf32> + %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty : tensor<16x16x8x8xf32> -> tensor<128x128xf32> + %tensor_empty1 = tensor.empty() : tensor<16x16x8x8xf32> + %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x8x8xf32> + return %packed : tensor<16x16x8x8xf32> +}