diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -798,20 +798,45 @@ } }; +/// Reshape of a FromElements can be replaced with a FromElements of the result +/// type +template +struct FoldReshapeWithFromElements : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, + PatternRewriter &rewriter) const override { + auto fromElements = + reshapeOp.src().template getDefiningOp(); + if (!fromElements) + return failure(); + + auto shapedTy = reshapeOp.getType().template cast(); + + if (!shapedTy.hasStaticShape()) + return failure(); + + rewriter.replaceOpWithNewOp(reshapeOp, reshapeOp.getType(), + fromElements.elements()); + return success(); + } +}; + } // namespace void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, CollapseMixedReshapeOps, - FoldReshapeWithConstant>(context); + FoldReshapeWithConstant, + FoldReshapeWithFromElements>(context); } void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, CollapseMixedReshapeOps, - FoldReshapeWithConstant>(context); + FoldReshapeWithConstant, + FoldReshapeWithFromElements>(context); } OpFoldResult ExpandShapeOp::fold(ArrayRef operands) { diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1178,3 +1178,25 @@ return %0 : tensor<2x3x4xf32> } + +// ----- + +// CHECK-LABEL: func @fold_collapse_shape_from_elements +func @fold_collapse_shape_from_elements(%arg0: i32) -> tensor { + // CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor + // CHECK: return %[[FROM]] : tensor + %0 = tensor.from_elements %arg0 : tensor<1xi32> + %1 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor + return %1 : tensor +} + +// ----- + +// CHECK-LABEL: func @fold_expand_shape_from_elements +func @fold_expand_shape_from_elements(%arg0: i32) -> tensor<1xi32> { + // CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<1xi32> + // CHECK: return %[[FROM]] : tensor<1xi32> + %0 = tensor.from_elements %arg0 : tensor + %1 = tensor.expand_shape %0 [] : tensor into tensor<1xi32> + return %1 : tensor<1xi32> +}