diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -573,7 +573,7 @@ PredOpTrait<"operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, InferTypeOpAdaptorWithIsCompatible]>, - Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$position)>, + Arguments<(ins AnyVectorOfAnyRank:$vector, DenseI64ArrayAttr:$position)>, Results<(outs AnyType)> { let summary = "extract operation"; let description = [{ @@ -589,7 +589,6 @@ ``` }]; let builders = [ - OpBuilder<(ins "Value":$source, "ArrayRef":$position)>, // Convenience builder which assumes the values in `position` are defined by // ConstantIndexOp. OpBuilder<(ins "Value":$source, "ValueRange":$position)> @@ -689,7 +688,7 @@ PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, AllTypesMatch<["dest", "res"]>]>, - Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, I64ArrayAttr:$position)>, + Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, DenseI64ArrayAttr:$position)>, Results<(outs AnyVectorOfAnyRank:$res)> { let summary = "insert operation"; let description = [{ @@ -711,8 +710,6 @@ }]; let builders = [ - OpBuilder<(ins "Value":$source, "Value":$dest, - "ArrayRef":$position)>, // Convenience builder which assumes all values are constant indices. OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$position)> ]; diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -807,8 +807,7 @@ Value el = rewriter.create(loc, loadedElType, op.getSource(), newIndices); - result = rewriter.create(loc, el, result, - rewriter.getI64ArrayAttr(i)); + result = rewriter.create(loc, el, result, i); } } else { if (auto vecType = dyn_cast(loadedElType)) { @@ -832,7 +831,7 @@ Value el = rewriter.create(op.getLoc(), loadedElType, op.getSource(), newIndices); result = rewriter.create( - op.getLoc(), el, result, rewriter.getI64ArrayAttr({i, innerIdx})); + op.getLoc(), el, result, ArrayRef{i, innerIdx}); } } } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1025,44 +1025,37 @@ auto loc = extractOp->getLoc(); auto resultType = extractOp.getResult().getType(); auto llvmResultType = typeConverter->convertType(resultType); - auto positionArrayAttr = extractOp.getPosition(); + ArrayRef positionArray = extractOp.getPosition(); // Bail if result type cannot be lowered. if (!llvmResultType) return failure(); // Extract entire vector. Should be handled by folder, but just to be safe. - if (positionArrayAttr.empty()) { + if (positionArray.empty()) { rewriter.replaceOp(extractOp, adaptor.getVector()); return success(); } // One-shot extraction of vector from array (only requires extractvalue). if (isa(resultType)) { - SmallVector indices; - for (auto idx : positionArrayAttr.getAsRange()) - indices.push_back(idx.getInt()); Value extracted = rewriter.create( - loc, adaptor.getVector(), indices); + loc, adaptor.getVector(), positionArray); rewriter.replaceOp(extractOp, extracted); return success(); } // Potential extraction of 1-D vector from array. Value extracted = adaptor.getVector(); - auto positionAttrs = positionArrayAttr.getValue(); - if (positionAttrs.size() > 1) { - SmallVector nMinusOnePosition; - for (auto idx : positionAttrs.drop_back()) - nMinusOnePosition.push_back(cast(idx).getInt()); - extracted = rewriter.create(loc, extracted, - nMinusOnePosition); + if (positionArray.size() > 1) { + extracted = rewriter.create( + loc, extracted, positionArray.drop_back()); } // Remaining extraction of element from 1-D LLVM vector - auto position = cast(positionAttrs.back()); auto i64Type = IntegerType::get(rewriter.getContext(), 64); - auto constant = rewriter.create(loc, i64Type, position); + auto constant = + rewriter.create(loc, i64Type, positionArray.back()); extracted = rewriter.create(loc, extracted, constant); rewriter.replaceOp(extractOp, extracted); @@ -1147,7 +1140,7 @@ auto sourceType = insertOp.getSourceType(); auto destVectorType = insertOp.getDestVectorType(); auto llvmResultType = typeConverter->convertType(destVectorType); - auto positionArrayAttr = insertOp.getPosition(); + ArrayRef positionArray = insertOp.getPosition(); // Bail if result type cannot be lowered. if (!llvmResultType) @@ -1155,7 +1148,7 @@ // Overwrite entire vector with value. Should be handled by folder, but // just to be safe. - if (positionArrayAttr.empty()) { + if (positionArray.empty()) { rewriter.replaceOp(insertOp, adaptor.getSource()); return success(); } @@ -1163,36 +1156,32 @@ // One-shot insertion of a vector into an array (only requires insertvalue). if (isa(sourceType)) { Value inserted = rewriter.create( - loc, adaptor.getDest(), adaptor.getSource(), - LLVM::convertArrayToIndices(positionArrayAttr)); + loc, adaptor.getDest(), adaptor.getSource(), positionArray); rewriter.replaceOp(insertOp, inserted); return success(); } // Potential extraction of 1-D vector from array. Value extracted = adaptor.getDest(); - auto positionAttrs = positionArrayAttr.getValue(); - auto position = cast(positionAttrs.back()); auto oneDVectorType = destVectorType; - if (positionAttrs.size() > 1) { + if (positionArray.size() > 1) { oneDVectorType = reducedVectorTypeBack(destVectorType); extracted = rewriter.create( - loc, extracted, - LLVM::convertArrayToIndices(positionAttrs.drop_back())); + loc, extracted, positionArray.drop_back()); } // Insertion of an element into a 1-D LLVM vector. auto i64Type = IntegerType::get(rewriter.getContext(), 64); - auto constant = rewriter.create(loc, i64Type, position); + auto constant = + rewriter.create(loc, i64Type, positionArray.back()); Value inserted = rewriter.create( loc, typeConverter->convertType(oneDVectorType), extracted, adaptor.getSource(), constant); // Potential insertion of resulting 1-D vector into array. - if (positionAttrs.size() > 1) { + if (positionArray.size() > 1) { inserted = rewriter.create( - loc, adaptor.getDest(), inserted, - LLVM::convertArrayToIndices(positionAttrs.drop_back())); + loc, adaptor.getDest(), inserted, positionArray.drop_back()); } rewriter.replaceOp(insertOp, inserted); diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -886,10 +886,9 @@ /// vector::InsertOp, return that operation's indices. void getInsertionIndices(TransferReadOp xferOp, SmallVector &indices) const { - if (auto insertOp = getInsertOp(xferOp)) { - for (Attribute attr : insertOp.getPosition()) - indices.push_back(dyn_cast(attr).getInt()); - } + if (auto insertOp = getInsertOp(xferOp)) + indices.assign(insertOp.getPosition().begin(), + insertOp.getPosition().end()); } /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds @@ -1013,10 +1012,9 @@ /// indices. void getExtractionIndices(TransferWriteOp xferOp, SmallVector &indices) const { - if (auto extractOp = getExtractOp(xferOp)) { - for (Attribute attr : extractOp.getPosition()) - indices.push_back(dyn_cast(attr).getInt()); - } + if (auto extractOp = getExtractOp(xferOp)) + indices.assign(extractOp.getPosition().begin(), + extractOp.getPosition().end()); } /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -152,7 +152,7 @@ return success(); } - int32_t id = getFirstIntValue(extractOp.getPosition()); + int32_t id = extractOp.getPosition()[0]; rewriter.replaceOpWithNewOp( extractOp, adaptor.getVector(), id); return success(); @@ -232,7 +232,7 @@ return success(); } - int32_t id = getFirstIntValue(insertOp.getPosition()); + int32_t id = insertOp.getPosition()[0]; rewriter.replaceOpWithNewOp( insertOp, adaptor.getSource(), adaptor.getDest(), id); return success(); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -385,8 +385,7 @@ } else { // This means we are reducing all the dimensions, and all reduction // dimensions are of size 1. So a simple extraction would do. - auto zeroAttr = - rewriter.getI64ArrayAttr(SmallVector(shape.size(), 0)); + SmallVector zeroAttr(shape.size(), 0); if (mask) mask = rewriter.create(loc, rewriter.getI1Type(), mask, zeroAttr); @@ -560,12 +559,10 @@ result = rewriter.create(loc, reductionOp.getVector()); } else { if (mask) { - mask = rewriter.create(loc, rewriter.getI1Type(), mask, - rewriter.getI64ArrayAttr(0)); + mask = rewriter.create(loc, rewriter.getI1Type(), mask, 0); } result = rewriter.create(loc, reductionOp.getType(), - reductionOp.getVector(), - rewriter.getI64ArrayAttr(0)); + reductionOp.getVector(), 0); } if (Value acc = reductionOp.getAcc()) @@ -1129,18 +1126,11 @@ // ExtractOp //===----------------------------------------------------------------------===// -void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, - Value source, ArrayRef position) { - build(builder, result, source, getVectorSubscriptAttr(builder, position)); -} - // Convenience builder which assumes the values are constant indices. void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, Value source, ValueRange position) { - SmallVector positionConstants = - llvm::to_vector<4>(llvm::map_range(position, [](Value pos) { - return getConstantIntValue(pos).value(); - })); + SmallVector positionConstants = llvm::to_vector(llvm::map_range( + position, [](Value pos) { return getConstantIntValue(pos).value(); })); build(builder, result, source, positionConstants); } @@ -1175,15 +1165,13 @@ } LogicalResult vector::ExtractOp::verify() { - auto positionAttr = getPosition().getValue(); - if (positionAttr.size() > - static_cast(getSourceVectorType().getRank())) + ArrayRef position = getPosition(); + if (position.size() > static_cast(getSourceVectorType().getRank())) return emitOpError( "expected position attribute of rank no greater than vector rank"); - for (const auto &en : llvm::enumerate(positionAttr)) { - auto attr = llvm::dyn_cast(en.value()); - if (!attr || attr.getInt() < 0 || - attr.getInt() >= getSourceVectorType().getDimSize(en.index())) + for (const auto &en : llvm::enumerate(position)) { + if (en.value() < 0 || + en.value() >= getSourceVectorType().getDimSize(en.index())) return emitOpError("expected position attribute #") << (en.index() + 1) << " to be a non-negative integer smaller than the corresponding " @@ -1207,18 +1195,18 @@ SmallVector globalPosition; ExtractOp currentOp = extractOp; - auto extrPos = extractVector(currentOp.getPosition()); + ArrayRef extrPos = currentOp.getPosition(); globalPosition.append(extrPos.rbegin(), extrPos.rend()); while (ExtractOp nextOp = currentOp.getVector().getDefiningOp()) { currentOp = nextOp; - auto extrPos = extractVector(currentOp.getPosition()); + ArrayRef extrPos = currentOp.getPosition(); globalPosition.append(extrPos.rbegin(), extrPos.rend()); } extractOp.setOperand(currentOp.getVector()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); std::reverse(globalPosition.begin(), globalPosition.end()); - extractOp.setPositionAttr(b.getI64ArrayAttr(globalPosition)); + extractOp.setPosition(globalPosition); return success(); } @@ -1329,7 +1317,8 @@ sentinels.reserve(vectorRank - extractedRank); for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i) sentinels.push_back(-(i + 1)); - extractPosition = extractVector(extractOp.getPosition()); + extractPosition.assign(extractOp.getPosition().begin(), + extractOp.getPosition().end()); llvm::append_range(extractPosition, sentinels); } @@ -1349,9 +1338,8 @@ LogicalResult ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos( Value &res) { - auto insertedPos = extractVector(nextInsertOp.getPosition()); - if (ArrayRef(insertedPos) != - llvm::ArrayRef(extractPosition).take_front(extractedRank)) + ArrayRef insertedPos = nextInsertOp.getPosition(); + if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank)) return failure(); // Case 2.a. early-exit fold. res = nextInsertOp.getSource(); @@ -1364,7 +1352,7 @@ /// This method updates the internal state. LogicalResult ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) { - auto insertedPos = extractVector(nextInsertOp.getPosition()); + ArrayRef insertedPos = nextInsertOp.getPosition(); if (!isContainedWithin(insertedPos, extractPosition)) return failure(); // Set leading dims to zero. @@ -1390,9 +1378,7 @@ return Value(); // Otherwise, fold by updating the op inplace and return its result. OpBuilder b(extractOp.getContext()); - extractOp->setAttr( - extractOp.getPositionAttrName(), - b.getI64ArrayAttr(ArrayRef(extractPosition).take_front(extractedRank))); + extractOp.setPosition(ArrayRef(extractPosition).take_front(extractedRank)); extractOp.getVectorMutable().assign(source); return extractOp.getResult(); } @@ -1422,7 +1408,7 @@ // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel // values. This is a more difficult case and we bail. - auto insertedPos = extractVector(nextInsertOp.getPosition()); + ArrayRef insertedPos = nextInsertOp.getPosition(); if (isContainedWithin(extractPosition, insertedPos) || intersectsWhereNonNegative(extractPosition, insertedPos)) return Value(); @@ -1487,7 +1473,7 @@ // extract position to `0` when extracting from the source operand. llvm::SetVector broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims(); - auto extractPos = extractVector(extractOp.getPosition()); + SmallVector extractPos(extractOp.getPosition()); for (int64_t i = rankDiff, e = extractPos.size(); i < e; ++i) if (broadcastedUnitDims.contains(i)) extractPos[i] = 0; @@ -1498,7 +1484,7 @@ // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); extractOp.setOperand(source); - extractOp.setPositionAttr(b.getI64ArrayAttr(extractPos)); + extractOp.setPosition(extractPos); return extractOp.getResult(); } @@ -1537,7 +1523,7 @@ } // Extract the strides associated with the extract op vector source. Then use // this to calculate a linearized position for the extract. - auto extractedPos = extractVector(extractOp.getPosition()); + SmallVector extractedPos(extractOp.getPosition()); std::reverse(extractedPos.begin(), extractedPos.end()); SmallVector strides; int64_t stride = 1; @@ -1563,7 +1549,7 @@ SmallVector newPosition = delinearize(position, newStrides); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); - extractOp.setPositionAttr(b.getI64ArrayAttr(newPosition)); + extractOp.setPosition(newPosition); extractOp.setOperand(shapeCastOp.getSource()); return extractOp.getResult(); } @@ -1603,14 +1589,14 @@ if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() - sliceOffsets.size()) return Value(); - auto extractedPos = extractVector(extractOp.getPosition()); + SmallVector extractedPos(extractOp.getPosition()); assert(extractedPos.size() >= sliceOffsets.size()); for (size_t i = 0, e = sliceOffsets.size(); i < e; i++) extractedPos[i] = extractedPos[i] + sliceOffsets[i]; extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); - extractOp.setPositionAttr(b.getI64ArrayAttr(extractedPos)); + extractOp.setPosition(extractedPos); return extractOp.getResult(); } @@ -1635,7 +1621,7 @@ if (destinationRank > insertOp.getSourceVectorType().getRank()) return Value(); auto insertOffsets = extractVector(insertOp.getOffsets()); - auto extractOffsets = extractVector(extractOp.getPosition()); + ArrayRef extractOffsets = extractOp.getPosition(); if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) { return llvm::cast(attr).getInt() != 1; @@ -1675,7 +1661,7 @@ extractOp.getVectorMutable().assign(insertOp.getSource()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); - extractOp.setPositionAttr(b.getI64ArrayAttr(offsetDiffs)); + extractOp.setPosition(offsetDiffs); return extractOp.getResult(); } // If the chunk extracted is disjoint from the chunk inserted, keep @@ -1795,7 +1781,7 @@ // Calculate the linearized position of the continuous chunk of elements to // extract. llvm::SmallVector completePositions(vecTy.getRank(), 0); - copy(getI64SubArray(extractOp.getPosition()), completePositions.begin()); + copy(extractOp.getPosition(), completePositions.begin()); int64_t elemBeginPosition = linearize(completePositions, computeStrides(vecTy.getShape())); auto denseValuesBegin = dense.value_begin() + elemBeginPosition; @@ -2288,14 +2274,6 @@ // InsertOp //===----------------------------------------------------------------------===// -void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, - Value dest, ArrayRef position) { - result.addOperands({source, dest}); - auto positionAttr = getVectorSubscriptAttr(builder, position); - result.addTypes(dest.getType()); - result.addAttribute(InsertOp::getPositionAttrName(result.name), positionAttr); -} - // Convenience builder which assumes the values are constant indices. void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, Value dest, ValueRange position) { @@ -2307,25 +2285,24 @@ } LogicalResult InsertOp::verify() { - auto positionAttr = getPosition().getValue(); + ArrayRef position = getPosition(); auto destVectorType = getDestVectorType(); - if (positionAttr.size() > static_cast(destVectorType.getRank())) + if (position.size() > static_cast(destVectorType.getRank())) return emitOpError( "expected position attribute of rank no greater than dest vector rank"); auto srcVectorType = llvm::dyn_cast(getSourceType()); if (srcVectorType && - (static_cast(srcVectorType.getRank()) + positionAttr.size() != + (static_cast(srcVectorType.getRank()) + position.size() != static_cast(destVectorType.getRank()))) return emitOpError("expected position attribute rank + source rank to " "match dest vector rank"); if (!srcVectorType && - (positionAttr.size() != static_cast(destVectorType.getRank()))) + (position.size() != static_cast(destVectorType.getRank()))) return emitOpError( "expected position attribute rank to match the dest vector rank"); - for (const auto &en : llvm::enumerate(positionAttr)) { - auto attr = llvm::dyn_cast(en.value()); - if (!attr || attr.getInt() < 0 || - attr.getInt() >= destVectorType.getDimSize(en.index())) + for (const auto &en : llvm::enumerate(position)) { + int64_t attr = en.value(); + if (attr < 0 || attr >= destVectorType.getDimSize(en.index())) return emitOpError("expected position attribute #") << (en.index() + 1) << " to be a non-negative integer smaller than the corresponding " @@ -2412,7 +2389,7 @@ // Calculate the linearized position of the continuous chunk of elements to // insert. llvm::SmallVector completePositions(destTy.getRank(), 0); - copy(getI64SubArray(op.getPosition()), completePositions.begin()); + copy(op.getPosition(), completePositions.begin()); int64_t insertBeginPosition = linearize(completePositions, computeStrides(destTy.getShape())); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -91,10 +91,8 @@ return val; Type lowType = VectorType::Builder(type).dropDim(0); // At extraction dimension? - if (index == 0) { - auto posAttr = rewriter.getI64ArrayAttr(pos); - return rewriter.create(loc, lowType, val, posAttr); - } + if (index == 0) + return rewriter.create(loc, lowType, val, pos); // Unroll leading dimensions. VectorType vType = cast(lowType); Type resType = VectorType::Builder(type).dropDim(index); @@ -102,11 +100,10 @@ Value result = rewriter.create( loc, resVectorType, rewriter.getZeroAttr(resVectorType)); for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) { - auto posAttr = rewriter.getI64ArrayAttr(d); - Value ext = rewriter.create(loc, vType, val, posAttr); + Value ext = rewriter.create(loc, vType, val, d); Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); - result = rewriter.create(loc, resVectorType, load, result, - posAttr); + result = + rewriter.create(loc, resVectorType, load, result, d); } return result; } @@ -120,20 +117,17 @@ if (index == -1) return val; // At insertion dimension? - if (index == 0) { - auto posAttr = rewriter.getI64ArrayAttr(pos); - return rewriter.create(loc, type, val, result, posAttr); - } + if (index == 0) + return rewriter.create(loc, type, val, result, pos); // Unroll leading dimensions. Type lowType = VectorType::Builder(type).dropDim(0); VectorType vType = cast(lowType); Type insType = VectorType::Builder(vType).dropDim(0); for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { - auto posAttr = rewriter.getI64ArrayAttr(d); - Value ext = rewriter.create(loc, vType, result, posAttr); - Value ins = rewriter.create(loc, insType, val, posAttr); + Value ext = rewriter.create(loc, vType, result, d); + Value ins = rewriter.create(loc, insType, val, d); Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); - result = rewriter.create(loc, type, sto, result, posAttr); + result = rewriter.create(loc, type, sto, result, d); } return result; } @@ -823,10 +817,8 @@ newRhs = rewriter.create(loc, newRhs, rhsTranspose); SmallVector lhsOffsets(lhsReductionDims.size(), 0); SmallVector rhsOffsets(rhsReductionDims.size(), 0); - newLhs = rewriter.create( - loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets)); - newRhs = rewriter.create( - loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets)); + newLhs = rewriter.create(loc, newLhs, lhsOffsets); + newRhs = rewriter.create(loc, newRhs, rhsOffsets); std::optional result = createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(), contractOp.getKind(), rewriter, isInt); @@ -1167,21 +1159,20 @@ Value result = rewriter.create( loc, resType, rewriter.getZeroAttr(resType)); for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { - auto pos = rewriter.getI64ArrayAttr(d); - Value x = rewriter.create(loc, op.getLhs(), pos); + Value x = rewriter.create(loc, op.getLhs(), d); Value a = rewriter.create(loc, rhsType, x); Value r = nullptr; if (acc) - r = rewriter.create(loc, acc, pos); + r = rewriter.create(loc, acc, d); Value extrMask; if (mask) - extrMask = rewriter.create(loc, mask, pos); + extrMask = rewriter.create(loc, mask, d); std::optional m = createContractArithOp( loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask); if (!m.has_value()) return failure(); - result = rewriter.create(loc, resType, *m, result, pos); + result = rewriter.create(loc, resType, *m, result, d); } rewriter.replaceOp(rootOp, result); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -77,9 +77,7 @@ Value val = rewriter.create(loc, arith::CmpIPredicate::slt, bnd, idx); Value sel = rewriter.create(loc, val, trueVal, falseVal); - auto pos = rewriter.getI64ArrayAttr(d); - result = - rewriter.create(loc, dstType, sel, result, pos); + result = rewriter.create(loc, dstType, sel, result, d); } rewriter.replaceOp(op, result); return success(); @@ -151,11 +149,9 @@ loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); Value result = rewriter.create( loc, dstType, rewriter.getZeroAttr(dstType)); - for (int64_t d = 0; d < trueDim; d++) { - auto pos = rewriter.getI64ArrayAttr(d); + for (int64_t d = 0; d < trueDim; d++) result = - rewriter.create(loc, dstType, trueVal, result, pos); - } + rewriter.create(loc, dstType, trueVal, result, d); rewriter.replaceOp(op, result); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -944,7 +944,7 @@ // Rewrite vector.extract with 1d source to vector.extractelement. if (extractSrcType.getRank() == 1) { assert(extractOp.getPosition().size() == 1 && "expected 1 index"); - int64_t pos = cast(extractOp.getPosition()[0]).getInt(); + int64_t pos = extractOp.getPosition()[0]; rewriter.setInsertionPoint(extractOp); rewriter.replaceOpWithNewOp( extractOp, extractOp.getVector(), @@ -1201,7 +1201,7 @@ // Rewrite vector.insert with 1d dest to vector.insertelement. if (insertOp.getDestVectorType().getRank() == 1) { assert(insertOp.getPosition().size() == 1 && "expected 1 index"); - int64_t pos = cast(insertOp.getPosition()[0]).getInt(); + int64_t pos = insertOp.getPosition()[0]; rewriter.setInsertionPoint(insertOp); rewriter.replaceOpWithNewOp( insertOp, insertOp.getSource(), insertOp.getDest(), @@ -1276,10 +1276,7 @@ } else { // One lane inserts the entire source vector. int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim); - SmallVector newPos = llvm::to_vector( - llvm::map_range(insertOp.getPosition(), [](Attribute attr) { - return cast(attr).getInt(); - })); + SmallVector newPos(insertOp.getPosition()); // tid of inserting lane: pos / elementsPerLane Value insertingLane = rewriter.create( loc, newPos[distrDestDim] / elementsPerLane); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -165,16 +165,14 @@ // type has leading unit dims, we also trim the position array accordingly, // then (2) if source type also has leading unit dims, we need to append // zeroes to the position array accordingly. - unsigned oldPosRank = insertOp.getPosition().getValue().size(); + unsigned oldPosRank = insertOp.getPosition().size(); unsigned newPosRank = std::max(0, oldPosRank - dstDropCount); - SmallVector newPositions = llvm::to_vector( - insertOp.getPosition().getValue().take_back(newPosRank)); - newPositions.resize(newDstType.getRank() - newSrcRank, - rewriter.getI64IntegerAttr(0)); + SmallVector newPositions = + llvm::to_vector(insertOp.getPosition().take_back(newPosRank)); + newPositions.resize(newDstType.getRank() - newSrcRank, 0); auto newInsertOp = rewriter.create( - loc, newDstType, newSrcVector, newDstVector, - rewriter.getArrayAttr(newPositions)); + loc, newDstType, newSrcVector, newDstVector, newPositions); rewriter.replaceOpWithNewOp(insertOp, oldDstType, newInsertOp); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -704,7 +704,7 @@ SmallVector newIndices(xferOp.getIndices().begin(), xferOp.getIndices().end()); for (const auto &it : llvm::enumerate(extractOp.getPosition())) { - int64_t offset = cast(it.value()).getInt(); + int64_t offset = it.value(); int64_t idx = newIndices.size() - extractOp.getPosition().size() + it.index(); OpFoldResult ofr = affine::makeComposedFoldedAffineApply( diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -598,11 +598,7 @@ unsigned expandRatio = castDstType.getNumElements() / castSrcType.getNumElements(); - auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t { - return (*attr.getAsValueRange().begin()).getZExtValue(); - }; - - uint64_t index = getFirstIntValue(extractOp.getPosition()); + uint64_t index = extractOp.getPosition()[0]; // Get the single scalar (as a vector) in the source value that packs the // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> @@ -610,7 +606,7 @@ VectorType::get({1}, castSrcType.getElementType()); Value packedValue = rewriter.create( extractOp.getLoc(), oneScalarType, castOp.getSource(), - rewriter.getI64ArrayAttr(index / expandRatio)); + index / expandRatio); // Cast it to a vector with the desired scalar's type. // E.g. f32 -> vector<2xf16> @@ -621,8 +617,7 @@ // Finally extract the desired scalar. rewriter.replaceOpWithNewOp( - extractOp, extractOp.getType(), castedValue, - rewriter.getI64ArrayAttr(index % expandRatio)); + extractOp, extractOp.getType(), castedValue, index % expandRatio); return success(); } diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -155,8 +155,8 @@ // CHECK: spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32> // CHECK: spirv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32> func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) { - %0 = "vector.extract"(%arg0) {position = [0]} : (vector<2xf32>) -> vector<1xf32> - %1 = "vector.extract"(%arg0) {position = [1]} : (vector<2xf32>) -> f32 + %0 = "vector.extract"(%arg0) <{position = array}> : (vector<2xf32>) -> vector<1xf32> + %1 = "vector.extract"(%arg0) <{position = array}> : (vector<2xf32>) -> f32 return %0, %1: vector<1xf32>, f32 } diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -133,7 +133,7 @@ func.func @extract_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute of rank no greater than vector rank}} - %1 = "vector.extract" (%arg0) { position = [0, 0, 0, 0] } : (vector<4x8x16xf32>) -> (vector<16xf32>) + %1 = "vector.extract" (%arg0) <{position = array}> : (vector<4x8x16xf32>) -> (vector<16xf32>) } // -----