diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td --- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td +++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td @@ -31,11 +31,14 @@ "apply_patterns.tensor.fold_tensor_empty", [DeclareOpInterfaceMethods]> { let description = [{ - Indicates that reassociative reshapes (tensor.collapse_shape / - tensor.expand_shape) should be folded with inverse rank expansions / rank - reductions (via tensor.insert_slice / tensor.extract_slice). + Indicates that tensor.extract_slice and reassociative reshapes should be + folded into tensor.empty. + + If `fold_single_use_only` is set to "true", only tensor.empty that have a + single use are folded. }]; + let arguments = (ins DefaultValuedAttr:$fold_single_use_only); let assemblyFormat = "attr-dict"; } def ApplyFoldIntoPackAndUnpackPatternsOp : Op struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + FoldEmptyTensorWithReshapeOp(MLIRContext *ctx, PatternBenefit benefit = 1, + bool foldSingleUseOnly = false) + : OpRewritePattern(ctx, benefit), + foldSingleUseOnly(foldSingleUseOnly) {} LogicalResult matchAndRewrite(ReshapeOp reshapeOp, PatternRewriter &rewriter) const override { - if (!reshapeOp.getSrc().template getDefiningOp()) + // Check for tensor.empty source. + auto emptyOp = reshapeOp.getSrc().template getDefiningOp(); + if (!emptyOp) return failure(); + + // Check for single use. + if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses())) + return failure(); + + // Reify result shape. Location loc = reshapeOp.getLoc(); ReifiedRankedShapedTypeDims resultShapes; if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) || !llvm::hasSingleElement(resultShapes)) return failure(); + + // Create new tensor.empty op. // TODO: Do not drop tensor type encoding. Value emptyTensor = rewriter.create( loc, resultShapes[0], reshapeOp.getResultType().getElementType()); @@ -40,21 +53,34 @@ } return success(); } + +private: + bool foldSingleUseOnly = false; }; -/// `tensor.empty` does not define any tensor contents, so a slice of a -/// `tensor.empty` can be canonicalized to a smaller `tensor.empty`. +/// tensor.empty does not define any tensor contents, so a slice of a +/// tensor.empty can be folded to a smaller tensor.empty. struct FoldEmptyTensorWithExtractSliceOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + FoldEmptyTensorWithExtractSliceOp(MLIRContext *ctx, + PatternBenefit benefit = 1, + bool foldSingleUseOnly = false) + : OpRewritePattern(ctx, benefit), + foldSingleUseOnly(foldSingleUseOnly) {} LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override { - if (!sliceOp.getSource().getDefiningOp()) + // Check for tensor.empty source. + auto emptyOp = sliceOp.getSource().template getDefiningOp(); + if (!emptyOp) + return failure(); + + // Check for single use. + if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses())) return failure(); - // ExtractSliceOp may be rank-reducing; its dynamic sizes must be - // preserved as well as its result type. + // Create new tensor.empty op. tensor.extract_slice may be rank-reducing; + // its dynamic sizes must be preserved as well as its result type. auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(), sliceOp.getType().getElementType(), sliceOp.getType().getEncoding()); @@ -62,14 +88,17 @@ sliceOp.getSizes()); return success(); } + +private: + bool foldSingleUseOnly = false; }; } // namespace -void mlir::tensor::populateFoldTensorEmptyPatterns( - RewritePatternSet &patterns) { +void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, + bool foldSingleUseOnly) { patterns.add, FoldEmptyTensorWithReshapeOp>( - patterns.getContext()); + patterns.getContext(), /*benefit=*/1, foldSingleUseOnly); } diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir --- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir +++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir @@ -1,4 +1,14 @@ -// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-empty-op-folding %s | FileCheck %s +// RUN: mlir-opt -split-input-file -test-transform-dialect-interpreter %s | FileCheck %s + +transform.sequence failures(propagate) { +^bb1(%module_op: !transform.any_op): + transform.apply_patterns to %module_op { + transform.apply_patterns.tensor.fold_tensor_empty + } : !transform.any_op +} + +// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> +// CHECK: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 * 28)> func.func @empty_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> { %0 = tensor.empty(%arg0) : tensor<6x5x?xf32> @@ -6,34 +16,28 @@ : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> return %1 : tensor<2x3x5x4x?x7xf32> } -// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> -// CHECK: func @empty_reshape_expansion +// CHECK-LABEL: func @empty_reshape_expansion // CHECK-SAME: %[[ARG0:.+]]: index // CHECK: %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<6x5x?xf32> // CHECK-NEXT: %[[DIM:.*]] = tensor.dim %[[OLD_INIT]] -// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[DIM]]] +// CHECK-NEXT: %[[D:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]] // CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]]) // CHECK-NEXT: return %[[INIT]] -// ----- - func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> { %0 = tensor.empty(%arg0) : tensor<2x3x5x4x?x7xf32> %1 = tensor.collapse_shape %0 [[0, 1], [2], [3, 4, 5]] : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32> return %1 : tensor<6x5x?xf32> } -// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)> -// CHECK: func @empty_reshape_collapse +// CHECK-LABEL: func @empty_reshape_collapse // CHECK-SAME: %[[ARG0:.+]]: index // CHECK: %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<2x3x5x4x?x7xf32> // CHECK-NEXT: %[[DIM:.*]] = tensor.dim %[[OLD_INIT]] -// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[DIM]]] +// CHECK-NEXT: %[[D:.+]] = affine.apply #[[$MAP2]]()[%[[DIM]]] // CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]]) // CHECK-NEXT: return %[[INIT]] -// ----- - func.func @fold_empty_tensor_with_slice (%arg0 : index, %arg1 : index) -> tensor<5x?x20xf32> { @@ -42,14 +46,12 @@ : tensor to tensor<5x?x20xf32> return %1 : tensor<5x?x20xf32> } -// CHECK: func @fold_empty_tensor_with_slice +// CHECK-LABEL: func @fold_empty_tensor_with_slice // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index // CHECK: %[[T0:.+]] = tensor.empty(%[[ARG1]]) // CHECK: return %[[T0]] -// ----- - // CHECK-LABEL: func @rank_reducing_empty_tensor_extract func.func @rank_reducing_empty_tensor_extract(%sz : index, %idx : index) -> tensor<2xf32> { // CHECK: tensor.empty() : tensor<2xf32> @@ -59,3 +61,28 @@ %r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor to tensor<2xf32> return %r: tensor<2xf32> } + +// ----- + +transform.sequence failures(propagate) { +^bb1(%module_op: !transform.any_op): + transform.apply_patterns to %module_op { + transform.apply_patterns.tensor.fold_tensor_empty + {fold_single_use_only = true} + } : !transform.any_op +} + +func.func @double_use_of_tensor_empty(%arg0: index, %arg1: index) + -> (tensor<5x?x20xf32>, tensor<5x?x20xf32>) +{ + %0 = tensor.empty(%arg0) : tensor + %1 = tensor.extract_slice %0[0, 0, 0] [5, %arg1, 20] [1, 1, 1] + : tensor to tensor<5x?x20xf32> + %2 = tensor.extract_slice %0[1, 1, 1] [5, %arg1, 20] [1, 1, 1] + : tensor to tensor<5x?x20xf32> + return %1, %2 : tensor<5x?x20xf32>, tensor<5x?x20xf32> +} +// CHECK-LABEL: func @double_use_of_tensor_empty( +// CHECK: tensor.empty{{.*}} : tensor +// CHECK: tensor.extract_slice +// CHECK: tensor.extract_slice diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -72,10 +72,6 @@ llvm::cl::desc("Test folding of expand_shape/collapse_shape"), llvm::cl::init(false)}; - Option testEmptyOpFolding{ - *this, "test-empty-op-folding", - llvm::cl::desc("Test folding of tensor.empty"), llvm::cl::init(false)}; - Option testFoldIntoPackAndUnpack{ *this, "test-fold-into-pack-and-unpack", llvm::cl::desc("Test folding ops into tensor.pack and tensor.unpack"), @@ -106,12 +102,6 @@ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); } -static void applyEmptyOpFoldingPatterns(Operation *rootOp) { - RewritePatternSet patterns(rootOp->getContext()); - tensor::populateFoldTensorEmptyPatterns(patterns); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); -} - static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) { RewritePatternSet patterns(rootOp->getContext()); tensor::populateFoldIntoPackAndUnpackPatterns(patterns); @@ -383,8 +373,6 @@ applyDropRedundantInsertSliceRankExpansionPatterns(rootOp); if (testReassociativeReshapeFolding) applyReassociativeReshapeFoldingPatterns(rootOp); - if (testEmptyOpFolding) - applyEmptyOpFoldingPatterns(rootOp); if (testFoldIntoPackAndUnpack) applyFoldIntoPackAndUnpackPatterns(rootOp); if (testRewriteExtractSliceWithTiledCollapseShape) {