diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1534,21 +1534,22 @@ }; // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp. -class ExtractOpConstantFolder final : public OpRewritePattern { +class ExtractOpSplatConstantFolder final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { - // Return if 'extractStridedSliceOp' operand is not defined by a + // Return if 'ExtractOp' operand is not defined by a splat vector // ConstantOp. - auto constantOp = extractOp.getVector().getDefiningOp(); - if (!constantOp) + Value sourceVector = extractOp.getVector(); + Attribute vectorCst; + if (!matchPattern(sourceVector, m_Constant(&vectorCst))) return failure(); - auto dense = constantOp.getValue().dyn_cast(); - if (!dense) + auto splat = vectorCst.dyn_cast(); + if (!splat) return failure(); - Attribute newAttr = dense.getSplatValue(); + Attribute newAttr = splat.getSplatValue(); if (auto vecDstType = extractOp.getType().dyn_cast()) newAttr = DenseElementsAttr::get(vecDstType, newAttr); rewriter.replaceOpWithNewOp(extractOp, newAttr); @@ -1556,11 +1557,71 @@ } }; +// Pattern to rewrite a ExtractOp(vector<...xT> ConstantOp)[...] -> ConstantOp, +// where the position array specifies a scalar element. +class ExtractOpScalarVectorConstantFolder final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractOp extractOp, + PatternRewriter &rewriter) const override { + // Return if 'ExtractOp' operand is not defined by a compatible vector + // ConstantOp. + Value sourceVector = extractOp.getVector(); + Attribute vectorCst; + if (!matchPattern(sourceVector, m_Constant(&vectorCst))) + return failure(); + + auto vecTy = sourceVector.getType().cast(); + Type elemTy = vecTy.getElementType(); + ArrayAttr positions = extractOp.getPosition(); + if (vecTy.isScalable()) + return failure(); + // Do not allow extracting sub-vectors to limit the size of the generated + // constants. + if (vecTy.getRank() != static_cast(positions.size())) + return failure(); + // TODO: Handle more element types, e.g., complex values. + if (!elemTy.isIntOrIndexOrFloat()) + return failure(); + + // The splat case is handled by `ExtractOpSplatConstantFolder`. + auto dense = vectorCst.dyn_cast(); + if (!dense || dense.isSplat()) + return failure(); + + // Calculate the flattened position. + int64_t elemPosition = 0; + int64_t innerElems = 1; + for (auto [dimSize, positionInDim] : + llvm::reverse(llvm::zip(vecTy.getShape(), positions))) { + int64_t positionVal = positionInDim.cast().getInt(); + elemPosition += positionVal * innerElems; + innerElems *= dimSize; + } + + Attribute newAttr; + if (vecTy.getElementType().isIntOrIndex()) { + auto values = to_vector(dense.getValues()); + newAttr = IntegerAttr::get(extractOp.getType(), values[elemPosition]); + } else if (vecTy.getElementType().isa()) { + auto values = to_vector(dense.getValues()); + newAttr = FloatAttr::get(extractOp.getType(), values[elemPosition]); + } + assert(newAttr && "Unhandled case"); + + rewriter.replaceOpWithNewOp(extractOp, newAttr); + return success(); + } +}; + } // namespace void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } static void populateFromInt64AttrArray(ArrayAttr arrayAttr, diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1367,11 +1367,11 @@ // ----- -// CHECK-LABEL: extract_constant -// CHECK-DAG: %[[CST1:.*]] = arith.constant 1 : i32 -// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<7xf32> -// CHECK: return %[[CST0]], %[[CST1]] : vector<7xf32>, i32 -func.func @extract_constant() -> (vector<7xf32>, i32) { +// CHECK-LABEL: func.func @extract_splat_constant +// CHECK-DAG: %[[CST1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<7xf32> +// CHECK-NEXT: return %[[CST0]], %[[CST1]] : vector<7xf32>, i32 +func.func @extract_splat_constant() -> (vector<7xf32>, i32) { %cst = arith.constant dense<2.000000e+00> : vector<29x7xf32> %cst_1 = arith.constant dense<1> : vector<4x37x9xi32> %0 = vector.extract %cst[2] : vector<29x7xf32> @@ -1381,6 +1381,57 @@ // ----- +// CHECK-LABEL: func.func @extract_1d_constant +// CHECK-DAG: %[[I32CST:.*]] = arith.constant 3 : i32 +// CHECK-DAG: %[[IDXCST:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[F32CST:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-NEXT: return %[[I32CST]], %[[IDXCST]], %[[F32CST]] : i32, index, f32 +func.func @extract_1d_constant() -> (i32, index, f32) { + %icst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32> + %e = vector.extract %icst[2] : vector<4xi32> + %idx_cst = arith.constant dense<[0, 1, 2]> : vector<3xindex> + %f = vector.extract %idx_cst[1] : vector<3xindex> + %fcst = arith.constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<3xf32> + %g = vector.extract %fcst[0] : vector<3xf32> + return %e, %f, %g : i32, index, f32 +} + +// ----- + +// CHECK-LABEL: func.func @extract_2d_constant +// CHECK-DAG: %[[ACST:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[BCST:.*]] = arith.constant 2 : i32 +// CHECK-DAG: %[[CCST:.*]] = arith.constant 3 : i32 +// CHECK-DAG: %[[DCST:.*]] = arith.constant 5 : i32 +// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]] : i32, i32, i32, i32 +func.func @extract_2d_constant() -> (i32, i32, i32, i32) { + %cst = arith.constant dense<[[0, 1, 2], [3, 4, 5]]> : vector<2x3xi32> + %a = vector.extract %cst[0, 0] : vector<2x3xi32> + %b = vector.extract %cst[0, 2] : vector<2x3xi32> + %c = vector.extract %cst[1, 0] : vector<2x3xi32> + %d = vector.extract %cst[1, 2] : vector<2x3xi32> + return %a, %b, %c, %d : i32, i32, i32, i32 +} + +// ----- + +// CHECK-LABEL: func.func @extract_3d_constant +// CHECK-DAG: %[[ACST:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[BCST:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[CCST:.*]] = arith.constant 9 : i32 +// CHECK-DAG: %[[DCST:.*]] = arith.constant 10 : i32 +// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]] : i32, i32, i32, i32 +func.func @extract_3d_constant() -> (i32, i32, i32, i32) { + %cst = arith.constant dense<[[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]]> : vector<2x3x2xi32> + %a = vector.extract %cst[0, 0, 0] : vector<2x3x2xi32> + %b = vector.extract %cst[0, 0, 1] : vector<2x3x2xi32> + %c = vector.extract %cst[1, 1, 1] : vector<2x3x2xi32> + %d = vector.extract %cst[1, 2, 0] : vector<2x3x2xi32> + return %a, %b, %c, %d : i32, i32, i32, i32 +} + +// ----- + // CHECK-LABEL: extract_extract_strided // CHECK-SAME: %[[A:.*]]: vector<32x16x4xf16> // CHECK: %[[V:.*]] = vector.extract %[[A]][9, 7] : vector<32x16x4xf16>