diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -23,6 +23,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" @@ -1623,24 +1624,33 @@ return failure(); auto vecTy = sourceVector.getType().cast(); - ArrayAttr positions = extractOp.getPosition(); if (vecTy.isScalable()) return failure(); - // Do not allow extracting sub-vectors to limit the size of the generated - // constants. - if (vecTy.getRank() != static_cast(positions.size())) - return failure(); // The splat case is handled by `ExtractOpSplatConstantFolder`. auto dense = vectorCst.dyn_cast(); if (!dense || dense.isSplat()) return failure(); - // Calculate the linearized position. - int64_t elemPosition = - linearize(getI64SubArray(positions), computeStrides(vecTy.getShape())); - Attribute elementValue = *(dense.value_begin() + elemPosition); - rewriter.replaceOpWithNewOp(extractOp, elementValue); + // Calculate the linearized position of the continous chunk of elements to + // extract. + llvm::SmallVector completePositions(vecTy.getRank(), 0); + llvm::copy(getI64SubArray(extractOp.getPosition()), + completePositions.begin()); + int64_t elemBeginPosition = + linearize(completePositions, computeStrides(vecTy.getShape())); + auto denseValuesBegin = dense.value_begin() + elemBeginPosition; + + Attribute newAttr; + if (auto resVecTy = extractOp.getType().dyn_cast()) { + SmallVector elementValues( + denseValuesBegin, denseValuesBegin + resVecTy.getNumElements()); + newAttr = DenseElementsAttr::get(resVecTy, elementValues); + } else { + newAttr = *denseValuesBegin; + } + + rewriter.replaceOpWithNewOp(extractOp, newAttr); return success(); } }; diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1471,6 +1471,19 @@ // ----- +// CHECK-LABEL: func.func @extract_vector_2d_constant +// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xi32> +// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[3, 4, 5]> : vector<3xi32> +// CHECK-NEXT: return %[[ACST]], %[[BCST]] : vector<3xi32>, vector<3xi32> +func.func @extract_vector_2d_constant() -> (vector<3xi32>, vector<3xi32>) { + %cst = arith.constant dense<[[0, 1, 2], [3, 4, 5]]> : vector<2x3xi32> + %a = vector.extract %cst[0] : vector<2x3xi32> + %b = vector.extract %cst[1] : vector<2x3xi32> + return %a, %b : vector<3xi32>, vector<3xi32> +} + +// ----- + // CHECK-LABEL: func.func @extract_3d_constant // CHECK-DAG: %[[ACST:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[BCST:.*]] = arith.constant 1 : i32 @@ -1488,6 +1501,38 @@ // ----- +// CHECK-LABEL: func.func @extract_vector_3d_constant +// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<{{\[\[0, 1\], \[2, 3\], \[4, 5\]\]}}> : vector<3x2xi32> +// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[6, 7\], \[8, 9\], \[10, 11\]\]}}> : vector<3x2xi32> +// CHECK-DAG: %[[CCST:.*]] = arith.constant dense<[8, 9]> : vector<2xi32> +// CHECK-DAG: %[[DCST:.*]] = arith.constant dense<[10, 11]> : vector<2xi32> +// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]] : vector<3x2xi32>, vector<3x2xi32>, vector<2xi32>, vector<2xi32> +func.func @extract_vector_3d_constant() -> (vector<3x2xi32>, vector<3x2xi32>, vector<2xi32>, vector<2xi32>) { + %cst = arith.constant dense<[[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]]> : vector<2x3x2xi32> + %a = vector.extract %cst[0] : vector<2x3x2xi32> + %b = vector.extract %cst[1] : vector<2x3x2xi32> + %c = vector.extract %cst[1, 1] : vector<2x3x2xi32> + %d = vector.extract %cst[1, 2] : vector<2x3x2xi32> + return %a, %b, %c, %d : vector<3x2xi32>, vector<3x2xi32>, vector<2xi32>, vector<2xi32> +} + +// ----- + +// CHECK-LABEL: func.func @extract_splat_vector_3d_constant +// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<0> : vector<2xi32> +// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<4> : vector<2xi32> +// CHECK-DAG: %[[CCST:.*]] = arith.constant dense<5> : vector<2xi32> +// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]] : vector<2xi32>, vector<2xi32>, vector<2xi32> +func.func @extract_splat_vector_3d_constant() -> (vector<2xi32>, vector<2xi32>, vector<2xi32>) { + %cst = arith.constant dense<[[[0, 0], [1, 1], [2, 2]], [[3, 3], [4, 4], [5, 5]]]> : vector<2x3x2xi32> + %a = vector.extract %cst[0, 0] : vector<2x3x2xi32> + %b = vector.extract %cst[1, 1] : vector<2x3x2xi32> + %c = vector.extract %cst[1, 2] : vector<2x3x2xi32> + return %a, %b, %c : vector<2xi32>, vector<2xi32>, vector<2xi32> +} + +// ----- + // CHECK-LABEL: extract_extract_strided // CHECK-SAME: %[[A:.*]]: vector<32x16x4xf16> // CHECK: %[[V:.*]] = vector.extract %[[A]][9, 7] : vector<32x16x4xf16>