diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -14,6 +14,7 @@ #define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H #include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" #include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h" #include "mlir/IR/AffineMap.h" @@ -130,6 +131,24 @@ return cast(attr).getValue() == IteratorType::reduction; } +/// Returns the integer numbers in `values`. `values` are expected to be +/// constant operations. +SmallVector getAsIntegers(ArrayRef values); + +/// Returns the integer numbers in `foldResults`. `foldResults` are expected to +/// be constant operations. +SmallVector getAsIntegers(ArrayRef foldResults); + +/// Convert `foldResults` into Values. Integer attributes are converted to +/// constant op. +SmallVector getAsValues(OpBuilder &builder, Location loc, + ArrayRef foldResults); + +/// Returns the constant index ops in `values`. `values` are expected to be +/// constant operations. +SmallVector +getAsConstantIndexOps(ArrayRef values); + //===----------------------------------------------------------------------===// // Vector Masking Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -572,9 +572,7 @@ Vector_Op<"extract", [Pure, PredOpTrait<"operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, - InferTypeOpAdaptorWithIsCompatible]>, - Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$position)>, - Results<(outs AnyType)> { + InferTypeOpAdaptorWithIsCompatible]> { let summary = "extract operation"; let description = [{ Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at @@ -588,18 +586,43 @@ %3 = vector.extract %1[]: vector ``` }]; + + let arguments = (ins + AnyVectorOfAnyRank:$vector, + Variadic:$dynamic_position, + DenseI64ArrayAttr:$static_position + ); + let results = (outs AnyType:$result); + let builders = [ + OpBuilder<(ins "Value":$source, "int64_t":$position)>, + OpBuilder<(ins "Value":$source, "OpFoldResult":$position)>, OpBuilder<(ins "Value":$source, "ArrayRef":$position)>, - // Convenience builder which assumes the values in `position` are defined by - // ConstantIndexOp. - OpBuilder<(ins "Value":$source, "ValueRange":$position)> + OpBuilder<(ins "Value":$source, "ArrayRef":$position)>, ]; + let extraClassDeclaration = [{ VectorType getSourceVectorType() { return ::llvm::cast(getVector().getType()); } + + /// Return a vector with all the static and dynamic position indices. + SmallVector getMixedPosition() { + OpBuilder builder(getContext()); + return getMixedValues(getStaticPosition(), getDynamicPosition(), builder); + } + + unsigned getNumIndices() { + return getStaticPosition().size(); + } + }]; + + let assemblyFormat = [{ + $vector `` + custom($dynamic_position, $static_position) + attr-dict `:` type($vector) }]; - let assemblyFormat = "$vector `` $position attr-dict `:` type($vector)"; + let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; @@ -688,9 +711,7 @@ Vector_Op<"insert", [Pure, PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, - AllTypesMatch<["dest", "res"]>]>, - Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, I64ArrayAttr:$position)>, - Results<(outs AnyVectorOfAnyRank:$res)> { + AllTypesMatch<["dest", "result"]>]> { let summary = "insert operation"; let description = [{ Takes an n-D source vector, an (n+k)-D destination vector and a k-D position @@ -700,27 +721,48 @@ Example: ```mlir - %2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32> - %5 = vector.insert %3, %4[3, 3, 3] : f32 into vector<4x8x16xf32> + %2 = vector.insert %0, %1[%c3] : vector<8x16xf32> into vector<4x8x16xf32> + %5 = vector.insert %3, %4[%c3, %c3, %c3] : f32 into vector<4x8x16xf32> %8 = vector.insert %6, %7[] : f32 into vector - %11 = vector.insert %9, %10[3, 3, 3] : vector into vector<4x8x16xf32> + %11 = vector.insert %9, %10[%c3, %c3, %c3] : vector into vector<4x8x16xf32> ``` }]; - let assemblyFormat = [{ - $source `,` $dest $position attr-dict `:` type($source) `into` type($dest) - }]; + + let arguments = (ins + AnyType:$source, + AnyVectorOfAnyRank:$dest, + Variadic:$dynamic_position, + DenseI64ArrayAttr:$static_position + ); + let results = (outs AnyVectorOfAnyRank:$result); let builders = [ - OpBuilder<(ins "Value":$source, "Value":$dest, - "ArrayRef":$position)>, - // Convenience builder which assumes all values are constant indices. - OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$position)> + OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>, + OpBuilder<(ins "Value":$source, "Value":$dest, "OpFoldResult":$position)>, + OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef":$position)>, + OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef":$position)>, ]; + let extraClassDeclaration = [{ Type getSourceType() { return getSource().getType(); } VectorType getDestVectorType() { return ::llvm::cast(getDest().getType()); } + + /// Return a vector with all the static and dynamic position indices. + SmallVector getMixedPosition() { + OpBuilder builder(getContext()); + return getMixedValues(getStaticPosition(), getDynamicPosition(), builder); + } + + unsigned getNumIndices() { + return getStaticPosition().size(); + } + }]; + + let assemblyFormat = [{ + $source `,` $dest custom($dynamic_position, $static_position) + attr-dict `:` type($source) `into` type($dest) }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -807,8 +807,7 @@ Value el = rewriter.create(loc, loadedElType, op.getSource(), newIndices); - result = rewriter.create(loc, el, result, - rewriter.getI64ArrayAttr(i)); + result = rewriter.create(loc, el, result, i); } } else { if (auto vecType = dyn_cast(loadedElType)) { @@ -832,7 +831,7 @@ Value el = rewriter.create(op.getLoc(), loadedElType, op.getSource(), newIndices); result = rewriter.create( - op.getLoc(), el, result, rewriter.getI64ArrayAttr({i, innerIdx})); + op.getLoc(), el, result, ArrayRef({i, innerIdx})); } } } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -128,6 +128,18 @@ return rewriter.create(loc, pType, ptr); } +/// Convert `foldResult` into a Value. Integer attribute is converted to +/// an LLVM constant op. +Value static getAsLLVMValue(OpBuilder &builder, Location loc, + OpFoldResult foldResult) { + if (auto attr = foldResult.dyn_cast()) { + auto intAttr = cast(attr); + return builder.create(loc, intAttr).getResult(); + } + + return foldResult.get(); +} + namespace { /// Trivial Vector to LLVM conversions @@ -1009,7 +1021,8 @@ } rewriter.replaceOpWithNewOp( - extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); + extractEltOp, llvmType, adaptor.getVector(), + adaptor.getPosition()); return success(); } }; @@ -1025,48 +1038,40 @@ auto loc = extractOp->getLoc(); auto resultType = extractOp.getResult().getType(); auto llvmResultType = typeConverter->convertType(resultType); - auto positionArrayAttr = extractOp.getPosition(); + SmallVector positionVec = extractOp.getMixedPosition(); + ArrayRef position(positionVec); // Bail if result type cannot be lowered. if (!llvmResultType) return failure(); // Extract entire vector. Should be handled by folder, but just to be safe. - if (positionArrayAttr.empty()) { + if (position.empty()) { rewriter.replaceOp(extractOp, adaptor.getVector()); return success(); } // One-shot extraction of vector from array (only requires extractvalue). if (isa(resultType)) { - SmallVector indices; - for (auto idx : positionArrayAttr.getAsRange()) - indices.push_back(idx.getInt()); Value extracted = rewriter.create( - loc, adaptor.getVector(), indices); + loc, adaptor.getVector(), getAsIntegers(position)); rewriter.replaceOp(extractOp, extracted); return success(); } // Potential extraction of 1-D vector from array. Value extracted = adaptor.getVector(); - auto positionAttrs = positionArrayAttr.getValue(); - if (positionAttrs.size() > 1) { - SmallVector nMinusOnePosition; - for (auto idx : positionAttrs.drop_back()) - nMinusOnePosition.push_back(cast(idx).getInt()); + if (position.size() > 1) { + SmallVector nMinusOnePosition = + getAsIntegers(position.drop_back()); extracted = rewriter.create(loc, extracted, nMinusOnePosition); } + Value lastPosition = getAsLLVMValue(rewriter, loc, position.back()); // Remaining extraction of element from 1-D LLVM vector - auto position = cast(positionAttrs.back()); - auto i64Type = IntegerType::get(rewriter.getContext(), 64); - auto constant = rewriter.create(loc, i64Type, position); - extracted = - rewriter.create(loc, extracted, constant); - rewriter.replaceOp(extractOp, extracted); - + rewriter.replaceOpWithNewOp(extractOp, extracted, + lastPosition); return success(); } }; @@ -1147,7 +1152,8 @@ auto sourceType = insertOp.getSourceType(); auto destVectorType = insertOp.getDestVectorType(); auto llvmResultType = typeConverter->convertType(destVectorType); - auto positionArrayAttr = insertOp.getPosition(); + SmallVector positionVec = insertOp.getMixedPosition(); + ArrayRef position(positionVec); // Bail if result type cannot be lowered. if (!llvmResultType) @@ -1155,7 +1161,7 @@ // Overwrite entire vector with value. Should be handled by folder, but // just to be safe. - if (positionArrayAttr.empty()) { + if (position.empty()) { rewriter.replaceOp(insertOp, adaptor.getSource()); return success(); } @@ -1164,35 +1170,30 @@ if (isa(sourceType)) { Value inserted = rewriter.create( loc, adaptor.getDest(), adaptor.getSource(), - LLVM::convertArrayToIndices(positionArrayAttr)); + getAsIntegers(position)); rewriter.replaceOp(insertOp, inserted); return success(); } // Potential extraction of 1-D vector from array. Value extracted = adaptor.getDest(); - auto positionAttrs = positionArrayAttr.getValue(); - auto position = cast(positionAttrs.back()); auto oneDVectorType = destVectorType; - if (positionAttrs.size() > 1) { + if (position.size() > 1) { oneDVectorType = reducedVectorTypeBack(destVectorType); extracted = rewriter.create( - loc, extracted, - LLVM::convertArrayToIndices(positionAttrs.drop_back())); + loc, extracted, getAsIntegers(position.drop_back())); } // Insertion of an element into a 1-D LLVM vector. - auto i64Type = IntegerType::get(rewriter.getContext(), 64); - auto constant = rewriter.create(loc, i64Type, position); Value inserted = rewriter.create( loc, typeConverter->convertType(oneDVectorType), extracted, - adaptor.getSource(), constant); + adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back())); // Potential insertion of resulting 1-D vector into array. - if (positionAttrs.size() > 1) { + if (position.size() > 1) { inserted = rewriter.create( loc, adaptor.getDest(), inserted, - LLVM::convertArrayToIndices(positionAttrs.drop_back())); + getAsIntegers(position.drop_back())); } rewriter.replaceOp(insertOp, inserted); diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -885,10 +885,10 @@ /// If the result of the TransferReadOp has exactly one user, which is a /// vector::InsertOp, return that operation's indices. void getInsertionIndices(TransferReadOp xferOp, - SmallVector &indices) const { + SmallVectorImpl &indices) const { if (auto insertOp = getInsertOp(xferOp)) { - for (Attribute attr : insertOp.getPosition()) - indices.push_back(dyn_cast(attr).getInt()); + auto pos = insertOp.getMixedPosition(); + indices.append(pos.begin(), pos.end()); } } @@ -927,9 +927,9 @@ getXferIndices(b, xferOp, iv, xferIndices); // Indices for the new vector.insert op. - SmallVector insertionIndices; + SmallVector insertionIndices; getInsertionIndices(xferOp, insertionIndices); - insertionIndices.push_back(i); + insertionIndices.push_back(rewriter.getIndexAttr(i)); auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); auto newXferOp = b.create( @@ -1012,10 +1012,10 @@ /// If the input of the given TransferWriteOp is an ExtractOp, return its /// indices. void getExtractionIndices(TransferWriteOp xferOp, - SmallVector &indices) const { + SmallVectorImpl &indices) const { if (auto extractOp = getExtractOp(xferOp)) { - for (Attribute attr : extractOp.getPosition()) - indices.push_back(dyn_cast(attr).getInt()); + auto pos = extractOp.getMixedPosition(); + indices.append(pos.begin(), pos.end()); } } @@ -1053,9 +1053,9 @@ getXferIndices(b, xferOp, iv, xferIndices); // Indices for the new vector.extract op. - SmallVector extractionIndices; + SmallVector extractionIndices; getExtractionIndices(xferOp, extractionIndices); - extractionIndices.push_back(i); + extractionIndices.push_back(b.getI64IntegerAttr(i)); auto extracted = b.create(loc, vec, extractionIndices); diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -34,8 +34,18 @@ /// Gets the first integer value from `attr`, assuming it is an integer array /// attribute. -static uint64_t getFirstIntValue(ArrayAttr attr) { - return (*attr.getAsValueRange().begin()).getZExtValue(); +static uint64_t getFirstIntValue(ValueRange values) { + return values[0].getDefiningOp().value(); +} +static uint64_t getFirstIntValue(ArrayRef attr) { + return cast(attr[0]).getInt(); +} +static uint64_t getFirstIntValue(ArrayRef foldResults) { + auto attr = foldResults[0].dyn_cast(); + if (attr) + return getFirstIntValue(attr); + + return getFirstIntValue(ValueRange{foldResults[0].get()}); } /// Returns the number of bits for the given scalar/vector type. @@ -152,7 +162,7 @@ return success(); } - int32_t id = getFirstIntValue(extractOp.getPosition()); + int32_t id = getFirstIntValue(extractOp.getMixedPosition()); rewriter.replaceOpWithNewOp( extractOp, adaptor.getVector(), id); return success(); @@ -232,7 +242,7 @@ return success(); } - int32_t id = getFirstIntValue(insertOp.getPosition()); + int32_t id = getFirstIntValue(insertOp.getMixedPosition()); rewriter.replaceOpWithNewOp( insertOp, adaptor.getSource(), adaptor.getDest(), id); return success(); diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp --- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp @@ -516,7 +516,7 @@ return failure(); Value newExtract = rewriter.create( - op.getLoc(), ext->getIn(), op.getPosition()); + op.getLoc(), ext->getIn(), op.getMixedPosition()); ext->recreateAndReplace(rewriter, op, newExtract); return success(); } @@ -645,8 +645,9 @@ vector::InsertOp origInsert, Value narrowValue, Value narrowDest) const override { - return rewriter.create( - origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition()); + return rewriter.create(origInsert.getLoc(), narrowValue, + narrowDest, + origInsert.getMixedPosition()); } }; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -224,6 +224,47 @@ return failure(); } +/// Returns the integer numbers in `values`. `values` are expected to be +/// constant operations. +SmallVector vector::getAsIntegers(ArrayRef values) { + SmallVector ints; + llvm::transform(values, std::back_inserter(ints), [](Value value) { + auto constOp = value.getDefiningOp(); + assert(constOp && "Unexpected non-constant index"); + return constOp.value(); + }); + return ints; +} + +/// Returns the integer numbers in `foldResults`. `foldResults` are expected to +/// be constant operations. +SmallVector vector::getAsIntegers(ArrayRef foldResults) { + SmallVector ints; + llvm::transform(foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) { + assert(foldResult.is() && "Unexpected non-constant index"); + return cast(foldResult.get()).getInt(); + }); + return ints; +} + +/// Convert `foldResults` into Values. Integer attributes are converted to +/// constant op. +SmallVector +vector::getAsValues(OpBuilder &builder, Location loc, ArrayRef foldResults) { + SmallVector values; + llvm::transform( + foldResults, std::back_inserter(values), [&](OpFoldResult foldResult) { + if (auto attr = foldResult.dyn_cast()) + return builder + .create(loc, + cast(attr).getInt()) + .getResult(); + + return foldResult.get(); + }); + return values; +} + //===----------------------------------------------------------------------===// // CombiningKindAttr //===----------------------------------------------------------------------===// @@ -385,13 +426,10 @@ } else { // This means we are reducing all the dimensions, and all reduction // dimensions are of size 1. So a simple extraction would do. - auto zeroAttr = - rewriter.getI64ArrayAttr(SmallVector(shape.size(), 0)); if (mask) - mask = rewriter.create(loc, rewriter.getI1Type(), - mask, zeroAttr); - cast = rewriter.create( - loc, reductionOp.getDestType(), reductionOp.getSource(), zeroAttr); + mask = rewriter.create(loc, mask, 0); + cast = + rewriter.create(loc, reductionOp.getSource(), 0); } Value result = vector::makeArithReduction( @@ -559,13 +597,9 @@ mask = rewriter.create(loc, mask); result = rewriter.create(loc, reductionOp.getVector()); } else { - if (mask) { - mask = rewriter.create(loc, rewriter.getI1Type(), mask, - rewriter.getI64ArrayAttr(0)); - } - result = rewriter.create(loc, reductionOp.getType(), - reductionOp.getVector(), - rewriter.getI64ArrayAttr(0)); + if (mask) + mask = rewriter.create(loc, mask, 0); + result = rewriter.create(loc, reductionOp.getVector(), 0); } if (Value acc = reductionOp.getAcc()) @@ -1129,19 +1163,33 @@ // ExtractOp //===----------------------------------------------------------------------===// +void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, + Value source, int64_t position) { + build(builder, result, source, ArrayRef{position}); +} + +void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, + Value source, OpFoldResult position) { + build(builder, result, source, ArrayRef{position}); +} + void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, Value source, ArrayRef position) { - build(builder, result, source, getVectorSubscriptAttr(builder, position)); + SmallVector posVals; + posVals.reserve(position.size()); + llvm::transform(position, std::back_inserter(posVals), [&](int64_t pos) { + return builder.getI64IntegerAttr(pos); + }); + build(builder, result, source, posVals); } -// Convenience builder which assumes the values are constant indices. void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, - Value source, ValueRange position) { - SmallVector positionConstants = - llvm::to_vector<4>(llvm::map_range(position, [](Value pos) { - return getConstantIntValue(pos).value(); - })); - build(builder, result, source, positionConstants); + Value source, ArrayRef position) { + SmallVector staticPos; + SmallVector dynamicPos; + dispatchIndexOpFoldResults(position, dynamicPos, staticPos); + build(builder, result, source, dynamicPos, + builder.getDenseI64ArrayAttr(staticPos)); } LogicalResult @@ -1149,12 +1197,12 @@ ExtractOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { auto vectorType = llvm::cast(adaptor.getVector().getType()); - if (static_cast(adaptor.getPosition().size()) == + if (static_cast(adaptor.getStaticPosition().size()) == vectorType.getRank()) { inferredReturnTypes.push_back(vectorType.getElementType()); } else { - auto n = - std::min(adaptor.getPosition().size(), vectorType.getRank()); + auto n = std::min(adaptor.getStaticPosition().size(), + vectorType.getRank()); inferredReturnTypes.push_back(VectorType::get( vectorType.getShape().drop_front(n), vectorType.getElementType())); } @@ -1174,20 +1222,22 @@ return l == r; } -LogicalResult vector::ExtractOp::verify() { - auto positionAttr = getPosition().getValue(); - if (positionAttr.size() > - static_cast(getSourceVectorType().getRank())) +LogicalResult ExtractOp::verify() { + auto position = getMixedPosition(); + if (position.size() > static_cast(getSourceVectorType().getRank())) return emitOpError( "expected position attribute of rank no greater than vector rank"); - for (const auto &en : llvm::enumerate(positionAttr)) { - auto attr = llvm::dyn_cast(en.value()); - if (!attr || attr.getInt() < 0 || - attr.getInt() >= getSourceVectorType().getDimSize(en.index())) - return emitOpError("expected position attribute #") - << (en.index() + 1) - << " to be a non-negative integer smaller than the corresponding " - "vector dimension"; + for (auto [idx, pos] : llvm::enumerate(position)) { + if (pos.is()) { + int64_t constIdx = cast(pos.get()).getInt(); + if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) { + return emitOpError("expected position attribute #") + << (idx + 1) + << " to be a non-negative integer smaller than the " + "corresponding " + "vector dimension"; + } + } } return success(); } @@ -1205,20 +1255,25 @@ if (!extractOp.getVector().getDefiningOp()) return failure(); - SmallVector globalPosition; + SmallVector globalStaticPos; + SmallVector globalDynamicPos; ExtractOp currentOp = extractOp; - auto extrPos = extractVector(currentOp.getPosition()); - globalPosition.append(extrPos.rbegin(), extrPos.rend()); + ArrayRef extrStaticPos = currentOp.getStaticPosition(); + SmallVector extrDynamicPos = currentOp.getDynamicPosition(); + globalStaticPos.append(extrStaticPos.rbegin(), extrStaticPos.rend()); + globalDynamicPos.append(extrDynamicPos.rbegin(), extrDynamicPos.rend()); while (ExtractOp nextOp = currentOp.getVector().getDefiningOp()) { currentOp = nextOp; - auto extrPos = extractVector(currentOp.getPosition()); - globalPosition.append(extrPos.rbegin(), extrPos.rend()); + extrStaticPos = currentOp.getStaticPosition(); + extrDynamicPos = currentOp.getDynamicPosition(); + globalStaticPos.append(extrStaticPos.rbegin(), extrStaticPos.rend()); + globalDynamicPos.append(extrDynamicPos.rbegin(), extrDynamicPos.rend()); } - extractOp.setOperand(currentOp.getVector()); - // OpBuilder is only used as a helper to build an I64ArrayAttr. - OpBuilder b(extractOp.getContext()); - std::reverse(globalPosition.begin(), globalPosition.end()); - extractOp.setPositionAttr(b.getI64ArrayAttr(globalPosition)); + SmallVector newOperands; + newOperands.push_back(currentOp.getVector()); + newOperands.append(globalDynamicPos.rbegin(), globalDynamicPos.rend()); + extractOp->setOperands(newOperands); + extractOp.setStaticPosition(globalStaticPos); return success(); } @@ -1324,12 +1379,13 @@ ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState( ExtractOp e) : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()), - extractedRank(extractOp.getPosition().size()) { + extractedRank(extractOp.getNumIndices()) { assert(vectorRank >= extractedRank && "extracted pos overflow"); sentinels.reserve(vectorRank - extractedRank); for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i) sentinels.push_back(-(i + 1)); - extractPosition = extractVector(extractOp.getPosition()); + SmallVector pos = extractOp.getMixedPosition(); + extractPosition = getAsIntegers(pos); llvm::append_range(extractPosition, sentinels); } @@ -1349,9 +1405,10 @@ LogicalResult ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos( Value &res) { - auto insertedPos = extractVector(nextInsertOp.getPosition()); + SmallVector pos = nextInsertOp.getMixedPosition(); + auto insertedPos = getAsIntegers(pos); if (ArrayRef(insertedPos) != - llvm::ArrayRef(extractPosition).take_front(extractedRank)) + ArrayRef(extractPosition).take_front(extractedRank)) return failure(); // Case 2.a. early-exit fold. res = nextInsertOp.getSource(); @@ -1364,7 +1421,8 @@ /// This method updates the internal state. LogicalResult ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) { - auto insertedPos = extractVector(nextInsertOp.getPosition()); + SmallVector pos = nextInsertOp.getMixedPosition(); + auto insertedPos = getAsIntegers(pos); if (!isContainedWithin(insertedPos, extractPosition)) return failure(); // Set leading dims to zero. @@ -1390,10 +1448,14 @@ return Value(); // Otherwise, fold by updating the op inplace and return its result. OpBuilder b(extractOp.getContext()); - extractOp->setAttr( - extractOp.getPositionAttrName(), - b.getI64ArrayAttr(ArrayRef(extractPosition).take_front(extractedRank))); - extractOp.getVectorMutable().assign(source); + SmallVector newOperands; + SmallVector newStaticPos; + newOperands.push_back(source); + for (int64_t index : ArrayRef(extractPosition).take_front(extractedRank)) + newStaticPos.push_back(index); + + extractOp->setOperands(newOperands); + extractOp.setStaticPosition(newStaticPos); return extractOp.getResult(); } @@ -1422,7 +1484,8 @@ // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel // values. This is a more difficult case and we bail. - auto insertedPos = extractVector(nextInsertOp.getPosition()); + SmallVector pos = nextInsertOp.getMixedPosition(); + auto insertedPos = getAsIntegers(pos); if (isContainedWithin(extractPosition, insertedPos) || intersectsWhereNonNegative(extractPosition, insertedPos)) return Value(); @@ -1487,18 +1550,24 @@ // extract position to `0` when extracting from the source operand. llvm::SetVector broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims(); - auto extractPos = extractVector(extractOp.getPosition()); - for (int64_t i = rankDiff, e = extractPos.size(); i < e; ++i) - if (broadcastedUnitDims.contains(i)) - extractPos[i] = 0; + SmallVector extrStaticPos = + llvm::to_vector(extractOp.getStaticPosition()); + SmallVector extrDynamicPos = extractOp.getDynamicPosition(); + OpBuilder b(extractOp); + for (int64_t i = rankDiff, e = extractOp.getNumIndices(); i < e; ++i) { + if (broadcastedUnitDims.contains(i)) { + extrStaticPos[i] = 0; + } + } // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the // matching extract position when extracting from the source operand. - extractPos.erase(extractPos.begin(), - std::next(extractPos.begin(), extractPos.size() - rankDiff)); - // OpBuilder is only used as a helper to build an I64ArrayAttr. - OpBuilder b(extractOp.getContext()); - extractOp.setOperand(source); - extractOp.setPositionAttr(b.getI64ArrayAttr(extractPos)); + SmallVector newOperands; + newOperands.push_back(source); + newOperands.append( + std::next(extrDynamicPos.begin(), extrDynamicPos.size() - rankDiff), + extrDynamicPos.end()); + extractOp->setOperands(newOperands); + extractOp.setStaticPosition(extrStaticPos); return extractOp.getResult(); } @@ -1537,7 +1606,8 @@ } // Extract the strides associated with the extract op vector source. Then use // this to calculate a linearized position for the extract. - auto extractedPos = extractVector(extractOp.getPosition()); + SmallVector pos = extractOp.getMixedPosition(); + auto extractedPos = getAsIntegers(pos); std::reverse(extractedPos.begin(), extractedPos.end()); SmallVector strides; int64_t stride = 1; @@ -1561,10 +1631,15 @@ } std::reverse(newStrides.begin(), newStrides.end()); SmallVector newPosition = delinearize(position, newStrides); - // OpBuilder is only used as a helper to build an I64ArrayAttr. - OpBuilder b(extractOp.getContext()); - extractOp.setPositionAttr(b.getI64ArrayAttr(newPosition)); - extractOp.setOperand(shapeCastOp.getSource()); + OpBuilder b(extractOp); + SmallVector newOperands; + SmallVector newStaticPos; + newOperands.push_back(shapeCastOp.getSource()); + for (int64_t pos : newPosition) + newStaticPos.push_back(pos); + + extractOp->setOperands(newOperands); + extractOp.setStaticPosition(newStaticPos); return extractOp.getResult(); } @@ -1603,14 +1678,21 @@ if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() - sliceOffsets.size()) return Value(); - auto extractedPos = extractVector(extractOp.getPosition()); + SmallVector valuePos = extractOp.getMixedPosition(); + auto extractedPos = getAsIntegers(valuePos); assert(extractedPos.size() >= sliceOffsets.size()); for (size_t i = 0, e = sliceOffsets.size(); i < e; i++) extractedPos[i] = extractedPos[i] + sliceOffsets[i]; extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector()); - // OpBuilder is only used as a helper to build an I64ArrayAttr. - OpBuilder b(extractOp.getContext()); - extractOp.setPositionAttr(b.getI64ArrayAttr(extractedPos)); + OpBuilder b(extractOp); + SmallVector newOperands; + SmallVector newStaticPos; + newOperands.push_back(extractOp.getVector()); + for (int64_t pos : extractedPos) + newStaticPos.push_back(pos); + + extractOp->setOperands(newOperands); + extractOp.setStaticPosition(newStaticPos); return extractOp.getResult(); } @@ -1635,7 +1717,8 @@ if (destinationRank > insertOp.getSourceVectorType().getRank()) return Value(); auto insertOffsets = extractVector(insertOp.getOffsets()); - auto extractOffsets = extractVector(extractOp.getPosition()); + SmallVector pos = extractOp.getMixedPosition(); + auto extractOffsets = getAsIntegers(pos); if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) { return llvm::cast(attr).getInt() != 1; @@ -1672,10 +1755,15 @@ insertRankDiff)) return Value(); } - extractOp.getVectorMutable().assign(insertOp.getSource()); - // OpBuilder is only used as a helper to build an I64ArrayAttr. - OpBuilder b(extractOp.getContext()); - extractOp.setPositionAttr(b.getI64ArrayAttr(offsetDiffs)); + SmallVector newOperands; + SmallVector newStaticPos; + newOperands.push_back(insertOp.getSource()); + OpBuilder b(extractOp); + for (int64_t offset : offsetDiffs) + newStaticPos.push_back(offset); + + extractOp->setOperands(newOperands); + extractOp.setStaticPosition(newStaticPos); return extractOp.getResult(); } // If the chunk extracted is disjoint from the chunk inserted, keep @@ -1686,7 +1774,7 @@ } OpFoldResult ExtractOp::fold(FoldAdaptor) { - if (getPosition().empty()) + if (getNumIndices() == 0) return getVector(); if (succeeded(foldExtractOpFromExtractChain(*this))) return getResult(); @@ -1794,8 +1882,9 @@ // Calculate the linearized position of the continuous chunk of elements to // extract. - llvm::SmallVector completePositions(vecTy.getRank(), 0); - copy(getI64SubArray(extractOp.getPosition()), completePositions.begin()); + SmallVector completePositions(vecTy.getRank(), 0); + SmallVector pos = extractOp.getMixedPosition(); + llvm::copy(getAsIntegers(pos), completePositions.begin()); int64_t elemBeginPosition = linearize(completePositions, computeStrides(vecTy.getShape())); auto denseValuesBegin = dense.value_begin() + elemBeginPosition; @@ -2288,48 +2377,64 @@ // InsertOp //===----------------------------------------------------------------------===// -void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, - Value dest, ArrayRef position) { - result.addOperands({source, dest}); - auto positionAttr = getVectorSubscriptAttr(builder, position); - result.addTypes(dest.getType()); - result.addAttribute(InsertOp::getPositionAttrName(result.name), positionAttr); +void vector::InsertOp::build(OpBuilder &builder, OperationState &result, + Value source, Value dest, int64_t position) { + build(builder, result, source, dest, ArrayRef{position}); +} + +void vector::InsertOp::build(OpBuilder &builder, OperationState &result, + Value source, Value dest, OpFoldResult position) { + build(builder, result, source, dest, ArrayRef{position}); +} + +void vector::InsertOp::build(OpBuilder &builder, OperationState &result, + Value source, Value dest, + ArrayRef position) { + SmallVector posVals; + posVals.reserve(position.size()); + llvm::transform(position, std::back_inserter(posVals), [&](int64_t pos) { + return builder.getI64IntegerAttr(pos); + }); + build(builder, result, source, dest, posVals); } -// Convenience builder which assumes the values are constant indices. -void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, - Value dest, ValueRange position) { - SmallVector positionConstants = - llvm::to_vector<4>(llvm::map_range(position, [](Value pos) { - return getConstantIntValue(pos).value(); - })); - build(builder, result, source, dest, positionConstants); +void vector::InsertOp::build(OpBuilder &builder, OperationState &result, + Value source, Value dest, + ArrayRef position) { + SmallVector staticPos; + SmallVector dynamicPos; + dispatchIndexOpFoldResults(position, dynamicPos, staticPos); + build(builder, result, source, dest, dynamicPos, + builder.getDenseI64ArrayAttr(staticPos)); } LogicalResult InsertOp::verify() { - auto positionAttr = getPosition().getValue(); + SmallVector position = getMixedPosition(); auto destVectorType = getDestVectorType(); - if (positionAttr.size() > static_cast(destVectorType.getRank())) + if (position.size() > static_cast(destVectorType.getRank())) return emitOpError( "expected position attribute of rank no greater than dest vector rank"); auto srcVectorType = llvm::dyn_cast(getSourceType()); if (srcVectorType && - (static_cast(srcVectorType.getRank()) + positionAttr.size() != + (static_cast(srcVectorType.getRank()) + position.size() != static_cast(destVectorType.getRank()))) return emitOpError("expected position attribute rank + source rank to " "match dest vector rank"); if (!srcVectorType && - (positionAttr.size() != static_cast(destVectorType.getRank()))) + (position.size() != static_cast(destVectorType.getRank()))) return emitOpError( "expected position attribute rank to match the dest vector rank"); - for (const auto &en : llvm::enumerate(positionAttr)) { - auto attr = llvm::dyn_cast(en.value()); - if (!attr || attr.getInt() < 0 || - attr.getInt() >= destVectorType.getDimSize(en.index())) - return emitOpError("expected position attribute #") - << (en.index() + 1) - << " to be a non-negative integer smaller than the corresponding " - "dest vector dimension"; + for (auto [idx, pos] : llvm::enumerate(position)) { + if (auto attr = pos.dyn_cast()) { + int64_t constIdx = cast(attr).getInt(); + if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) { + return emitOpError("expected position attribute #") + << (idx + 1) + << " to be a non-negative integer smaller than the " + "corresponding " + "dest vector dimension"; + } + } } return success(); } @@ -2411,8 +2516,9 @@ // Calculate the linearized position of the continuous chunk of elements to // insert. - llvm::SmallVector completePositions(destTy.getRank(), 0); - copy(getI64SubArray(op.getPosition()), completePositions.begin()); + SmallVector completePositions(destTy.getRank(), 0); + SmallVector pos = op.getMixedPosition(); + llvm::copy(getAsIntegers(pos), completePositions.begin()); int64_t insertBeginPosition = linearize(completePositions, computeStrides(destTy.getShape())); @@ -2443,7 +2549,7 @@ // value. This happens when the source and destination vectors have identical // sizes. OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) { - if (getPosition().empty()) + if (getNumIndices() == 0) return getSource(); return {}; } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -89,24 +89,22 @@ PatternRewriter &rewriter) { if (index == -1) return val; - Type lowType = VectorType::Builder(type).dropDim(0); + // At extraction dimension? - if (index == 0) { - auto posAttr = rewriter.getI64ArrayAttr(pos); - return rewriter.create(loc, lowType, val, posAttr); - } + if (index == 0) + return rewriter.create(loc, val, pos); + // Unroll leading dimensions. + Type lowType = VectorType::Builder(type).dropDim(0); VectorType vType = cast(lowType); Type resType = VectorType::Builder(type).dropDim(index); auto resVectorType = cast(resType); Value result = rewriter.create( loc, resVectorType, rewriter.getZeroAttr(resVectorType)); for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) { - auto posAttr = rewriter.getI64ArrayAttr(d); - Value ext = rewriter.create(loc, vType, val, posAttr); + Value ext = rewriter.create(loc, val, d); Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); - result = rewriter.create(loc, resVectorType, load, result, - posAttr); + result = rewriter.create(loc, load, result, d); } return result; } @@ -120,20 +118,17 @@ if (index == -1) return val; // At insertion dimension? - if (index == 0) { - auto posAttr = rewriter.getI64ArrayAttr(pos); - return rewriter.create(loc, type, val, result, posAttr); - } + if (index == 0) + return rewriter.create(loc, val, result, pos); + // Unroll leading dimensions. Type lowType = VectorType::Builder(type).dropDim(0); - VectorType vType = cast(lowType); - Type insType = VectorType::Builder(vType).dropDim(0); + auto vType = cast(lowType); for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { - auto posAttr = rewriter.getI64ArrayAttr(d); - Value ext = rewriter.create(loc, vType, result, posAttr); - Value ins = rewriter.create(loc, insType, val, posAttr); + Value ext = rewriter.create(loc, result, d); + Value ins = rewriter.create(loc, val, d); Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); - result = rewriter.create(loc, type, sto, result, posAttr); + result = rewriter.create(loc, sto, result, d); } return result; } @@ -823,10 +818,8 @@ newRhs = rewriter.create(loc, newRhs, rhsTranspose); SmallVector lhsOffsets(lhsReductionDims.size(), 0); SmallVector rhsOffsets(rhsReductionDims.size(), 0); - newLhs = rewriter.create( - loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets)); - newRhs = rewriter.create( - loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets)); + newLhs = rewriter.create(loc, newLhs, lhsOffsets); + newRhs = rewriter.create(loc, newRhs, rhsOffsets); std::optional result = createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(), contractOp.getKind(), rewriter, isInt); @@ -1167,21 +1160,20 @@ Value result = rewriter.create( loc, resType, rewriter.getZeroAttr(resType)); for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { - auto pos = rewriter.getI64ArrayAttr(d); - Value x = rewriter.create(loc, op.getLhs(), pos); + Value x = rewriter.create(loc, op.getLhs(), d); Value a = rewriter.create(loc, rhsType, x); Value r = nullptr; if (acc) - r = rewriter.create(loc, acc, pos); + r = rewriter.create(loc, acc, d); Value extrMask; if (mask) - extrMask = rewriter.create(loc, mask, pos); + extrMask = rewriter.create(loc, mask, d); std::optional m = createContractArithOp( loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask); if (!m.has_value()) return failure(); - result = rewriter.create(loc, resType, *m, result, pos); + result = rewriter.create(loc, *m, result, d); } rewriter.replaceOp(rootOp, result); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -77,9 +77,7 @@ Value val = rewriter.create(loc, arith::CmpIPredicate::slt, bnd, idx); Value sel = rewriter.create(loc, val, trueVal, falseVal); - auto pos = rewriter.getI64ArrayAttr(d); - result = - rewriter.create(loc, dstType, sel, result, pos); + result = rewriter.create(loc, sel, result, d); } rewriter.replaceOp(op, result); return success(); @@ -151,11 +149,9 @@ loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); Value result = rewriter.create( loc, dstType, rewriter.getZeroAttr(dstType)); - for (int64_t d = 0; d < trueDim; d++) { - auto pos = rewriter.getI64ArrayAttr(d); - result = - rewriter.create(loc, dstType, trueVal, result, pos); - } + for (int64_t d = 0; d < trueDim; d++) + result = rewriter.create(loc, trueVal, result, d); + rewriter.replaceOp(op, result); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -938,19 +938,12 @@ "vector.extract does not support rank 0 sources"); // "vector.extract %v[] : vector<...xf32>" can be canonicalized to %v. - if (extractOp.getPosition().empty()) + if (extractOp.getNumIndices() == 0) return failure(); - // Rewrite vector.extract with 1d source to vector.extractelement. - if (extractSrcType.getRank() == 1) { - assert(extractOp.getPosition().size() == 1 && "expected 1 index"); - int64_t pos = cast(extractOp.getPosition()[0]).getInt(); - rewriter.setInsertionPoint(extractOp); - rewriter.replaceOpWithNewOp( - extractOp, extractOp.getVector(), - rewriter.create(loc, pos)); - return success(); - } + // Skip vector.extract already with 1d source. + if (extractSrcType.getRank() == 1) + return failure(); // All following cases are 2d or higher dimensional source vectors. @@ -968,7 +961,7 @@ Value distributedVec = newWarpOp->getResult(newRetIndices[0]); // Extract from distributed vector. Value newExtract = rewriter.create( - loc, distributedVec, extractOp.getPosition()); + loc, distributedVec, extractOp.getMixedPosition()); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newExtract); return success(); @@ -994,7 +987,7 @@ SmallVector newDistributedShape(extractSrcType.getShape().begin(), extractSrcType.getShape().end()); for (int i = 0; i < distributedType.getRank(); ++i) - newDistributedShape[i + extractOp.getPosition().size()] = + newDistributedShape[i + extractOp.getNumIndices()] = distributedType.getDimSize(i); auto newDistributedType = VectorType::get(newDistributedShape, distributedType.getElementType()); @@ -1006,7 +999,7 @@ Value distributedVec = newWarpOp->getResult(newRetIndices[0]); // Extract from distributed vector. Value newExtract = rewriter.create( - loc, distributedVec, extractOp.getPosition()); + loc, distributedVec, extractOp.getMixedPosition()); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newExtract); return success(); @@ -1195,19 +1188,12 @@ Location loc = insertOp.getLoc(); // "vector.insert %v, %v[] : ..." can be canonicalized to %v. - if (insertOp.getPosition().empty()) + if (insertOp.getNumIndices() == 0) return failure(); - // Rewrite vector.insert with 1d dest to vector.insertelement. - if (insertOp.getDestVectorType().getRank() == 1) { - assert(insertOp.getPosition().size() == 1 && "expected 1 index"); - int64_t pos = cast(insertOp.getPosition()[0]).getInt(); - rewriter.setInsertionPoint(insertOp); - rewriter.replaceOpWithNewOp( - insertOp, insertOp.getSource(), insertOp.getDest(), - rewriter.create(loc, pos)); - return success(); - } + // Skip vector.insert already with 1d dest. + if (insertOp.getDestVectorType().getRank() == 1) + return failure(); if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) { // There is no distribution, this is a broadcast. Simply move the insert @@ -1221,7 +1207,7 @@ Value distributedSrc = newWarpOp->getResult(newRetIndices[0]); Value distributedDest = newWarpOp->getResult(newRetIndices[1]); Value newResult = rewriter.create( - loc, distributedSrc, distributedDest, insertOp.getPosition()); + loc, distributedSrc, distributedDest, insertOp.getMixedPosition()); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult); return success(); @@ -1252,7 +1238,7 @@ // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that // case, one lane will insert the source vector<96xf32>. The other // lanes will not do anything. - int64_t distrSrcDim = distrDestDim - insertOp.getPosition().size(); + int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices(); if (distrSrcDim >= 0) distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim); auto distrSrcType = @@ -1272,14 +1258,12 @@ if (distrSrcDim >= 0) { // Every lane inserts a small piece. newResult = rewriter.create( - loc, distributedSrc, distributedDest, insertOp.getPosition()); + loc, distributedSrc, distributedDest, insertOp.getMixedPosition()); } else { // One lane inserts the entire source vector. int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim); - SmallVector newPos = llvm::to_vector( - llvm::map_range(insertOp.getPosition(), [](Attribute attr) { - return cast(attr).getInt(); - })); + SmallVector pos = insertOp.getMixedPosition(); + SmallVector newPos = getAsIntegers(pos); // tid of inserting lane: pos / elementsPerLane Value insertingLane = rewriter.create( loc, newPos[distrDestDim] / elementsPerLane); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" @@ -165,16 +166,16 @@ // type has leading unit dims, we also trim the position array accordingly, // then (2) if source type also has leading unit dims, we need to append // zeroes to the position array accordingly. - unsigned oldPosRank = insertOp.getPosition().getValue().size(); + unsigned oldPosRank = insertOp.getNumIndices(); unsigned newPosRank = std::max(0, oldPosRank - dstDropCount); - SmallVector newPositions = llvm::to_vector( - insertOp.getPosition().getValue().take_back(newPosRank)); - newPositions.resize(newDstType.getRank() - newSrcRank, - rewriter.getI64IntegerAttr(0)); + SmallVector oldPosition = insertOp.getMixedPosition(); + SmallVector newPosition = + llvm::to_vector(ArrayRef(oldPosition).take_back(newPosRank)); + newPosition.resize(newDstType.getRank() - newSrcRank, + rewriter.getI64IntegerAttr(0)); auto newInsertOp = rewriter.create( - loc, newDstType, newSrcVector, newDstVector, - rewriter.getArrayAttr(newPositions)); + loc, newSrcVector, newDstVector, newPosition); rewriter.replaceOpWithNewOp(insertOp, oldDstType, newInsertOp); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -703,10 +703,10 @@ auto xferOp = extractOp.getVector().getDefiningOp(); SmallVector newIndices(xferOp.getIndices().begin(), xferOp.getIndices().end()); - for (const auto &it : llvm::enumerate(extractOp.getPosition())) { - int64_t offset = cast(it.value()).getInt(); - int64_t idx = - newIndices.size() - extractOp.getPosition().size() + it.index(); + for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) { + assert(pos.is() && "Unespected non-constant index"); + int64_t offset = cast(pos.get()).getInt(); + int64_t idx = newIndices.size() - extractOp.getNumIndices() + i; OpFoldResult ofr = affine::makeComposedFoldedAffineApply( rewriter, extractOp.getLoc(), rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]}); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -598,19 +598,17 @@ unsigned expandRatio = castDstType.getNumElements() / castSrcType.getNumElements(); - auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t { - return (*attr.getAsValueRange().begin()).getZExtValue(); + auto getFirstIntValue = [](ArrayRef values) -> uint64_t { + assert(values[0].is() && "Unexpected non-constant index"); + return cast(values[0].get()).getInt(); }; - uint64_t index = getFirstIntValue(extractOp.getPosition()); + uint64_t index = getFirstIntValue(extractOp.getMixedPosition()); // Get the single scalar (as a vector) in the source value that packs the // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> - VectorType oneScalarType = - VectorType::get({1}, castSrcType.getElementType()); Value packedValue = rewriter.create( - extractOp.getLoc(), oneScalarType, castOp.getSource(), - rewriter.getI64ArrayAttr(index / expandRatio)); + extractOp.getLoc(), castOp.getSource(), index / expandRatio); // Cast it to a vector with the desired scalar's type. // E.g. f32 -> vector<2xf16> @@ -620,9 +618,8 @@ extractOp.getLoc(), packedType, packedValue); // Finally extract the desired scalar. - rewriter.replaceOpWithNewOp( - extractOp, extractOp.getType(), castedValue, - rewriter.getI64ArrayAttr(index % expandRatio)); + rewriter.replaceOpWithNewOp(extractOp, castedValue, + index % expandRatio); return success(); }