diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1623,7 +1623,6 @@ return failure(); auto vecTy = sourceVector.getType().cast(); - Type elemTy = vecTy.getElementType(); ArrayAttr positions = extractOp.getPosition(); if (vecTy.isScalable()) return failure(); @@ -1631,36 +1630,17 @@ // constants. if (vecTy.getRank() != static_cast(positions.size())) return failure(); - // TODO: Handle more element types, e.g., complex values. - if (!elemTy.isIntOrIndexOrFloat()) - return failure(); // The splat case is handled by `ExtractOpSplatConstantFolder`. auto dense = vectorCst.dyn_cast(); if (!dense || dense.isSplat()) return failure(); - // Calculate the flattened position. - int64_t elemPosition = 0; - int64_t innerElems = 1; - for (auto [dimSize, positionInDim] : - llvm::reverse(llvm::zip(vecTy.getShape(), positions))) { - int64_t positionVal = positionInDim.cast().getInt(); - elemPosition += positionVal * innerElems; - innerElems *= dimSize; - } - - Attribute newAttr; - if (vecTy.getElementType().isIntOrIndex()) { - auto values = to_vector(dense.getValues()); - newAttr = IntegerAttr::get(extractOp.getType(), values[elemPosition]); - } else if (vecTy.getElementType().isa()) { - auto values = to_vector(dense.getValues()); - newAttr = FloatAttr::get(extractOp.getType(), values[elemPosition]); - } - assert(newAttr && "Unhandled case"); - - rewriter.replaceOpWithNewOp(extractOp, newAttr); + // Calculate the linearized position. + int64_t elemPosition = + linearize(getI64SubArray(positions), computeStrides(vecTy.getShape())); + Attribute elementValue = *(dense.value_begin() + elemPosition); + rewriter.replaceOpWithNewOp(extractOp, elementValue); return success(); } };