diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -40,6 +40,10 @@ /// `tensor.collapse_shape` into other ops. void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns); +/// Populates `patterns` with patterns that fold tensor.empty with +/// tensor.[extract_slice|cast|expand_shape|collapse_shape]. +void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns); + } // namespace tensor } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -667,6 +668,7 @@ tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); + tensor::populateFoldTensorEmptyPatterns(patterns); memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); memref::populateResolveShapedTypeResultDimsPatterns(patterns); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -620,58 +620,6 @@ } }; -/// `tensor.empty` does not define any tensor contents, so a slice of a -/// `tensor.empty` can be canonicalized to a smaller `tensor.empty`. -struct FoldEmptyTensorWithExtractSliceOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, - PatternRewriter &rewriter) const override { - if (!sliceOp.getSource().getDefiningOp()) - return failure(); - - // ExtractSliceOp may be rank-reducing; its dynamic sizes must be - // preserved as well as its result type. - auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(), - sliceOp.getType().getElementType(), - sliceOp.getType().getEncoding()); - rewriter.replaceOpWithNewOp(sliceOp, tensorType, - sliceOp.getSizes()); - return success(); - } -}; - -template -struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ReshapeOp reshapeOp, - PatternRewriter &rewriter) const override { - if (!reshapeOp.getSrc().template getDefiningOp()) - return failure(); - Location loc = reshapeOp.getLoc(); - ReifiedRankedShapedTypeDims resultShapes; - ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = - cast(reshapeOp.getOperation()); - if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter, - resultShapes)) || - !llvm::hasSingleElement(resultShapes)) - return failure(); - // TODO: Do not drop tensor type encoding. - Value emptyTensor = - rewriter.create(loc, getAsOpFoldResult(resultShapes[0]), - reshapeOp.getResultType().getElementType()); - if (emptyTensor.getType() != reshapeOp.getResultType()) { - rewriter.replaceOpWithNewOp( - reshapeOp, reshapeOp.getResultType(), emptyTensor); - } else { - rewriter.replaceOp(reshapeOp, emptyTensor); - } - return success(); - } -}; - struct FoldEmptyTensorWithDimOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -765,9 +713,6 @@ void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, - FoldEmptyTensorWithReshapeOp, ReplaceEmptyTensorStaticShapeDims>(context); } diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRTensorTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + EmptyOpPatterns.cpp ExtractSliceFromReshapeUtils.cpp MergeConsecutiveInsertExtractSlicePatterns.cpp ReshapePatterns.cpp diff --git a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp @@ -0,0 +1,79 @@ +//===- EmptyOpPatterns.cpp - Patterns related to tensor.empty folding ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; +using namespace mlir::tensor; + +namespace { + +template +struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ReshapeOp reshapeOp, + PatternRewriter &rewriter) const override { + if (!reshapeOp.getSrc().template getDefiningOp()) + return failure(); + Location loc = reshapeOp.getLoc(); + ReifiedRankedShapedTypeDims resultShapes; + ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = + cast(reshapeOp.getOperation()); + if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter, + resultShapes)) || + !llvm::hasSingleElement(resultShapes)) + return failure(); + // TODO: Do not drop tensor type encoding. + Value emptyTensor = + rewriter.create(loc, getAsOpFoldResult(resultShapes[0]), + reshapeOp.getResultType().getElementType()); + if (emptyTensor.getType() != reshapeOp.getResultType()) { + rewriter.replaceOpWithNewOp( + reshapeOp, reshapeOp.getResultType(), emptyTensor); + } else { + rewriter.replaceOp(reshapeOp, emptyTensor); + } + return success(); + } +}; + +/// `tensor.empty` does not define any tensor contents, so a slice of a +/// `tensor.empty` can be canonicalized to a smaller `tensor.empty`. +struct FoldEmptyTensorWithExtractSliceOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, + PatternRewriter &rewriter) const override { + if (!sliceOp.getSource().getDefiningOp()) + return failure(); + + // ExtractSliceOp may be rank-reducing; its dynamic sizes must be + // preserved as well as its result type. + auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(), + sliceOp.getType().getElementType(), + sliceOp.getType().getEncoding()); + rewriter.replaceOpWithNewOp(sliceOp, tensorType, + sliceOp.getSizes()); + return success(); + } +}; + +} // namespace + +void mlir::tensor::populateFoldTensorEmptyPatterns( + RewritePatternSet &patterns) { + patterns.add, + FoldEmptyTensorWithReshapeOp>( + patterns.getContext()); +} diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -299,9 +299,10 @@ // CHECK-LABEL: func @fold_fill_reshape() func.func @fold_fill_reshape() -> tensor<6x4xf32> { %zero = arith.constant 0.0 : f32 - // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<6x4xf32> %empty = tensor.empty() : tensor<1x2x3x4xf32> - // CHECK: %[[FILL:.+]] = linalg.fill ins(%cst : f32) outs(%[[INIT]] : tensor<6x4xf32>) -> tensor<6x4xf32> + // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape + // CHECK-NEXT: %[[FILL:.+]] = linalg.fill ins(%cst : f32) + // CHECK-SAME: outs(%[[COLLAPSE]] : tensor<6x4xf32>) %fill = linalg.fill ins(%zero : f32) outs(%empty : tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> %reshape = tensor.collapse_shape %fill [[0, 1, 2], [3]] : tensor<1x2x3x4xf32> into tensor<6x4xf32> diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1538,52 +1538,6 @@ // ----- -func.func @empty_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> { - %0 = tensor.empty(%arg0) : tensor<6x5x?xf32> - %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4, 5]] - : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> - return %1 : tensor<2x3x5x4x?x7xf32> -} -// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> -// CHECK: func @empty_reshape_expansion -// CHECK-SAME: %[[ARG0:.+]]: index -// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] -// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]]) -// CHECK-NEXT: return %[[INIT]] - -// ----- - -func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> { - %0 = tensor.empty(%arg0) : tensor<2x3x5x4x?x7xf32> - %1 = tensor.collapse_shape %0 [[0, 1], [2], [3, 4, 5]] - : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32> - return %1 : tensor<6x5x?xf32> -} -// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)> -// CHECK: func @empty_reshape_collapse -// CHECK-SAME: %[[ARG0:.+]]: index -// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] -// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]]) -// CHECK-NEXT: return %[[INIT]] - -// ----- - -func.func @fold_empty_tensor_with_slice - (%arg0 : index, %arg1 : index) -> tensor<5x?x20xf32> -{ - %0 = tensor.empty(%arg0) : tensor - %1 = tensor.extract_slice %0[0, 0, 0] [5, %arg1, 20] [1, 1, 1] - : tensor to tensor<5x?x20xf32> - return %1 : tensor<5x?x20xf32> -} -// CHECK: func @fold_empty_tensor_with_slice -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index -// CHECK: %[[T0:.+]] = tensor.empty(%[[ARG1]]) -// CHECK: return %[[T0]] - -// ----- - func.func @fold_empty_tensor_with_cast(%arg0 : index) -> tensor<1x12xf32> { %0 = tensor.empty(%arg0) : tensor %1 = tensor.cast %0 : tensor to tensor<1x12xf32> @@ -1619,18 +1573,6 @@ // ----- -// CHECK-LABEL: func @rank_reducing_empty_tensor_extract -func.func @rank_reducing_empty_tensor_extract(%sz : index, %idx : index) -> tensor<2xf32> { - // CHECK: tensor.empty() : tensor<2xf32> - %a = tensor.empty(%sz) : tensor - - // CHECK-NOT: extract - %r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor to tensor<2xf32> - return %r: tensor<2xf32> -} - -// ----- - // CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)> // CHECK-LABEL: func @dim_of_expand_shape( // CHECK-SAME: %[[t:.*]]: tensor diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir @@ -0,0 +1,61 @@ +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-empty-op-folding %s | FileCheck %s + +func.func @empty_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> { + %0 = tensor.empty(%arg0) : tensor<6x5x?xf32> + %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4, 5]] + : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> + return %1 : tensor<2x3x5x4x?x7xf32> +} +// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> +// CHECK: func @empty_reshape_expansion +// CHECK-SAME: %[[ARG0:.+]]: index +// CHECK: %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<6x5x?xf32> +// CHECK-NEXT: %[[DIM:.*]] = tensor.dim %[[OLD_INIT]] +// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[DIM]]] +// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]]) +// CHECK-NEXT: return %[[INIT]] + +// ----- + +func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> { + %0 = tensor.empty(%arg0) : tensor<2x3x5x4x?x7xf32> + %1 = tensor.collapse_shape %0 [[0, 1], [2], [3, 4, 5]] + : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32> + return %1 : tensor<6x5x?xf32> +} +// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)> +// CHECK: func @empty_reshape_collapse +// CHECK-SAME: %[[ARG0:.+]]: index +// CHECK: %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<2x3x5x4x?x7xf32> +// CHECK-NEXT: %[[DIM:.*]] = tensor.dim %[[OLD_INIT]] +// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[DIM]]] +// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]]) +// CHECK-NEXT: return %[[INIT]] + +// ----- + +func.func @fold_empty_tensor_with_slice + (%arg0 : index, %arg1 : index) -> tensor<5x?x20xf32> +{ + %0 = tensor.empty(%arg0) : tensor + %1 = tensor.extract_slice %0[0, 0, 0] [5, %arg1, 20] [1, 1, 1] + : tensor to tensor<5x?x20xf32> + return %1 : tensor<5x?x20xf32> +} +// CHECK: func @fold_empty_tensor_with_slice +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: %[[T0:.+]] = tensor.empty(%[[ARG1]]) +// CHECK: return %[[T0]] + +// ----- + +// CHECK-LABEL: func @rank_reducing_empty_tensor_extract +func.func @rank_reducing_empty_tensor_extract(%sz : index, %idx : index) -> tensor<2xf32> { + // CHECK: tensor.empty() : tensor<2xf32> + %a = tensor.empty(%sz) : tensor + + // CHECK-NOT: extract + %r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor to tensor<2xf32> + return %r: tensor<2xf32> +} diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -70,6 +70,10 @@ llvm::cl::desc("Test folding of expand_shape/collapse_shape"), llvm::cl::init(false)}; + Option testEmptyOpFolding{ + *this, "test-empty-op-folding", + llvm::cl::desc("Test folding of tensor.empty"), llvm::cl::init(false)}; + Option useForeach{ *this, "use-foreach", llvm::cl::desc( @@ -85,6 +89,12 @@ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); } +static void applyEmptyOpFoldingPatterns(Operation *rootOp) { + RewritePatternSet patterns(rootOp->getContext()); + tensor::populateFoldTensorEmptyPatterns(patterns); + (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); +} + static void applySplitPaddingPatterns(Operation *rootOp) { RewritePatternSet patterns(rootOp->getContext()); tensor::populateSplitPaddingPatterns(patterns); @@ -264,6 +274,8 @@ applyFoldConsecutiveInsertExtractSlicePatterns(rootOp); if (testReassociativeReshapeFolding) applyReassociativeReshapeFoldingPatterns(rootOp); + if (testEmptyOpFolding) + applyEmptyOpFoldingPatterns(rootOp); if (testRewriteExtractSliceWithTiledCollapseShape) { if (failed( applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))