diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1534,21 +1534,20 @@ }; // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp. -class ExtractOpConstantFolder final : public OpRewritePattern { +class ExtractOpSplatConstantFolder final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { - // Return if 'extractStridedSliceOp' operand is not defined by a - // ConstantOp. + // Return if 'ExtractOp' operand is not defined by a splat vector ConstantOp. auto constantOp = extractOp.getVector().getDefiningOp(); if (!constantOp) return failure(); - auto dense = constantOp.getValue().dyn_cast(); - if (!dense) + auto splat = constantOp.getValue().dyn_cast(); + if (!splat) return failure(); - Attribute newAttr = dense.getSplatValue(); + Attribute newAttr = splat.getSplatValue(); if (auto vecDstType = extractOp.getType().dyn_cast()) newAttr = DenseElementsAttr::get(vecDstType, newAttr); rewriter.replaceOpWithNewOp(extractOp, newAttr); @@ -1556,11 +1555,57 @@ } }; +// Pattern to rewrite a ExtractOp(vector ConstantOp) -> ConstantOp. +class ExtractOp1DVectorConstantFolder final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractOp extractOp, + PatternRewriter &rewriter) const override { + // Return if 'ExtractOp' operand is not defined by a 1D vector ConstantOp. + auto constantOp = extractOp.getVector().getDefiningOp(); + if (!constantOp) + return failure(); + + auto vecTy = constantOp.getType().cast(); + Type elemTy = vecTy.getElementType(); + if (!vecTy.hasStaticShape()) + return failure(); + if (vecTy.getRank() != 1) + return failure(); + if (!elemTy.isIntOrIndexOrFloat()) + return failure(); + + TypedAttr vectorAttr = constantOp.getValue(); + // The splat case is handled by `ExtractOpSplatConstantFolder`. + auto dense = vectorAttr.dyn_cast(); + if (!dense || dense.isSplat()) + return failure(); + + auto positionAttr = extractOp.getPosition()[0].cast(); + int64_t position = positionAttr.getInt(); + Attribute newAttr; + if (vecTy.getElementType().isIntOrIndex()) { + auto values = to_vector(dense.getValues()); + newAttr = IntegerAttr::get(extractOp.getType(), values[position]); + } else if (vecTy.getElementType().isa()) { + auto values = to_vector(dense.getValues()); + newAttr = FloatAttr::get(extractOp.getType(), values[position]); + } + assert(newAttr && "Unhandled case"); + + rewriter.replaceOpWithNewOp(extractOp, newAttr); + return success(); + } +}; + } // namespace void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } static void populateFromInt64AttrArray(ArrayAttr arrayAttr, diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1367,11 +1367,11 @@ // ----- -// CHECK-LABEL: extract_constant -// CHECK-DAG: %[[CST1:.*]] = arith.constant 1 : i32 -// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<7xf32> -// CHECK: return %[[CST0]], %[[CST1]] : vector<7xf32>, i32 -func.func @extract_constant() -> (vector<7xf32>, i32) { +// CHECK-LABEL: func.func @extract_splat_constant +// CHECK-DAG: %[[CST1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<7xf32> +// CHECK-NEXT: return %[[CST0]], %[[CST1]] : vector<7xf32>, i32 +func.func @extract_splat_constant() -> (vector<7xf32>, i32) { %cst = arith.constant dense<2.000000e+00> : vector<29x7xf32> %cst_1 = arith.constant dense<1> : vector<4x37x9xi32> %0 = vector.extract %cst[2] : vector<29x7xf32> @@ -1381,6 +1381,23 @@ // ----- +// CHECK-LABEL: func.func @extract_1d_constant +// CHECK-DAG: %[[I32CST:.*]] = arith.constant 3 : i32 +// CHECK-DAG: %[[IDXCST:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[F32CST:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-NEXT: return %[[I32CST]], %[[IDXCST]], %[[F32CST]] : i32, index, f32 +func.func @extract_1d_constant() -> (i32, index, f32) { + %icst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32> + %e = vector.extract %icst[2] : vector<4xi32> + %idx_cst = arith.constant dense<[0, 1, 2]> : vector<3xindex> + %f = vector.extract %idx_cst[1] : vector<3xindex> + %fcst = arith.constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<3xf32> + %g = vector.extract %fcst[0] : vector<3xf32> + return %e, %f, %g : i32, index, f32 +} + +// ----- + // CHECK-LABEL: extract_extract_strided // CHECK-SAME: %[[A:.*]]: vector<32x16x4xf16> // CHECK: %[[V:.*]] = vector.extract %[[A]][9, 7] : vector<32x16x4xf16>