diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1649,6 +1649,7 @@ }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)"; let hasFolder = 1; + let hasCanonicalizer = 1; } def Vector_BitCastOp : diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1770,13 +1770,39 @@ } }; +// Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp. +class StridedSliceConstantFolder final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, + PatternRewriter &rewriter) const override { + // Return if 'extractStridedSliceOp' operand is not defined by a + // ConstantOp. + auto constantOp = + extractStridedSliceOp.vector().getDefiningOp(); + if (!constantOp) + return failure(); + auto dense = constantOp.value().dyn_cast(); + if (!dense) + return failure(); + auto newAttr = DenseElementsAttr::get( + extractStridedSliceOp.getType().cast(), + dense.getSplatValue()); + rewriter.replaceOpWithNewOp(extractStridedSliceOp, newAttr); + return success(); + } +}; + } // end anonymous namespace void ExtractStridedSliceOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> - // ConstantMaskOp. - results.insert(context); + // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp. + results.insert( + context); } //===----------------------------------------------------------------------===// @@ -2560,6 +2586,36 @@ return {}; } +namespace { +// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp. +class ShapeCastConstantFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, + PatternRewriter &rewriter) const override { + auto constantOp = shapeCastOp.source().getDefiningOp(); + if (!constantOp) + return failure(); + // Only handle splat for now. + auto dense = constantOp.value().dyn_cast(); + if (!dense) + return failure(); + auto newAttr = DenseElementsAttr::get( + shapeCastOp.getType().cast(), dense.getSplatValue()); + rewriter.replaceOpWithNewOp(shapeCastOp, newAttr); + return success(); + } +}; + +} // namespace + +void ShapeCastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + // Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp. + results.insert(context); +} + //===----------------------------------------------------------------------===// // VectorBitCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -580,3 +580,37 @@ %2 = vector.broadcast %1 : vector<16xi32> to vector<4x16xi32> return %2 : vector<4x16xi32> } + +// ----- + +// CHECK-LABEL: shape_cast_constant +// CHECK: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<20x2xf32> +// CHECK: %[[CST1:.*]] = constant dense<1> : vector<3x4x2xi32> +// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32> +func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) { + %cst = constant dense<2.000000e+00> : vector<5x4x2xf32> + %cst_1 = constant dense<1> : vector<12x2xi32> + %0 = vector.shape_cast %cst : vector<5x4x2xf32> to vector<20x2xf32> + %1 = vector.shape_cast %cst_1 : vector<12x2xi32> to vector<3x4x2xi32> + return %0, %1 : vector<20x2xf32>, vector<3x4x2xi32> +} + +// ----- + +// CHECK-LABEL: extract_strided_constant +// CHECK: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<12x2xf32> +// CHECK: %[[CST1:.*]] = constant dense<1> : vector<2x13x3xi32> +// CHECK: return %[[CST0]], %[[CST1]] : vector<12x2xf32>, vector<2x13x3xi32> +func @extract_strided_constant() -> (vector<12x2xf32>, vector<2x13x3xi32>) { + %cst = constant dense<2.000000e+00> : vector<29x7xf32> + %cst_1 = constant dense<1> : vector<4x37x9xi32> + %0 = vector.extract_strided_slice %cst + {offsets = [2, 3], sizes = [12, 2], strides = [1, 1]} + : vector<29x7xf32> to vector<12x2xf32> + %1 = vector.extract_strided_slice %cst_1 + {offsets = [1, 2, 5], sizes = [2, 13, 3], strides = [1, 1, 1]} + : vector<4x37x9xi32> to vector<2x13x3xi32> + return %0, %1 : vector<12x2xf32>, vector<2x13x3xi32> +} + +