diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" @@ -734,9 +735,28 @@ return success(); } +/// Reshape of a splat constant can be replaced with a constant of the result +/// type. +struct FoldReshapeWithConstant : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, + PatternRewriter &rewriter) const override { + DenseElementsAttr attr; + if (!matchPattern(reshapeOp.src(), m_Constant(&attr))) + return failure(); + if (!attr || !attr.isSplat()) + return failure(); + DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer( + reshapeOp.getResultType(), attr.getRawData(), true); + rewriter.replaceOpWithNewOp(reshapeOp, newAttr); + return success(); + } +}; + void TensorReshapeOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.insert>(context); + results.insert, FoldReshapeWithConstant>( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -203,3 +203,60 @@ // CHECK-NOT: linalg.copy // CHECK-NEXT: linalg.generic +// ----- + +func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> +{ + %c0 = constant dense<42> : tensor<2x8xi32> + %0 = linalg.tensor_reshape %c0 + [affine_map<(d0, d1, d2) -> (d0)>, + affine_map<(d0, d1, d2) -> (d1, d2)>] + : tensor<2x8xi32> into tensor<2x4x2xi32> + return %0 : tensor<2x4x2xi32> +} +// CHECK-LABEL: @reshape_splat_constant_int32 +// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xi32> +// CHECK-NOT: linalg.tensor_reshape +// CHECK: return %[[CST]] + +func @reshape_splat_constant_int16() -> tensor<2x4x2xi16> +{ + %c0 = constant dense<42> : tensor<2x8xi16> + %0 = linalg.tensor_reshape %c0 + [affine_map<(d0, d1, d2) -> (d0)>, + affine_map<(d0, d1, d2) -> (d1, d2)>] + : tensor<2x8xi16> into tensor<2x4x2xi16> + return %0 : tensor<2x4x2xi16> +} +// CHECK-LABEL: @reshape_splat_constant_int16 +// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xi16> +// CHECK-NOT: linalg.tensor_reshape +// CHECK: return %[[CST]] + +func @reshape_splat_constant_float32() -> tensor<2x4x2xf32> +{ + %c0 = constant dense<42.0> : tensor<2x8xf32> + %0 = linalg.tensor_reshape %c0 + [affine_map<(d0, d1, d2) -> (d0)>, + affine_map<(d0, d1, d2) -> (d1, d2)>] + : tensor<2x8xf32> into tensor<2x4x2xf32> + return %0 : tensor<2x4x2xf32> +} +// CHECK-LABEL: @reshape_splat_constant_float32 +// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf32> +// CHECK-NOT: linalg.tensor_reshape +// CHECK: return %[[CST]] + +func @reshape_splat_constant_float64() -> tensor<2x4x2xf64> +{ + %c0 = constant dense<42.0> : tensor<2x8xf64> + %0 = linalg.tensor_reshape %c0 + [affine_map<(d0, d1, d2) -> (d0)>, + affine_map<(d0, d1, d2) -> (d1, d2)>] + : tensor<2x8xf64> into tensor<2x4x2xf64> + return %0 : tensor<2x4x2xf64> +} +// CHECK-LABEL: @reshape_splat_constant_float64 +// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf64> +// CHECK-NOT: linalg.tensor_reshape +// CHECK: return %[[CST]]