Index: mlir/lib/Dialect/Tensor/IR/TensorOps.cpp =================================================================== --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1382,6 +1382,24 @@ } }; +// Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type. +template +class FoldReshapeWithSplat : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, + PatternRewriter &rewriter) const override { + auto splatOp = reshapeOp.getSrc().template getDefiningOp(); + if (!splatOp) + return failure(); + + rewriter.replaceOpWithNewOp( + reshapeOp, reshapeOp.getResultType(), splatOp.getInput()); + return success(); + } +}; + /// Reshape of a FromElements can be replaced with a FromElements of the /// result type template @@ -1523,6 +1541,7 @@ results.add, ComposeExpandOfCollapseOp, FoldReshapeWithConstant, + FoldReshapeWithSplat, FoldReshapeWithFromElements, FoldDimOfExpandShape, FoldDimOfCollapseShape>(context); } @@ -1533,6 +1552,7 @@ .add, ComposeCollapseOfExpandOp, FoldReshapeWithConstant, + FoldReshapeWithSplat, FoldReshapeWithFromElements, FoldCollapseOfCastOp>( context); } @@ -1540,6 +1560,7 @@ OpFoldResult ExpandShapeOp::fold(ArrayRef operands) { return foldReshapeOp(*this, operands); } + OpFoldResult CollapseShapeOp::fold(ArrayRef operands) { return foldReshapeOp(*this, operands); } Index: mlir/test/Dialect/Tensor/canonicalize.mlir =================================================================== --- mlir/test/Dialect/Tensor/canonicalize.mlir +++ mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1013,9 +1013,34 @@ // CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xi32> // CHECK-NOT: tensor.expand_shape // CHECK: return %[[CST]] +// ----- +func.func @expand_shape_splat(%arg : f32) -> tensor<2x2x2xf32> { + %c0 = tensor.splat %arg : tensor<2x4xf32> + %0 = tensor.expand_shape %c0 [[0], [1, 2]] + : tensor<2x4xf32> into tensor<2x2x2xf32> + return %0 : tensor<2x2x2xf32> +} +// CHECK-LABEL: @expand_shape_splat +// CHECK-SAME: %[[ARG0:.+]]: f32 +// CHECK: %[[CST:.*]] = tensor.splat %[[ARG0:.+]] : tensor<2x2x2xf32> +// CHECK-NOT: tensor.expand_shape +// CHECK: return %[[CST]] // ----- +func.func @collapse_shape_splat(%arg : f32) -> tensor<2x4xf32> { + %c0 = tensor.splat %arg : tensor<2x2x2xf32> + %0 = tensor.collapse_shape %c0 [[0], [1, 2]] + : tensor<2x2x2xf32> into tensor<2x4xf32> + return %0 : tensor<2x4xf32> +} +// CHECK-LABEL: @collapse_shape_splat +// CHECK-SAME: %[[ARG0:.+]]: f32 +// CHECK: %[[CST:.*]] = tensor.splat %[[ARG0:.+]] : tensor<2x4xf32> +// CHECK-NOT: tensor.collapse_shape +// CHECK: return %[[CST]] + +// ----- func.func @reshape_splat_constant_int16() -> tensor<2x4x2xi16> { %c0 = arith.constant dense<42> : tensor<2x8xi16> %0 = tensor.expand_shape %c0 [[0], [1, 2]]