diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -47,6 +47,9 @@ } // namespace llvm namespace mlir { +namespace arith { +class ConstantOp; +} namespace LLVM { class LLVMDialect; diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -14,6 +14,7 @@ #define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H #include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" #include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h" #include "mlir/IR/AffineMap.h" @@ -134,6 +135,15 @@ return cast(attr).getValue() == IteratorType::reduction; } +/// Returns the integer numbers in `values`. `values` are expected to be +/// constant operations. +SmallVector getAsIntegers(ArrayRef values); + +/// Returns the constant index ops in `values`. `values` are expected to be +/// constant operations. +SmallVector +getAsConstantIndexOps(ArrayRef values); + //===----------------------------------------------------------------------===// // Vector Masking Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -577,7 +577,7 @@ PredOpTrait<"operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, DeclareOpInterfaceMethods]>, - Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$position)>, + Arguments<(ins AnyVectorOfAnyRank:$vector, Variadic:$position)>, Results<(outs AnyType)> { let summary = "extract operation"; let description = [{ @@ -593,19 +593,20 @@ ``` }]; let builders = [ + // Convenience builder for a single constant index. + OpBuilder<(ins "Value":$source, "int64_t":$position)>, + // Convenience builder for constant indices. OpBuilder<(ins "Value":$source, "ArrayRef":$position)>, - // Convenience builder which assumes the values in `position` are defined by - // ConstantIndexOp. - OpBuilder<(ins "Value":$source, "ValueRange":$position)> + // Convenience builder for single value index. + OpBuilder<(ins "Value":$source, "Value":$position)>, ]; let extraClassDeclaration = [{ - static StringRef getPositionAttrStrName() { return "position"; } VectorType getSourceVectorType() { return ::llvm::cast(getVector().getType()); } static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); }]; - let assemblyFormat = "$vector `` $position attr-dict `:` type($vector)"; + let assemblyFormat = "$vector `[` $position `]` attr-dict `:` type($vector)"; let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; @@ -695,7 +696,8 @@ PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, AllTypesMatch<["dest", "res"]>]>, - Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, I64ArrayAttr:$position)>, + Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, + Variadic:$position)>, Results<(outs AnyVectorOfAnyRank:$res)> { let summary = "insert operation"; let description = [{ @@ -706,24 +708,26 @@ Example: ```mlir - %2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32> - %5 = vector.insert %3, %4[3, 3, 3] : f32 into vector<4x8x16xf32> + %2 = vector.insert %0, %1[%c3] : vector<8x16xf32> into vector<4x8x16xf32> + %5 = vector.insert %3, %4[%c3, %c3, %c3] : f32 into vector<4x8x16xf32> %8 = vector.insert %6, %7[] : f32 into vector - %11 = vector.insert %9, %10[3, 3, 3] : vector into vector<4x8x16xf32> + %11 = vector.insert %9, %10[%c3, %c3, %c3] : vector into vector<4x8x16xf32> ``` }]; let assemblyFormat = [{ - $source `,` $dest $position attr-dict `:` type($source) `into` type($dest) + $source `,` $dest `[` $position `]` attr-dict `:` type($source) `into` type($dest) }]; let builders = [ + // Convenience builder for a single constant index. + OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>, + // Convenience builder for constant indices. OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef":$position)>, - // Convenience builder which assumes all values are constant indices. - OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$position)> + // Convenience builder for single value index. + OpBuilder<(ins "Value":$source, "Value":$dest, "Value":$position)>, ]; let extraClassDeclaration = [{ - static StringRef getPositionAttrStrName() { return "position"; } Type getSourceType() { return getSource().getType(); } VectorType getDestVectorType() { return ::llvm::cast(getDest().getType()); diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -807,8 +807,7 @@ Value el = rewriter.create(loc, loadedElType, op.getSource(), newIndices); - result = rewriter.create(loc, el, result, - rewriter.getI64ArrayAttr(i)); + result = rewriter.create(loc, el, result, i); } } else { if (auto vecType = dyn_cast(loadedElType)) { @@ -832,7 +831,7 @@ Value el = rewriter.create(op.getLoc(), loadedElType, op.getSource(), newIndices); result = rewriter.create( - op.getLoc(), el, result, rewriter.getI64ArrayAttr({i, innerIdx})); + op.getLoc(), el, result, ArrayRef({i, innerIdx})); } } } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -130,6 +130,20 @@ return rewriter.create(loc, pType, ptr); } +/// Convert an array of arith constant ops to a vector of integers that can be +/// used as indices in LLVM operations. +template +static SmallVector convertConstantsToInts(ArrayRef values) { + SmallVector indices; + indices.reserve(values.size()); + llvm::transform(values, std::back_inserter(indices), [](Value value) { + auto constantOp = value.getDefiningOp(); + assert(constantOp && "Unexpected non-constant index"); + return cast(constantOp.getValue()).getInt(); + }); + return indices; +} + namespace { /// Trivial Vector to LLVM conversions @@ -1027,46 +1041,39 @@ auto loc = extractOp->getLoc(); auto resultType = extractOp.getResult().getType(); auto llvmResultType = typeConverter->convertType(resultType); - auto positionArrayAttr = extractOp.getPosition(); + SmallVector positionVec = extractOp.getPosition(); + ArrayRef position(positionVec); // Bail if result type cannot be lowered. if (!llvmResultType) return failure(); // Extract entire vector. Should be handled by folder, but just to be safe. - if (positionArrayAttr.empty()) { + if (position.empty()) { rewriter.replaceOp(extractOp, adaptor.getVector()); return success(); } // One-shot extraction of vector from array (only requires extractvalue). if (isa(resultType)) { - SmallVector indices; - for (auto idx : positionArrayAttr.getAsRange()) - indices.push_back(idx.getInt()); Value extracted = rewriter.create( - loc, adaptor.getVector(), indices); + loc, adaptor.getVector(), getAsIntegers(position)); rewriter.replaceOp(extractOp, extracted); return success(); } // Potential extraction of 1-D vector from array. Value extracted = adaptor.getVector(); - auto positionAttrs = positionArrayAttr.getValue(); - if (positionAttrs.size() > 1) { - SmallVector nMinusOnePosition; - for (auto idx : positionAttrs.drop_back()) - nMinusOnePosition.push_back(cast(idx).getInt()); + if (position.size() > 1) { + SmallVector nMinusOnePosition = + getAsIntegers(position.drop_back()); extracted = rewriter.create(loc, extracted, nMinusOnePosition); } // Remaining extraction of element from 1-D LLVM vector - auto position = cast(positionAttrs.back()); - auto i64Type = IntegerType::get(rewriter.getContext(), 64); - auto constant = rewriter.create(loc, i64Type, position); - extracted = - rewriter.create(loc, extracted, constant); + extracted = rewriter.create(loc, extracted, + position.back()); rewriter.replaceOp(extractOp, extracted); return success(); @@ -1149,7 +1156,8 @@ auto sourceType = insertOp.getSourceType(); auto destVectorType = insertOp.getDestVectorType(); auto llvmResultType = typeConverter->convertType(destVectorType); - auto positionArrayAttr = insertOp.getPosition(); + SmallVector positionVec = insertOp.getPosition(); + ArrayRef position(positionVec); // Bail if result type cannot be lowered. if (!llvmResultType) @@ -1157,7 +1165,7 @@ // Overwrite entire vector with value. Should be handled by folder, but // just to be safe. - if (positionArrayAttr.empty()) { + if (position.empty()) { rewriter.replaceOp(insertOp, adaptor.getSource()); return success(); } @@ -1166,21 +1174,18 @@ if (isa(sourceType)) { Value inserted = rewriter.create( loc, adaptor.getDest(), adaptor.getSource(), - LLVM::convertArrayToIndices(positionArrayAttr)); + convertConstantsToInts(position)); rewriter.replaceOp(insertOp, inserted); return success(); } // Potential extraction of 1-D vector from array. Value extracted = adaptor.getDest(); - auto positionAttrs = positionArrayAttr.getValue(); - auto position = cast(positionAttrs.back()); auto oneDVectorType = destVectorType; - if (positionAttrs.size() > 1) { + if (position.size() > 1) { oneDVectorType = reducedVectorTypeBack(destVectorType); extracted = rewriter.create( - loc, extracted, - LLVM::convertArrayToIndices(positionAttrs.drop_back())); + loc, extracted, convertConstantsToInts(position.drop_back())); } // Insertion of an element into a 1-D LLVM vector. @@ -1191,10 +1196,10 @@ adaptor.getSource(), constant); // Potential insertion of resulting 1-D vector into array. - if (positionAttrs.size() > 1) { + if (position.size() > 1) { inserted = rewriter.create( loc, adaptor.getDest(), inserted, - LLVM::convertArrayToIndices(positionAttrs.drop_back())); + convertConstantsToInts(position.drop_back())); } rewriter.replaceOp(insertOp, inserted); diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -885,10 +885,10 @@ /// If the result of the TransferReadOp has exactly one user, which is a /// vector::InsertOp, return that operation's indices. void getInsertionIndices(TransferReadOp xferOp, - SmallVector &indices) const { + SmallVectorImpl &indices) const { if (auto insertOp = getInsertOp(xferOp)) { - for (Attribute attr : insertOp.getPosition()) - indices.push_back(dyn_cast(attr).getInt()); + auto pos = insertOp.getPosition(); + indices.append(pos.begin(), pos.end()); } } @@ -927,9 +927,10 @@ getXferIndices(b, xferOp, iv, xferIndices); // Indices for the new vector.insert op. - SmallVector insertionIndices; + SmallVector insertionIndices; getInsertionIndices(xferOp, insertionIndices); - insertionIndices.push_back(i); + insertionIndices.push_back( + b.create(loc, i)); auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); auto newXferOp = b.create( @@ -1012,10 +1013,10 @@ /// If the input of the given TransferWriteOp is an ExtractOp, return its /// indices. void getExtractionIndices(TransferWriteOp xferOp, - SmallVector &indices) const { + SmallVectorImpl &indices) const { if (auto extractOp = getExtractOp(xferOp)) { - for (Attribute attr : extractOp.getPosition()) - indices.push_back(dyn_cast(attr).getInt()); + auto pos = extractOp.getPosition(); + indices.append(pos.begin(), pos.end()); } } @@ -1053,9 +1054,10 @@ getXferIndices(b, xferOp, iv, xferIndices); // Indices for the new vector.extract op. - SmallVector extractionIndices; + SmallVector extractionIndices; getExtractionIndices(xferOp, extractionIndices); - extractionIndices.push_back(i); + extractionIndices.push_back( + b.create(loc, i)); auto extracted = b.create(loc, vec, extractionIndices); diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -34,6 +34,9 @@ /// Gets the first integer value from `attr`, assuming it is an integer array /// attribute. +static uint64_t getFirstIntValue(ValueRange values) { + return values[0].getDefiningOp().value(); +} static uint64_t getFirstIntValue(ArrayAttr attr) { return (*attr.getAsValueRange().begin()).getZExtValue(); } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -233,6 +233,31 @@ return failure(); } +/// Returns the integer numbers in `values`. `values` are expected to be +/// constant operations. +SmallVector vector::getAsIntegers(ArrayRef values) { + SmallVector ints; + llvm::transform(values, std::back_inserter(ints), [](Value value) { + auto constOp = value.getDefiningOp(); + assert(constOp && "Unexpected non-constant index"); + return constOp.value(); + }); + return ints; +} + +/// Returns the constant ops in `values`. `values` are expected to be constant +/// operations. +SmallVector +vector::getAsConstantIndexOps(ArrayRef values) { + SmallVector constIdxs; + llvm::transform(values, std::back_inserter(constIdxs), [](Value value) { + auto constOp = value.getDefiningOp(); + assert(constOp && "Unexpected non-constant index"); + return constOp; + }); + return constIdxs; +} + //===----------------------------------------------------------------------===// // CombiningKindAttr //===----------------------------------------------------------------------===// @@ -394,13 +419,10 @@ } else { // This means we are reducing all the dimensions, and all reduction // dimensions are of size 1. So a simple extraction would do. - auto zeroAttr = - rewriter.getI64ArrayAttr(SmallVector(shape.size(), 0)); if (mask) - mask = rewriter.create(loc, rewriter.getI1Type(), - mask, zeroAttr); - cast = rewriter.create( - loc, reductionOp.getDestType(), reductionOp.getSource(), zeroAttr); + mask = rewriter.create(loc, mask, 0); + cast = + rewriter.create(loc, reductionOp.getSource(), 0); } Value result = vector::makeArithReduction( @@ -568,13 +590,9 @@ mask = rewriter.create(loc, mask); result = rewriter.create(loc, reductionOp.getVector()); } else { - if (mask) { - mask = rewriter.create(loc, rewriter.getI1Type(), mask, - rewriter.getI64ArrayAttr(0)); - } - result = rewriter.create(loc, reductionOp.getType(), - reductionOp.getVector(), - rewriter.getI64ArrayAttr(0)); + if (mask) + mask = rewriter.create(loc, mask, 0); + result = rewriter.create(loc, reductionOp.getVector(), 0); } if (Value acc = reductionOp.getAcc()) @@ -1138,19 +1156,24 @@ // ExtractOp //===----------------------------------------------------------------------===// +void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, + Value source, int64_t position) { + build(builder, result, source, ArrayRef(position)); +} + void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, Value source, ArrayRef position) { - build(builder, result, source, getVectorSubscriptAttr(builder, position)); + SmallVector posVals; + posVals.reserve(position.size()); + llvm::transform(position, std::back_inserter(posVals), [&](int64_t pos) { + return builder.create(result.location, pos); + }); + build(builder, result, source, posVals); } -// Convenience builder which assumes the values are constant indices. void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, - Value source, ValueRange position) { - SmallVector positionConstants = - llvm::to_vector<4>(llvm::map_range(position, [](Value pos) { - return getConstantIntValue(pos).value(); - })); - build(builder, result, source, positionConstants); + Value source, Value position) { + build(builder, result, source, ValueRange(position)); } LogicalResult @@ -1183,20 +1206,23 @@ return l == r; } -LogicalResult vector::ExtractOp::verify() { - auto positionAttr = getPosition().getValue(); - if (positionAttr.size() > - static_cast(getSourceVectorType().getRank())) +LogicalResult ExtractOp::verify() { + auto position = getPosition(); + if (position.size() > static_cast(getSourceVectorType().getRank())) return emitOpError( "expected position attribute of rank no greater than vector rank"); - for (const auto &en : llvm::enumerate(positionAttr)) { - auto attr = llvm::dyn_cast(en.value()); - if (!attr || attr.getInt() < 0 || - attr.getInt() >= getSourceVectorType().getDimSize(en.index())) - return emitOpError("expected position attribute #") - << (en.index() + 1) - << " to be a non-negative integer smaller than the corresponding " - "vector dimension"; + for (auto [idx, pos] : llvm::enumerate(position)) { + auto constOp = pos.getDefiningOp(); + if (constOp) { + if (constOp.value() < 0 || + constOp.value() >= getSourceVectorType().getDimSize(idx)) { + return emitOpError("expected position attribute #") + << (idx + 1) + << " to be a non-negative integer smaller than the " + "corresponding " + "vector dimension"; + } + } } return success(); } @@ -1214,21 +1240,19 @@ if (!extractOp.getVector().getDefiningOp()) return failure(); - SmallVector globalPosition; + SmallVector globalPosition; ExtractOp currentOp = extractOp; - auto extrPos = extractVector(currentOp.getPosition()); + SmallVector extrPos = currentOp.getPosition(); globalPosition.append(extrPos.rbegin(), extrPos.rend()); while (ExtractOp nextOp = currentOp.getVector().getDefiningOp()) { currentOp = nextOp; - auto extrPos = extractVector(currentOp.getPosition()); + extrPos = currentOp.getPosition(); globalPosition.append(extrPos.rbegin(), extrPos.rend()); } - extractOp.setOperand(currentOp.getVector()); - // OpBuilder is only used as a helper to build an I64ArrayAttr. - OpBuilder b(extractOp.getContext()); - std::reverse(globalPosition.begin(), globalPosition.end()); - extractOp->setAttr(ExtractOp::getPositionAttrStrName(), - b.getI64ArrayAttr(globalPosition)); + SmallVector newOperands; + newOperands.push_back(currentOp.getVector()); + newOperands.append(globalPosition.rbegin(), globalPosition.rend()); + extractOp->setOperands(newOperands); return success(); } @@ -1339,7 +1363,8 @@ sentinels.reserve(vectorRank - extractedRank); for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i) sentinels.push_back(-(i + 1)); - extractPosition = extractVector(extractOp.getPosition()); + SmallVector pos = extractOp.getPosition(); + extractPosition = getAsIntegers(pos); llvm::append_range(extractPosition, sentinels); } @@ -1359,9 +1384,10 @@ LogicalResult ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos( Value &res) { - auto insertedPos = extractVector(nextInsertOp.getPosition()); + SmallVector pos = nextInsertOp.getPosition(); + auto insertedPos = getAsIntegers(pos); if (ArrayRef(insertedPos) != - llvm::ArrayRef(extractPosition).take_front(extractedRank)) + ArrayRef(extractPosition).take_front(extractedRank)) return failure(); // Case 2.a. early-exit fold. res = nextInsertOp.getSource(); @@ -1374,7 +1400,8 @@ /// This method updates the internal state. LogicalResult ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) { - auto insertedPos = extractVector(nextInsertOp.getPosition()); + SmallVector pos = nextInsertOp.getPosition(); + auto insertedPos = getAsIntegers(pos); if (!isContainedWithin(insertedPos, extractPosition)) return failure(); // Set leading dims to zero. @@ -1400,10 +1427,12 @@ return Value(); // Otherwise, fold by updating the op inplace and return its result. OpBuilder b(extractOp.getContext()); - extractOp->setAttr( - extractOp.getPositionAttrName(), - b.getI64ArrayAttr(ArrayRef(extractPosition).take_front(extractedRank))); - extractOp.getVectorMutable().assign(source); + SmallVector newOperands; + newOperands.push_back(source); + for (int64_t index : ArrayRef(extractPosition).take_front(extractedRank)) + newOperands.push_back( + b.create(extractOp.getLoc(), index)); + extractOp->setOperands(newOperands); return extractOp.getResult(); } @@ -1432,7 +1461,8 @@ // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel // values. This is a more difficult case and we bail. - auto insertedPos = extractVector(nextInsertOp.getPosition()); + SmallVector pos = nextInsertOp.getPosition(); + auto insertedPos = getAsIntegers(pos); if (isContainedWithin(extractPosition, insertedPos) || intersectsWhereNonNegative(extractPosition, insertedPos)) return Value(); @@ -1497,19 +1527,19 @@ // extract position to `0` when extracting from the source operand. llvm::SetVector broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims(); - auto extractPos = extractVector(extractOp.getPosition()); + SmallVector extractPos = extractOp.getPosition(); + OpBuilder b(extractOp); for (int64_t i = rankDiff, e = extractPos.size(); i < e; ++i) if (broadcastedUnitDims.contains(i)) - extractPos[i] = 0; + extractPos[i] = b.create(extractOp.getLoc(), 0); // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the // matching extract position when extracting from the source operand. - extractPos.erase(extractPos.begin(), - std::next(extractPos.begin(), extractPos.size() - rankDiff)); - // OpBuilder is only used as a helper to build an I64ArrayAttr. - OpBuilder b(extractOp.getContext()); - extractOp.setOperand(source); - extractOp->setAttr(ExtractOp::getPositionAttrStrName(), - b.getI64ArrayAttr(extractPos)); + SmallVector newOperands; + newOperands.push_back(source); + newOperands.append( + std::next(extractPos.begin(), extractPos.size() - rankDiff), + extractPos.end()); + extractOp->setOperands(newOperands); return extractOp.getResult(); } @@ -1548,7 +1578,8 @@ } // Extract the strides associated with the extract op vector source. Then use // this to calculate a linearized position for the extract. - auto extractedPos = extractVector(extractOp.getPosition()); + SmallVector pos = extractOp.getPosition(); + auto extractedPos = getAsIntegers(pos); std::reverse(extractedPos.begin(), extractedPos.end()); SmallVector strides; int64_t stride = 1; @@ -1572,11 +1603,14 @@ } std::reverse(newStrides.begin(), newStrides.end()); SmallVector newPosition = delinearize(position, newStrides); - // OpBuilder is only used as a helper to build an I64ArrayAttr. - OpBuilder b(extractOp.getContext()); - extractOp->setAttr(ExtractOp::getPositionAttrStrName(), - b.getI64ArrayAttr(newPosition)); - extractOp.setOperand(shapeCastOp.getSource()); + OpBuilder b(extractOp); + SmallVector newOperands; + newOperands.push_back(shapeCastOp.getSource()); + for (int64_t pos : newPosition) { + newOperands.push_back( + b.create(extractOp.getLoc(), pos)); + } + extractOp->setOperands(newOperands); return extractOp.getResult(); } @@ -1615,15 +1649,20 @@ if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() - sliceOffsets.size()) return Value(); - auto extractedPos = extractVector(extractOp.getPosition()); + SmallVector valuePos = extractOp.getPosition(); + auto extractedPos = getAsIntegers(valuePos); assert(extractedPos.size() >= sliceOffsets.size()); for (size_t i = 0, e = sliceOffsets.size(); i < e; i++) extractedPos[i] = extractedPos[i] + sliceOffsets[i]; extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector()); - // OpBuilder is only used as a helper to build an I64ArrayAttr. - OpBuilder b(extractOp.getContext()); - extractOp->setAttr(ExtractOp::getPositionAttrStrName(), - b.getI64ArrayAttr(extractedPos)); + OpBuilder b(extractOp); + SmallVector newOperands; + newOperands.push_back(extractOp.getVector()); + for (int64_t pos : extractedPos) { + newOperands.push_back( + b.create(extractOp.getLoc(), pos)); + } + extractOp->setOperands(newOperands); return extractOp.getResult(); } @@ -1648,7 +1687,8 @@ if (destinationRank > insertOp.getSourceVectorType().getRank()) return Value(); auto insertOffsets = extractVector(insertOp.getOffsets()); - auto extractOffsets = extractVector(extractOp.getPosition()); + SmallVector pos = extractOp.getPosition(); + auto extractOffsets = getAsIntegers(pos); if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) { return llvm::cast(attr).getInt() != 1; @@ -1685,11 +1725,13 @@ insertRankDiff)) return Value(); } - extractOp.getVectorMutable().assign(insertOp.getSource()); - // OpBuilder is only used as a helper to build an I64ArrayAttr. - OpBuilder b(extractOp.getContext()); - extractOp->setAttr(ExtractOp::getPositionAttrStrName(), - b.getI64ArrayAttr(offsetDiffs)); + SmallVector newOperands; + newOperands.push_back(insertOp.getSource()); + OpBuilder b(extractOp); + for (int64_t offset : offsetDiffs) + newOperands.push_back( + b.create(extractOp.getLoc(), offset)); + extractOp->setOperands(newOperands); return extractOp.getResult(); } // If the chunk extracted is disjoint from the chunk inserted, keep @@ -1808,8 +1850,9 @@ // Calculate the linearized position of the continuous chunk of elements to // extract. - llvm::SmallVector completePositions(vecTy.getRank(), 0); - copy(getI64SubArray(extractOp.getPosition()), completePositions.begin()); + SmallVector completePositions(vecTy.getRank(), 0); + SmallVector pos = extractOp.getPosition(); + llvm::copy(getAsIntegers(pos), completePositions.begin()); int64_t elemBeginPosition = linearize(completePositions, computeStrides(vecTy.getShape())); auto denseValuesBegin = dense.value_begin() + elemBeginPosition; @@ -2304,48 +2347,54 @@ // InsertOp //===----------------------------------------------------------------------===// +void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, + Value dest, int64_t position) { + build(builder, result, source, dest, ArrayRef(position)); +} + void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, Value dest, ArrayRef position) { - result.addOperands({source, dest}); - auto positionAttr = getVectorSubscriptAttr(builder, position); - result.addTypes(dest.getType()); - result.addAttribute(getPositionAttrStrName(), positionAttr); + SmallVector posVals; + posVals.reserve(position.size()); + llvm::transform(position, std::back_inserter(posVals), [&](int64_t pos) { + return builder.create(result.location, pos); + }); + build(builder, result, source, dest, posVals); } -// Convenience builder which assumes the values are constant indices. void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, - Value dest, ValueRange position) { - SmallVector positionConstants = - llvm::to_vector<4>(llvm::map_range(position, [](Value pos) { - return getConstantIntValue(pos).value(); - })); - build(builder, result, source, dest, positionConstants); + Value dest, Value position) { + build(builder, result, source, dest, ValueRange(position)); } LogicalResult InsertOp::verify() { - auto positionAttr = getPosition().getValue(); + SmallVector position = getPosition(); auto destVectorType = getDestVectorType(); - if (positionAttr.size() > static_cast(destVectorType.getRank())) + if (position.size() > static_cast(destVectorType.getRank())) return emitOpError( "expected position attribute of rank no greater than dest vector rank"); auto srcVectorType = llvm::dyn_cast(getSourceType()); if (srcVectorType && - (static_cast(srcVectorType.getRank()) + positionAttr.size() != + (static_cast(srcVectorType.getRank()) + position.size() != static_cast(destVectorType.getRank()))) return emitOpError("expected position attribute rank + source rank to " "match dest vector rank"); if (!srcVectorType && - (positionAttr.size() != static_cast(destVectorType.getRank()))) + (position.size() != static_cast(destVectorType.getRank()))) return emitOpError( "expected position attribute rank to match the dest vector rank"); - for (const auto &en : llvm::enumerate(positionAttr)) { - auto attr = llvm::dyn_cast(en.value()); - if (!attr || attr.getInt() < 0 || - attr.getInt() >= destVectorType.getDimSize(en.index())) - return emitOpError("expected position attribute #") - << (en.index() + 1) - << " to be a non-negative integer smaller than the corresponding " - "dest vector dimension"; + for (auto [idx, pos] : llvm::enumerate(position)) { + auto constOp = pos.getDefiningOp(); + if (constOp) { + if (constOp.value() < 0 || + constOp.value() >= destVectorType.getDimSize(idx)) { + return emitOpError("expected position attribute #") + << (idx + 1) + << " to be a non-negative integer smaller than the " + "corresponding " + "dest vector dimension"; + } + } } return success(); } @@ -2427,8 +2476,9 @@ // Calculate the linearized position of the continuous chunk of elements to // insert. - llvm::SmallVector completePositions(destTy.getRank(), 0); - copy(getI64SubArray(op.getPosition()), completePositions.begin()); + SmallVector completePositions(destTy.getRank(), 0); + SmallVector pos = op.getPosition(); + llvm::copy(getAsIntegers(pos), completePositions.begin()); int64_t insertBeginPosition = linearize(completePositions, computeStrides(destTy.getShape())); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -89,24 +89,22 @@ PatternRewriter &rewriter) { if (index == -1) return val; - Type lowType = VectorType::Builder(type).dropDim(0); + // At extraction dimension? - if (index == 0) { - auto posAttr = rewriter.getI64ArrayAttr(pos); - return rewriter.create(loc, lowType, val, posAttr); - } + if (index == 0) + return rewriter.create(loc, val, pos); + // Unroll leading dimensions. + Type lowType = VectorType::Builder(type).dropDim(0); VectorType vType = cast(lowType); Type resType = VectorType::Builder(type).dropDim(index); auto resVectorType = cast(resType); Value result = rewriter.create( loc, resVectorType, rewriter.getZeroAttr(resVectorType)); for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) { - auto posAttr = rewriter.getI64ArrayAttr(d); - Value ext = rewriter.create(loc, vType, val, posAttr); + Value ext = rewriter.create(loc, val, d); Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); - result = rewriter.create(loc, resVectorType, load, result, - posAttr); + result = rewriter.create(loc, load, result, d); } return result; } @@ -120,20 +118,17 @@ if (index == -1) return val; // At insertion dimension? - if (index == 0) { - auto posAttr = rewriter.getI64ArrayAttr(pos); - return rewriter.create(loc, type, val, result, posAttr); - } + if (index == 0) + return rewriter.create(loc, val, result, pos); + // Unroll leading dimensions. Type lowType = VectorType::Builder(type).dropDim(0); - VectorType vType = cast(lowType); - Type insType = VectorType::Builder(vType).dropDim(0); + auto vType = cast(lowType); for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { - auto posAttr = rewriter.getI64ArrayAttr(d); - Value ext = rewriter.create(loc, vType, result, posAttr); - Value ins = rewriter.create(loc, insType, val, posAttr); + Value ext = rewriter.create(loc, result, d); + Value ins = rewriter.create(loc, val, d); Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); - result = rewriter.create(loc, type, sto, result, posAttr); + result = rewriter.create(loc, sto, result, d); } return result; } @@ -823,10 +818,8 @@ newRhs = rewriter.create(loc, newRhs, rhsTranspose); SmallVector lhsOffsets(lhsReductionDims.size(), 0); SmallVector rhsOffsets(rhsReductionDims.size(), 0); - newLhs = rewriter.create( - loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets)); - newRhs = rewriter.create( - loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets)); + newLhs = rewriter.create(loc, newLhs, lhsOffsets); + newRhs = rewriter.create(loc, newRhs, rhsOffsets); std::optional result = createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(), contractOp.getKind(), rewriter, isInt); @@ -1167,21 +1160,20 @@ Value result = rewriter.create( loc, resType, rewriter.getZeroAttr(resType)); for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { - auto pos = rewriter.getI64ArrayAttr(d); - Value x = rewriter.create(loc, op.getLhs(), pos); + Value x = rewriter.create(loc, op.getLhs(), d); Value a = rewriter.create(loc, rhsType, x); Value r = nullptr; if (acc) - r = rewriter.create(loc, acc, pos); + r = rewriter.create(loc, acc, d); Value extrMask; if (mask) - extrMask = rewriter.create(loc, mask, pos); + extrMask = rewriter.create(loc, mask, d); std::optional m = createContractArithOp( loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask); if (!m.has_value()) return failure(); - result = rewriter.create(loc, resType, *m, result, pos); + result = rewriter.create(loc, *m, result, d); } rewriter.replaceOp(rootOp, result); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -77,9 +77,7 @@ Value val = rewriter.create(loc, arith::CmpIPredicate::slt, bnd, idx); Value sel = rewriter.create(loc, val, trueVal, falseVal); - auto pos = rewriter.getI64ArrayAttr(d); - result = - rewriter.create(loc, dstType, sel, result, pos); + result = rewriter.create(loc, sel, result, d); } rewriter.replaceOp(op, result); return success(); @@ -151,11 +149,9 @@ loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); Value result = rewriter.create( loc, dstType, rewriter.getZeroAttr(dstType)); - for (int64_t d = 0; d < trueDim; d++) { - auto pos = rewriter.getI64ArrayAttr(d); - result = - rewriter.create(loc, dstType, trueVal, result, pos); - } + for (int64_t d = 0; d < trueDim; d++) + result = rewriter.create(loc, trueVal, result, d); + rewriter.replaceOp(op, result); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -944,11 +944,10 @@ // Rewrite vector.extract with 1d source to vector.extractelement. if (extractSrcType.getRank() == 1) { assert(extractOp.getPosition().size() == 1 && "expected 1 index"); - int64_t pos = cast(extractOp.getPosition()[0]).getInt(); + Value pos = extractOp.getPosition()[0]; rewriter.setInsertionPoint(extractOp); rewriter.replaceOpWithNewOp( - extractOp, extractOp.getVector(), - rewriter.create(loc, pos)); + extractOp, extractOp.getVector(), pos); return success(); } @@ -1201,11 +1200,10 @@ // Rewrite vector.insert with 1d dest to vector.insertelement. if (insertOp.getDestVectorType().getRank() == 1) { assert(insertOp.getPosition().size() == 1 && "expected 1 index"); - int64_t pos = cast(insertOp.getPosition()[0]).getInt(); + Value pos = insertOp.getPosition()[0]; rewriter.setInsertionPoint(insertOp); rewriter.replaceOpWithNewOp( - insertOp, insertOp.getSource(), insertOp.getDest(), - rewriter.create(loc, pos)); + insertOp, insertOp.getSource(), insertOp.getDest(), pos); return success(); } @@ -1276,10 +1274,8 @@ } else { // One lane inserts the entire source vector. int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim); - SmallVector newPos = llvm::to_vector( - llvm::map_range(insertOp.getPosition(), [](Attribute attr) { - return cast(attr).getInt(); - })); + SmallVector pos = insertOp.getPosition(); + SmallVector newPos = getAsIntegers(pos); // tid of inserting lane: pos / elementsPerLane Value insertingLane = rewriter.create( loc, newPos[distrDestDim] / elementsPerLane); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" @@ -165,16 +166,15 @@ // type has leading unit dims, we also trim the position array accordingly, // then (2) if source type also has leading unit dims, we need to append // zeroes to the position array accordingly. - unsigned oldPosRank = insertOp.getPosition().getValue().size(); + unsigned oldPosRank = insertOp.getPosition().size(); unsigned newPosRank = std::max(0, oldPosRank - dstDropCount); - SmallVector newPositions = llvm::to_vector( - insertOp.getPosition().getValue().take_back(newPosRank)); + SmallVector newPositions = + llvm::to_vector(insertOp.getPosition().take_back(newPosRank)); newPositions.resize(newDstType.getRank() - newSrcRank, - rewriter.getI64IntegerAttr(0)); + rewriter.create(loc, 0)); auto newInsertOp = rewriter.create( - loc, newDstType, newSrcVector, newDstVector, - rewriter.getArrayAttr(newPositions)); + loc, newDstType, newSrcVector, newDstVector, newPositions); rewriter.replaceOpWithNewOp(insertOp, oldDstType, newInsertOp); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -703,10 +703,9 @@ auto xferOp = extractOp.getVector().getDefiningOp(); SmallVector newIndices(xferOp.getIndices().begin(), xferOp.getIndices().end()); - for (const auto &it : llvm::enumerate(extractOp.getPosition())) { - int64_t offset = cast(it.value()).getInt(); - int64_t idx = - newIndices.size() - extractOp.getPosition().size() + it.index(); + for (auto [i, pos] : llvm::enumerate(extractOp.getPosition())) { + int64_t offset = pos.getDefiningOp().value(); + int64_t idx = newIndices.size() - extractOp.getPosition().size() + i; OpFoldResult ofr = affine::makeComposedFoldedAffineApply( rewriter, extractOp.getLoc(), rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]}); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -598,19 +598,16 @@ unsigned expandRatio = castDstType.getNumElements() / castSrcType.getNumElements(); - auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t { - return (*attr.getAsValueRange().begin()).getZExtValue(); + auto getFirstIntValue = [](ValueRange values) -> uint64_t { + return values[0].getDefiningOp().value(); }; uint64_t index = getFirstIntValue(extractOp.getPosition()); // Get the single scalar (as a vector) in the source value that packs the // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> - VectorType oneScalarType = - VectorType::get({1}, castSrcType.getElementType()); Value packedValue = rewriter.create( - extractOp.getLoc(), oneScalarType, castOp.getSource(), - rewriter.getI64ArrayAttr(index / expandRatio)); + extractOp.getLoc(), castOp.getSource(), index / expandRatio); // Cast it to a vector with the desired scalar's type. // E.g. f32 -> vector<2xf16> @@ -620,9 +617,8 @@ extractOp.getLoc(), packedType, packedValue); // Finally extract the desired scalar. - rewriter.replaceOpWithNewOp( - extractOp, extractOp.getType(), castedValue, - rewriter.getI64ArrayAttr(index % expandRatio)); + rewriter.replaceOpWithNewOp(extractOp, castedValue, + index % expandRatio); return success(); }