diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -142,10 +142,10 @@ LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector &result); -/// Function to control the folding of constant and extract slice +/// Function to control the folding of constant and extract slice. using ControlConstantExtractSliceFusionFn = std::function; -/// Patterns to fold the extract slice op with its constant operand +/// Patterns to fold the extract slice op with its constant operand. void populateFoldConstantExtractSlicePatterns( RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn = @@ -155,6 +155,9 @@ return false; }); +/// Patterns to simplify tensor.pack. +void populateSimplifyTensorPack(RewritePatternSet &patterns); + } // namespace tensor } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1790,6 +1790,8 @@ Value source, ArrayRef innerTileSizes, ArrayRef innerDimsPos, ArrayRef outerDimsPerm); }]; + + let hasCanonicalizeMethod = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3017,6 +3017,44 @@ // PackOp/UnPackOp Common //===----------------------------------------------------------------------===// +namespace { + +/// Packing one-dimensional tensor can be expressed as an expand shape op. +struct SimplifyPackToExandShape : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + Value insertExpand(RewriterBase &rewriter, Location loc, Value operand, + Type newOperandType, ArrayAttr reassociation) const { + if (operand.getType() == newOperandType) + return operand; + return rewriter.create(loc, newOperandType, operand, + reassociation); + } + + LogicalResult matchAndRewrite(PackOp packOp, + PatternRewriter &rewriter) const override { + RankedTensorType sourceType = packOp.getSourceType(); + RankedTensorType destType = packOp.getDestType(); + if (sourceType.getRank() != 1 || packOp.getPaddingValue()) + return failure(); + auto reassociation = + getReassociationIndicesForReshape(sourceType, destType); + if (!reassociation) + return failure(); + Value expanded = insertExpand( + rewriter, packOp.getLoc(), packOp.getSource(), destType, + getReassociationIndicesAttribute(rewriter, *reassociation)); + rewriter.replaceOp(packOp, expanded); + return success(); + } +}; + +} // namespace + +void mlir::tensor::populateSimplifyTensorPack(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + template static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, @@ -3376,6 +3414,41 @@ return Speculation::Speculatable; } +// Return true if `inner_dims_pos` and `outer_dims_perm` target the same +// dimensions for pack and unpack. +static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) { + if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos()) + return false; + return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm(); +} + +// Return true if pack and unpack have the same tiles. +// Same SSA values or same integer constants. +static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) { + auto packTiles = packOp.getMixedTiles(); + auto unPackTiles = unPackOp.getMixedTiles(); + if (packTiles.size() != unPackTiles.size()) + return false; + for (size_t i = 0, e = packTiles.size(); i < e; i++) { + if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i])) + return false; + } + return true; +} + +/// Fold an unpack(pack(x)) to x. +LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { + UnPackOp unPackOp = packOp.getSource().getDefiningOp(); + if (!unPackOp || unPackOp.getSourceType() != packOp.getDestType()) + return failure(); + if (packOp.getPaddingValue() || + !hasSameInnerOuterAttribute(packOp, unPackOp) || + !haveSameTiles(packOp, unPackOp)) + return failure(); + rewriter.replaceOp(packOp, unPackOp.getSource()); + return success(); +} + //===----------------------------------------------------------------------===// // UnPackOp //===----------------------------------------------------------------------===// @@ -3433,16 +3506,16 @@ } /// pack(unpack(x)) -> x -LogicalResult UnPackOp::canonicalize(UnPackOp unpackOp, +LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp, PatternRewriter &rewriter) { - PackOp packOp = unpackOp.getSource().getDefiningOp(); - if (!packOp || packOp.getDestType() != unpackOp.getSourceType()) - return failure(); - if (packOp.getInnerDimsPos() != unpackOp.getInnerDimsPos()) + PackOp packOp = unPackOp.getSource().getDefiningOp(); + if (!packOp || packOp.getDestType() != unPackOp.getSourceType()) return failure(); - if (packOp.getOuterDimsPerm() != unpackOp.getOuterDimsPerm()) + if (packOp.getPaddingValue() || + !hasSameInnerOuterAttribute(packOp, unPackOp) || + !haveSameTiles(packOp, unPackOp)) return failure(); - rewriter.replaceOp(unpackOp, packOp.getSource()); + rewriter.replaceOp(unPackOp, packOp.getSource()); return success(); } diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1704,3 +1704,73 @@ <128x128xf32> return %unpacked : tensor<128x128xf32> } + +// ----- + +// Chain NCnc -> NC -> NC -> NCnc +// CHECK: func.func @pack_unpack( +// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>, +// CHECK: return %[[T]] : tensor<16x16x?x?xf32> +func.func @pack_unpack(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> { + %tensor_empty = tensor.empty() : tensor<128x128xf32> + %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32> + %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32> + %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32> + return %packed : tensor<16x16x?x?xf32> +} + +// ----- + +// Chain NCnc -> NC -> NC -> NCnc +// CHECK: func.func @pack_unpack( +// CHECK-SAME: %[[T:.+]]: tensor<16x16x8x8xf32> +// CHECK: return %[[T]] : tensor<16x16x8x8xf32> +func.func @pack_unpack(%t: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> { + %tensor_empty = tensor.empty() : tensor<128x128xf32> + %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty : tensor<16x16x8x8xf32> -> tensor<128x128xf32> + %tensor_empty1 = tensor.empty() : tensor<16x16x8x8xf32> + %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x8x8xf32> + return %packed : tensor<16x16x8x8xf32> +} + +// ----- + +// CHECK: func.func @pack_unpack_same_tiles( +// CHECK-SAME: %[[T:.+]]: tensor, +// CHECK: return %[[T]] : tensor +func.func @pack_unpack_same_tiles(%t: tensor, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index, + %tile1: index, %tile2: index) -> tensor { + %tensor_empty = tensor.empty(%dim1, %dim2) : tensor + %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor -> tensor + %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor + %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor -> tensor + return %packed : tensor +} + +// ----- + +// CHECK: func.func @pack_unpack_different_tiles( +// CHECK-SAME: %[[T:.+]]: tensor, +// CHECK-NOT: return %[[T]] : tensor +func.func @pack_unpack_different_tiles(%t: tensor, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index, + %tile1: index, %tile2: index) -> tensor { + %tensor_empty = tensor.empty(%dim1, %dim2) : tensor + %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor -> tensor + %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor + %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile2, %tile1] into %tensor_empty1 : tensor -> tensor + return %packed : tensor +} + +// ----- + +// CHECK: func.func @pack_unpack_dynamic_with_padding( +// CHECK-SAME: %[[T:.+]]: tensor, +// CHECK-NOT: return %[[T]] : tensor +func.func @pack_unpack_dynamic_with_padding(%t: tensor, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index, + %tile1: index, %tile2: index, %pad: f32) -> tensor { + %tensor_empty = tensor.empty(%dim1, %dim2) : tensor + %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor -> tensor + %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor + %packed = tensor.pack %unpacked padding_value(%pad: f32) inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor -> tensor + return %packed : tensor +} diff --git a/mlir/test/Dialect/Tensor/simplify-tensor-pack.mlir b/mlir/test/Dialect/Tensor/simplify-tensor-pack.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/simplify-tensor-pack.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-simplify-pack-patterns" %s | FileCheck %s + +// CHECK: func.func @single_dim_packing( +// CHECK-SAME: %[[ARG0:.+]]: tensor<256xf32>) +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<256xf32> into tensor<8x32xf32> +// CHECK: return %[[EXPANDED]] : tensor<8x32xf32> +func.func @single_dim_packing(%arg0: tensor<256xf32>) -> tensor<8x32xf32> { + %empty = tensor.empty() : tensor<8x32xf32> + %0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256xf32> -> tensor<8x32xf32> + return %0 : tensor<8x32xf32> +} + +// ----- + +// CHECK: func.func @single_dim_packing_with_padding( +// CHECK-SAME: %[[ARG0:.+]]: tensor<255xf32>) +// CHECK-NOT: tensor.expand_shape +// CHECK: tensor.pack +func.func @single_dim_packing_with_padding(%arg0: tensor<255xf32>) -> tensor<8x32xf32> { + %empty = tensor.empty() : tensor<8x32xf32> + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.pack %arg0 padding_value(%cst : f32) inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<255xf32> -> tensor<8x32xf32> + return %0 : tensor<8x32xf32> +} diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -85,6 +85,11 @@ "Use the scf.foreach_thread operation when generating loop nests for " "the extract_slice of collapse_shape pattern"), llvm::cl::init(false)}; + + Option testSimplifyPackPatterns{ + *this, "test-simplify-pack-patterns", + llvm::cl::desc("Test patterns to simplify tensor.pack"), + llvm::cl::init(false)}; }; } // namespace @@ -134,6 +139,12 @@ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); } +static void applySimplifyPackPatterns(Operation *rootOp) { + RewritePatternSet patterns(rootOp->getContext()); + tensor::populateSimplifyTensorPack(patterns); + (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); +} + namespace { /// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`. /// The `tensor.extract_slice` is replaced by a loop or gather operation that @@ -277,6 +288,8 @@ void TestTensorTransforms::runOnOperation() { Operation *rootOp = getOperation(); + if (testSimplifyPackPatterns) + applySimplifyPackPatterns(rootOp); if (testSplitPaddingPatterns) applySplitPaddingPatterns(rootOp); if (testFoldConstantExtractSlice)