diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -40,6 +40,10 @@ /// `tensor.collapse_shape` into other ops. void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns); +/// Populates `patterns` with patterns that fold tensor.empty with +/// tensor.[extract_slice|cast|expand_shape|collapse_shape]. +void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns); + } // namespace tensor } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -667,6 +668,7 @@ tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); + tensor::populateFoldTensorEmptyPatterns(patterns); memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); memref::populateResolveShapedTypeResultDimsPatterns(patterns); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -620,58 +620,6 @@ } }; -/// `tensor.empty` does not define any tensor contents, so a slice of a -/// `tensor.empty` can be canonicalized to a smaller `tensor.empty`. -struct FoldEmptyTensorWithExtractSliceOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, - PatternRewriter &rewriter) const override { - if (!sliceOp.getSource().getDefiningOp()) - return failure(); - - // ExtractSliceOp may be rank-reducing; its dynamic sizes must be - // preserved as well as its result type. - auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(), - sliceOp.getType().getElementType(), - sliceOp.getType().getEncoding()); - rewriter.replaceOpWithNewOp(sliceOp, tensorType, - sliceOp.getSizes()); - return success(); - } -}; - -template -struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ReshapeOp reshapeOp, - PatternRewriter &rewriter) const override { - if (!reshapeOp.getSrc().template getDefiningOp()) - return failure(); - Location loc = reshapeOp.getLoc(); - ReifiedRankedShapedTypeDims resultShapes; - ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = - cast(reshapeOp.getOperation()); - if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter, - resultShapes)) || - !llvm::hasSingleElement(resultShapes)) - return failure(); - // TODO: Do not drop tensor type encoding. - Value emptyTensor = - rewriter.create(loc, getAsOpFoldResult(resultShapes[0]), - reshapeOp.getResultType().getElementType()); - if (emptyTensor.getType() != reshapeOp.getResultType()) { - rewriter.replaceOpWithNewOp( - reshapeOp, reshapeOp.getResultType(), emptyTensor); - } else { - rewriter.replaceOp(reshapeOp, emptyTensor); - } - return success(); - } -}; - struct FoldEmptyTensorWithDimOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -689,86 +637,12 @@ } }; -/// Canonicalize -/// -/// ```mlir -/// %0 = tensor.empty(%d0, %d1) : tensor -/// %1 = tensor.cast %0 : tensor to tensor<4x?xf32> -/// ``` -/// -/// into -/// -/// ```mlir -/// %0 = tensor.empty(%d1) : tensor<4x?xf32> -/// ``` -/// -/// This assumes the input program is correct in terms of its shape. So it is -/// safe to assume that `%d0` is in fact 4. -struct FoldEmptyTensorWithCastOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(CastOp castOp, - PatternRewriter &rewriter) const override { - if (!canFoldIntoProducerOp(castOp)) - return failure(); - auto producer = castOp.getSource().getDefiningOp(); - if (!producer) - return failure(); - - auto resultType = castOp->getResult(0).getType().cast(); - ArrayRef resultShape = resultType.getShape(); - SmallVector currMixedSizes = producer.getMixedSizes(); - SmallVector newMixedSizes; - newMixedSizes.reserve(currMixedSizes.size()); - assert(resultShape.size() == currMixedSizes.size() && - "mismatch in result shape and sizes of empty op"); - for (auto it : llvm::zip(resultShape, currMixedSizes)) { - int64_t newDim = std::get<0>(it); - OpFoldResult currDim = std::get<1>(it); - // Case 1: The empty tensor dim is static. Check that the tensor cast - // result dim matches. - if (auto attr = currDim.dyn_cast()) { - if (ShapedType::isDynamic(newDim) || - newDim != attr.cast().getInt()) { - // Something is off, the cast result shape cannot be more dynamic - // than the empty tensor result shape (enforced by - // `canFoldIntoProducer`). Abort for now. - return rewriter.notifyMatchFailure( - producer, "mismatch in static value of shape of empty tensor " - "result and cast result"); - } - newMixedSizes.push_back(attr); - continue; - } - - // Case 2 : The tensor cast shape is static, but empty tensor result - // shape is dynamic. - if (!ShapedType::isDynamic(newDim)) { - newMixedSizes.push_back(rewriter.getIndexAttr(newDim)); - continue; - } - - // Case 3 : The tensor cast shape is dynamic and empty tensor result - // shape is dynamic. Use the dynamic value from the empty tensor op. - newMixedSizes.push_back(currDim); - } - - // TODO: Do not drop tensor encoding. - rewriter.replaceOpWithNewOp(castOp, newMixedSizes, - resultType.getElementType()); - return success(); - } -}; - } // namespace void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - FoldEmptyTensorWithReshapeOp, - ReplaceEmptyTensorStaticShapeDims>(context); + results.add( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRTensorTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + EmptyOpPatterns.cpp ExtractSliceFromReshapeUtils.cpp MergeConsecutiveInsertExtractSlicePatterns.cpp ReshapePatterns.cpp diff --git a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp @@ -0,0 +1,150 @@ +//===- EmptyOpPatterns.cpp - Patterns related to tensor.empty folding ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; +using namespace mlir::tensor; + +namespace { + +/// Canonicalize +/// +/// ```mlir +/// %0 = tensor.empty(%d0, %d1) : tensor +/// %1 = tensor.cast %0 : tensor to tensor<4x?xf32> +/// ``` +/// +/// into +/// +/// ```mlir +/// %0 = tensor.empty(%d1) : tensor<4x?xf32> +/// ``` +/// +/// This assumes the input program is correct in terms of its shape. So it is +/// safe to assume that `%d0` is in fact 4. +struct FoldEmptyTensorWithCastOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CastOp castOp, + PatternRewriter &rewriter) const override { + if (!canFoldIntoProducerOp(castOp)) + return failure(); + auto producer = castOp.getSource().getDefiningOp(); + if (!producer) + return failure(); + + auto resultType = castOp->getResult(0).getType().cast(); + ArrayRef resultShape = resultType.getShape(); + SmallVector currMixedSizes = producer.getMixedSizes(); + SmallVector newMixedSizes; + newMixedSizes.reserve(currMixedSizes.size()); + assert(resultShape.size() == currMixedSizes.size() && + "mismatch in result shape and sizes of empty op"); + for (auto it : llvm::zip(resultShape, currMixedSizes)) { + int64_t newDim = std::get<0>(it); + OpFoldResult currDim = std::get<1>(it); + // Case 1: The empty tensor dim is static. Check that the tensor cast + // result dim matches. + if (auto attr = currDim.dyn_cast()) { + if (ShapedType::isDynamic(newDim) || + newDim != attr.cast().getInt()) { + // Something is off, the cast result shape cannot be more dynamic + // than the empty tensor result shape (enforced by + // `canFoldIntoProducer`). Abort for now. + return rewriter.notifyMatchFailure( + producer, "mismatch in static value of shape of empty tensor " + "result and cast result"); + } + newMixedSizes.push_back(attr); + continue; + } + + // Case 2 : The tensor cast shape is static, but empty tensor result + // shape is dynamic. + if (!ShapedType::isDynamic(newDim)) { + newMixedSizes.push_back(rewriter.getIndexAttr(newDim)); + continue; + } + + // Case 3 : The tensor cast shape is dynamic and empty tensor result + // shape is dynamic. Use the dynamic value from the empty tensor op. + newMixedSizes.push_back(currDim); + } + + // TODO: Do not drop tensor encoding. + rewriter.replaceOpWithNewOp(castOp, newMixedSizes, + resultType.getElementType()); + return success(); + } +}; + +template +struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ReshapeOp reshapeOp, + PatternRewriter &rewriter) const override { + if (!reshapeOp.getSrc().template getDefiningOp()) + return failure(); + Location loc = reshapeOp.getLoc(); + ReifiedRankedShapedTypeDims resultShapes; + ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = + cast(reshapeOp.getOperation()); + if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter, + resultShapes)) || + !llvm::hasSingleElement(resultShapes)) + return failure(); + // TODO: Do not drop tensor type encoding. + Value emptyTensor = + rewriter.create(loc, getAsOpFoldResult(resultShapes[0]), + reshapeOp.getResultType().getElementType()); + if (emptyTensor.getType() != reshapeOp.getResultType()) { + rewriter.replaceOpWithNewOp( + reshapeOp, reshapeOp.getResultType(), emptyTensor); + } else { + rewriter.replaceOp(reshapeOp, emptyTensor); + } + return success(); + } +}; + +/// `tensor.empty` does not define any tensor contents, so a slice of a +/// `tensor.empty` can be canonicalized to a smaller `tensor.empty`. +struct FoldEmptyTensorWithExtractSliceOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, + PatternRewriter &rewriter) const override { + if (!sliceOp.getSource().getDefiningOp()) + return failure(); + + // ExtractSliceOp may be rank-reducing; its dynamic sizes must be + // preserved as well as its result type. + auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(), + sliceOp.getType().getElementType(), + sliceOp.getType().getEncoding()); + rewriter.replaceOpWithNewOp(sliceOp, tensorType, + sliceOp.getSizes()); + return success(); + } +}; + +} // namespace + +void mlir::tensor::populateFoldTensorEmptyPatterns( + RewritePatternSet &patterns) { + patterns.add, + FoldEmptyTensorWithReshapeOp>( + patterns.getContext()); +} diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -299,9 +299,10 @@ // CHECK-LABEL: func @fold_fill_reshape() func.func @fold_fill_reshape() -> tensor<6x4xf32> { %zero = arith.constant 0.0 : f32 - // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<6x4xf32> %empty = tensor.empty() : tensor<1x2x3x4xf32> - // CHECK: %[[FILL:.+]] = linalg.fill ins(%cst : f32) outs(%[[INIT]] : tensor<6x4xf32>) -> tensor<6x4xf32> + // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape + // CHECK-NEXT: %[[FILL:.+]] = linalg.fill ins(%cst : f32) + // CHECK-SAME: outs(%[[COLLAPSE]] : tensor<6x4xf32>) %fill = linalg.fill ins(%zero : f32) outs(%empty : tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> %reshape = tensor.collapse_shape %fill [[0, 1, 2], [3]] : tensor<1x2x3x4xf32> into tensor<6x4xf32> @@ -788,13 +789,25 @@ } // CHECK: func @fold_multi_use_generic_op_with_consumer // CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<2x3x4xf32> +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 + +// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] + +// CHECK: %[[INIT1:.+]] = tensor.empty(%[[DIM1]], %[[DIM2]], %[[DIM0]]) +// CHECK-NEXT: %[[INIT2:.+]] = tensor.empty(%[[DIM2]], %[[DIM1]], %[[DIM0]]) + // CHECK-DAG: %[[CAST:.+]] = tensor.cast %[[ARG0]] : tensor to tensor<4x3x2xf32> -// CHECK-DAG: %[[INIT2:.+]] = tensor.empty() : tensor<3x2x4xf32> +// CHECK-DAG: %[[CAST_INIT1:.+]] = tensor.cast %[[INIT1]] +// CHECK-DAG: %[[CAST_INIT2:.+]] = tensor.cast %[[INIT2]] // CHECK: %[[GENERIC:.+]]:2 = linalg.generic // CHECK-SAME: ins(%[[CAST]] : -// CHECK-SAME: outs(%[[INIT2]], %[[INIT1]] : -// CHECK: %[[RETURN_CAST:.+]] = tensor.cast %[[GENERIC]]#0 : tensor<3x2x4xf32> to tensor +// CHECK-SAME: outs(%[[CAST_INIT1]], %[[CAST_INIT2]] : +// CHECK: %[[RETURN_CAST:.+]] = tensor.cast %[[GENERIC]]#0 +// CHECK-SAME: tensor<3x2x4xf32> to tensor // CHECK: return %[[RETURN_CAST]], %[[GENERIC]]#1 // ----- diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1538,63 +1538,6 @@ // ----- -func.func @empty_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> { - %0 = tensor.empty(%arg0) : tensor<6x5x?xf32> - %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4, 5]] - : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> - return %1 : tensor<2x3x5x4x?x7xf32> -} -// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> -// CHECK: func @empty_reshape_expansion -// CHECK-SAME: %[[ARG0:.+]]: index -// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] -// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]]) -// CHECK-NEXT: return %[[INIT]] - -// ----- - -func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> { - %0 = tensor.empty(%arg0) : tensor<2x3x5x4x?x7xf32> - %1 = tensor.collapse_shape %0 [[0, 1], [2], [3, 4, 5]] - : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32> - return %1 : tensor<6x5x?xf32> -} -// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)> -// CHECK: func @empty_reshape_collapse -// CHECK-SAME: %[[ARG0:.+]]: index -// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] -// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]]) -// CHECK-NEXT: return %[[INIT]] - -// ----- - -func.func @fold_empty_tensor_with_slice - (%arg0 : index, %arg1 : index) -> tensor<5x?x20xf32> -{ - %0 = tensor.empty(%arg0) : tensor - %1 = tensor.extract_slice %0[0, 0, 0] [5, %arg1, 20] [1, 1, 1] - : tensor to tensor<5x?x20xf32> - return %1 : tensor<5x?x20xf32> -} -// CHECK: func @fold_empty_tensor_with_slice -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index -// CHECK: %[[T0:.+]] = tensor.empty(%[[ARG1]]) -// CHECK: return %[[T0]] - -// ----- - -func.func @fold_empty_tensor_with_cast(%arg0 : index) -> tensor<1x12xf32> { - %0 = tensor.empty(%arg0) : tensor - %1 = tensor.cast %0 : tensor to tensor<1x12xf32> - return %1 : tensor<1x12xf32> -} -// CHECK: func @fold_empty_tensor_with_cast(%[[ARG0:.+]]: index) -// CHECK: %[[T0:.+]] = tensor.empty() : tensor<1x12xf32> -// CHECK: return %[[T0]] : tensor<1x12xf32> - -// ----- - func.func private @some_use(%i : index, %j : index) // CHECK-LABEL: func @empty_tensor_canonicalize @@ -1619,18 +1562,6 @@ // ----- -// CHECK-LABEL: func @rank_reducing_empty_tensor_extract -func.func @rank_reducing_empty_tensor_extract(%sz : index, %idx : index) -> tensor<2xf32> { - // CHECK: tensor.empty() : tensor<2xf32> - %a = tensor.empty(%sz) : tensor - - // CHECK-NOT: extract - %r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor to tensor<2xf32> - return %r: tensor<2xf32> -} - -// ----- - // CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)> // CHECK-LABEL: func @dim_of_expand_shape( // CHECK-SAME: %[[t:.*]]: tensor diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir @@ -0,0 +1,72 @@ +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-empty-op-folding %s | FileCheck %s --dump-input=always + +func.func @empty_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> { + %0 = tensor.empty(%arg0) : tensor<6x5x?xf32> + %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4, 5]] + : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> + return %1 : tensor<2x3x5x4x?x7xf32> +} +// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> +// CHECK: func @empty_reshape_expansion +// CHECK-SAME: %[[ARG0:.+]]: index +// CHECK: %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<6x5x?xf32> +// CHECK-NEXT: %[[DIM:.*]] = tensor.dim %[[OLD_INIT]] +// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[DIM]]] +// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]]) +// CHECK-NEXT: return %[[INIT]] + +// ----- + +func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> { + %0 = tensor.empty(%arg0) : tensor<2x3x5x4x?x7xf32> + %1 = tensor.collapse_shape %0 [[0, 1], [2], [3, 4, 5]] + : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32> + return %1 : tensor<6x5x?xf32> +} +// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)> +// CHECK: func @empty_reshape_collapse +// CHECK-SAME: %[[ARG0:.+]]: index +// CHECK: %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<2x3x5x4x?x7xf32> +// CHECK-NEXT: %[[DIM:.*]] = tensor.dim %[[OLD_INIT]] +// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[DIM]]] +// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]]) +// CHECK-NEXT: return %[[INIT]] + +// ----- + +func.func @fold_empty_tensor_with_slice + (%arg0 : index, %arg1 : index) -> tensor<5x?x20xf32> +{ + %0 = tensor.empty(%arg0) : tensor + %1 = tensor.extract_slice %0[0, 0, 0] [5, %arg1, 20] [1, 1, 1] + : tensor to tensor<5x?x20xf32> + return %1 : tensor<5x?x20xf32> +} +// CHECK: func @fold_empty_tensor_with_slice +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: %[[T0:.+]] = tensor.empty(%[[ARG1]]) +// CHECK: return %[[T0]] + +// ----- + +func.func @fold_empty_tensor_with_cast(%arg0 : index) -> tensor<1x12xf32> { + %0 = tensor.empty(%arg0) : tensor + %1 = tensor.cast %0 : tensor to tensor<1x12xf32> + return %1 : tensor<1x12xf32> +} +// CHECK: func @fold_empty_tensor_with_cast(%[[ARG0:.+]]: index) +// CHECK: %[[T0:.+]] = tensor.empty() : tensor<1x12xf32> +// CHECK: return %[[T0]] : tensor<1x12xf32> + +// ----- + +// CHECK-LABEL: func @rank_reducing_empty_tensor_extract +func.func @rank_reducing_empty_tensor_extract(%sz : index, %idx : index) -> tensor<2xf32> { + // CHECK: tensor.empty() : tensor<2xf32> + %a = tensor.empty(%sz) : tensor + + // CHECK-NOT: extract + %r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor to tensor<2xf32> + return %r: tensor<2xf32> +} diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -70,6 +70,10 @@ llvm::cl::desc("Test folding of expand_shape/collapse_shape"), llvm::cl::init(false)}; + Option testEmptyOpFolding{ + *this, "test-empty-op-folding", + llvm::cl::desc("Test folding of tensor.empty"), llvm::cl::init(false)}; + Option useForeach{ *this, "use-foreach", llvm::cl::desc( @@ -85,6 +89,12 @@ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); } +static void applyEmptyOpFoldingPatterns(Operation *rootOp) { + RewritePatternSet patterns(rootOp->getContext()); + tensor::populateFoldTensorEmptyPatterns(patterns); + (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); +} + static void applySplitPaddingPatterns(Operation *rootOp) { RewritePatternSet patterns(rootOp->getContext()); tensor::populateSplitPaddingPatterns(patterns); @@ -264,6 +274,8 @@ applyFoldConsecutiveInsertExtractSlicePatterns(rootOp); if (testReassociativeReshapeFolding) applyReassociativeReshapeFoldingPatterns(rootOp); + if (testEmptyOpFolding) + applyEmptyOpFoldingPatterns(rootOp); if (testRewriteExtractSliceWithTiledCollapseShape) { if (failed( applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))