diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -131,6 +131,24 @@ return cast(attr).getValue() == IteratorType::reduction; } +/// Returns the integer numbers in `values`. `values` are expected to be +/// constant operations. +SmallVector getAsIntegers(ArrayRef values); + +/// Returns the integer numbers in `foldResults`. `foldResults` are expected to +/// be constant operations. +SmallVector getAsIntegers(ArrayRef foldResults); + +/// Convert `foldResults` into Values. Integer attributes are converted to +/// constant op. +SmallVector getAsValues(OpBuilder &builder, Location loc, + ArrayRef foldResults); + +/// Returns the constant index ops in `values`. `values` are expected to be +/// constant operations. +SmallVector +getAsConstantIndexOps(ArrayRef values); + //===----------------------------------------------------------------------===// // Vector Masking Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -523,9 +523,7 @@ Vector_Op<"extract", [Pure, PredOpTrait<"operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, - InferTypeOpAdaptorWithIsCompatible]>, - Arguments<(ins AnyVectorOfAnyRank:$vector, DenseI64ArrayAttr:$position)>, - Results<(outs AnyType)> { + InferTypeOpAdaptorWithIsCompatible]> { let summary = "extract operation"; let description = [{ Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at @@ -535,21 +533,55 @@ ```mlir %1 = vector.extract %0[3]: vector<4x8x16xf32> - %2 = vector.extract %0[3, 3, 3]: vector<4x8x16xf32> + %2 = vector.extract %0[2, 1, 3]: vector<4x8x16xf32> %3 = vector.extract %1[]: vector + %4 = vector.extract %0[%a, %b, %c]: vector<4x8x16xf32> + %5 = vector.extract %0[2, %b]: vector<4x8x16xf32> ``` }]; + + let arguments = (ins + AnyVectorOfAnyRank:$vector, + Variadic:$dynamic_position, + DenseI64ArrayAttr:$static_position + ); + let results = (outs AnyType:$result); + let builders = [ - // Convenience builder which assumes the values in `position` are defined by - // ConstantIndexOp. - OpBuilder<(ins "Value":$source, "ValueRange":$position)> + OpBuilder<(ins "Value":$source, "int64_t":$position)>, + OpBuilder<(ins "Value":$source, "OpFoldResult":$position)>, + OpBuilder<(ins "Value":$source, "ArrayRef":$position)>, + OpBuilder<(ins "Value":$source, "ArrayRef":$position)>, ]; + let extraClassDeclaration = [{ VectorType getSourceVectorType() { return ::llvm::cast(getVector().getType()); } + + /// Return a vector with all the static and dynamic position indices. + SmallVector getMixedPosition() { + OpBuilder builder(getContext()); + return getMixedValues(getStaticPosition(), getDynamicPosition(), builder); + } + + unsigned getNumIndices() { + return getStaticPosition().size(); + } + + bool hasDynamicPosition() { + auto dynPos = getDynamicPosition(); + return std::any_of(dynPos.begin(), dynPos.end(), + [](Value operand) { return operand != nullptr; }); + } }]; - let assemblyFormat = "$vector `` $position attr-dict `:` type($vector)"; + + let assemblyFormat = [{ + $vector `` + custom($dynamic_position, $static_position) + attr-dict `:` type($vector) + }]; + let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; @@ -638,9 +670,7 @@ Vector_Op<"insert", [Pure, PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, - AllTypesMatch<["dest", "res"]>]>, - Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, DenseI64ArrayAttr:$position)>, - Results<(outs AnyVectorOfAnyRank:$res)> { + AllTypesMatch<["dest", "result"]>]> { let summary = "insert operation"; let description = [{ Takes an n-D source vector, an (n+k)-D destination vector and a k-D position @@ -651,24 +681,53 @@ ```mlir %2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32> - %5 = vector.insert %3, %4[3, 3, 3] : f32 into vector<4x8x16xf32> + %5 = vector.insert %3, %4[2, 1, 3] : f32 into vector<4x8x16xf32> %8 = vector.insert %6, %7[] : f32 into vector - %11 = vector.insert %9, %10[3, 3, 3] : vector into vector<4x8x16xf32> + %11 = vector.insert %9, %10[%a, %b, %c] : vector into vector<4x8x16xf32> + %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32> ``` }]; - let assemblyFormat = [{ - $source `,` $dest $position attr-dict `:` type($source) `into` type($dest) - }]; + + let arguments = (ins + AnyType:$source, + AnyVectorOfAnyRank:$dest, + Variadic:$dynamic_position, + DenseI64ArrayAttr:$static_position + ); + let results = (outs AnyVectorOfAnyRank:$result); let builders = [ - // Convenience builder which assumes all values are constant indices. - OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$position)> + OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>, + OpBuilder<(ins "Value":$source, "Value":$dest, "OpFoldResult":$position)>, + OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef":$position)>, + OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef":$position)>, ]; + let extraClassDeclaration = [{ Type getSourceType() { return getSource().getType(); } VectorType getDestVectorType() { return ::llvm::cast(getDest().getType()); } + + /// Return a vector with all the static and dynamic position indices. + SmallVector getMixedPosition() { + OpBuilder builder(getContext()); + return getMixedValues(getStaticPosition(), getDynamicPosition(), builder); + } + + unsigned getNumIndices() { + return getStaticPosition().size(); + } + + bool hasDynamicPosition() { + return llvm::any_of(getDynamicPosition(), + [](Value operand) { return operand != nullptr; }); + } + }]; + + let assemblyFormat = [{ + $source `,` $dest custom($dynamic_position, $static_position) + attr-dict `:` type($source) `into` type($dest) }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -126,6 +126,18 @@ return rewriter.create(loc, pType, ptr); } +/// Convert `foldResult` into a Value. Integer attribute is converted to +/// an LLVM constant op. +static Value getAsLLVMValue(OpBuilder &builder, Location loc, + OpFoldResult foldResult) { + if (auto attr = foldResult.dyn_cast()) { + auto intAttr = cast(attr); + return builder.create(loc, intAttr).getResult(); + } + + return foldResult.get(); +} + namespace { /// Trivial Vector to LLVM conversions @@ -1079,41 +1091,53 @@ auto loc = extractOp->getLoc(); auto resultType = extractOp.getResult().getType(); auto llvmResultType = typeConverter->convertType(resultType); - ArrayRef positionArray = extractOp.getPosition(); - // Bail if result type cannot be lowered. if (!llvmResultType) return failure(); + SmallVector positionVec; + for (auto [idx, pos] : llvm::enumerate(extractOp.getMixedPosition())) { + if (pos.is()) + // Make sure we use the value that has been already converted to LLVM. + positionVec.push_back(adaptor.getDynamicPosition()[idx]); + else + positionVec.push_back(pos); + } + // Extract entire vector. Should be handled by folder, but just to be safe. - if (positionArray.empty()) { + ArrayRef position(positionVec); + if (position.empty()) { rewriter.replaceOp(extractOp, adaptor.getVector()); return success(); } // One-shot extraction of vector from array (only requires extractvalue). if (isa(resultType)) { + if (extractOp.hasDynamicPosition()) + return failure(); + Value extracted = rewriter.create( - loc, adaptor.getVector(), positionArray); + loc, adaptor.getVector(), getAsIntegers(position)); rewriter.replaceOp(extractOp, extracted); return success(); } // Potential extraction of 1-D vector from array. Value extracted = adaptor.getVector(); - if (positionArray.size() > 1) { - extracted = rewriter.create( - loc, extracted, positionArray.drop_back()); - } + if (position.size() > 1) { + if (extractOp.hasDynamicPosition()) + return failure(); - // Remaining extraction of element from 1-D LLVM vector - auto i64Type = IntegerType::get(rewriter.getContext(), 64); - auto constant = - rewriter.create(loc, i64Type, positionArray.back()); - extracted = - rewriter.create(loc, extracted, constant); - rewriter.replaceOp(extractOp, extracted); + SmallVector nMinusOnePosition = + getAsIntegers(position.drop_back()); + extracted = rewriter.create(loc, extracted, + nMinusOnePosition); + } + Value lastPosition = getAsLLVMValue(rewriter, loc, position.back()); + // Remaining extraction of element from 1-D LLVM vector. + rewriter.replaceOpWithNewOp(extractOp, extracted, + lastPosition); return success(); } }; @@ -1194,23 +1218,34 @@ auto sourceType = insertOp.getSourceType(); auto destVectorType = insertOp.getDestVectorType(); auto llvmResultType = typeConverter->convertType(destVectorType); - ArrayRef positionArray = insertOp.getPosition(); - // Bail if result type cannot be lowered. if (!llvmResultType) return failure(); + SmallVector positionVec; + for (auto [idx, pos] : llvm::enumerate(insertOp.getMixedPosition())) { + if (pos.is()) + // Make sure we use the value that has been already converted to LLVM. + positionVec.push_back(adaptor.getDynamicPosition()[idx]); + else + positionVec.push_back(pos); + } + // Overwrite entire vector with value. Should be handled by folder, but // just to be safe. - if (positionArray.empty()) { + ArrayRef position(positionVec); + if (position.empty()) { rewriter.replaceOp(insertOp, adaptor.getSource()); return success(); } // One-shot insertion of a vector into an array (only requires insertvalue). if (isa(sourceType)) { + if (insertOp.hasDynamicPosition()) + return failure(); + Value inserted = rewriter.create( - loc, adaptor.getDest(), adaptor.getSource(), positionArray); + loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position)); rewriter.replaceOp(insertOp, inserted); return success(); } @@ -1218,24 +1253,28 @@ // Potential extraction of 1-D vector from array. Value extracted = adaptor.getDest(); auto oneDVectorType = destVectorType; - if (positionArray.size() > 1) { + if (position.size() > 1) { + if (insertOp.hasDynamicPosition()) + return failure(); + oneDVectorType = reducedVectorTypeBack(destVectorType); extracted = rewriter.create( - loc, extracted, positionArray.drop_back()); + loc, extracted, getAsIntegers(position.drop_back())); } // Insertion of an element into a 1-D LLVM vector. - auto i64Type = IntegerType::get(rewriter.getContext(), 64); - auto constant = - rewriter.create(loc, i64Type, positionArray.back()); Value inserted = rewriter.create( loc, typeConverter->convertType(oneDVectorType), extracted, - adaptor.getSource(), constant); + adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back())); // Potential insertion of resulting 1-D vector into array. - if (positionArray.size() > 1) { + if (position.size() > 1) { + if (insertOp.hasDynamicPosition()) + return failure(); + inserted = rewriter.create( - loc, adaptor.getDest(), inserted, positionArray.drop_back()); + loc, adaptor.getDest(), inserted, + getAsIntegers(position.drop_back())); } rewriter.replaceOp(insertOp, inserted); diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -1063,10 +1063,11 @@ /// If the result of the TransferReadOp has exactly one user, which is a /// vector::InsertOp, return that operation's indices. void getInsertionIndices(TransferReadOp xferOp, - SmallVector &indices) const { - if (auto insertOp = getInsertOp(xferOp)) - indices.assign(insertOp.getPosition().begin(), - insertOp.getPosition().end()); + SmallVectorImpl &indices) const { + if (auto insertOp = getInsertOp(xferOp)) { + auto pos = insertOp.getMixedPosition(); + indices.append(pos.begin(), pos.end()); + } } /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds @@ -1110,9 +1111,9 @@ getXferIndices(b, xferOp, iv, xferIndices); // Indices for the new vector.insert op. - SmallVector insertionIndices; + SmallVector insertionIndices; getInsertionIndices(xferOp, insertionIndices); - insertionIndices.push_back(i); + insertionIndices.push_back(rewriter.getIndexAttr(i)); auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); auto newXferOp = b.create( @@ -1195,10 +1196,11 @@ /// If the input of the given TransferWriteOp is an ExtractOp, return its /// indices. void getExtractionIndices(TransferWriteOp xferOp, - SmallVector &indices) const { - if (auto extractOp = getExtractOp(xferOp)) - indices.assign(extractOp.getPosition().begin(), - extractOp.getPosition().end()); + SmallVectorImpl &indices) const { + if (auto extractOp = getExtractOp(xferOp)) { + auto pos = extractOp.getMixedPosition(); + indices.append(pos.begin(), pos.end()); + } } /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds @@ -1235,9 +1237,9 @@ getXferIndices(b, xferOp, iv, xferIndices); // Indices for the new vector.extract op. - SmallVector extractionIndices; + SmallVector extractionIndices; getExtractionIndices(xferOp, extractionIndices); - extractionIndices.push_back(i); + extractionIndices.push_back(b.getI64IntegerAttr(i)); auto extracted = b.create(loc, vec, extractionIndices); diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -35,11 +35,25 @@ using namespace mlir; -/// Gets the first integer value from `attr`, assuming it is an integer array -/// attribute. +/// Returns the integer value from the first valid input element, assuming Value +/// inputs are defined by a constant index ops and Attribute inputs are integer +/// attributes. +static uint64_t getFirstIntValue(ValueRange values) { + return values[0].getDefiningOp().value(); +} +static uint64_t getFirstIntValue(ArrayRef attr) { + return cast(attr[0]).getInt(); +} static uint64_t getFirstIntValue(ArrayAttr attr) { return (*attr.getAsValueRange().begin()).getZExtValue(); } +static uint64_t getFirstIntValue(ArrayRef foldResults) { + auto attr = foldResults[0].dyn_cast(); + if (attr) + return getFirstIntValue(attr); + + return getFirstIntValue(ValueRange{foldResults[0].get()}); +} /// Returns the number of bits for the given scalar/vector type. static int getNumBits(Type type) { @@ -141,9 +155,7 @@ LogicalResult matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // Only support extracting a scalar value now. - VectorType resultVectorType = dyn_cast(extractOp.getType()); - if (resultVectorType && resultVectorType.getNumElements() > 1) + if (extractOp.hasDynamicPosition()) return failure(); Type dstType = getTypeConverter()->convertType(extractOp.getType()); @@ -155,7 +167,7 @@ return success(); } - int32_t id = extractOp.getPosition()[0]; + int32_t id = getFirstIntValue(extractOp.getMixedPosition()); rewriter.replaceOpWithNewOp( extractOp, adaptor.getVector(), id); return success(); @@ -235,7 +247,7 @@ return success(); } - int32_t id = insertOp.getPosition()[0]; + int32_t id = getFirstIntValue(insertOp.getMixedPosition()); rewriter.replaceOpWithNewOp( insertOp, adaptor.getSource(), adaptor.getDest(), id); return success(); diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp --- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp @@ -516,7 +516,7 @@ return failure(); Value newExtract = rewriter.create( - op.getLoc(), ext->getIn(), op.getPosition()); + op.getLoc(), ext->getIn(), op.getMixedPosition()); ext->recreateAndReplace(rewriter, op, newExtract); return success(); } @@ -645,8 +645,9 @@ vector::InsertOp origInsert, Value narrowValue, Value narrowDest) const override { - return rewriter.create( - origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition()); + return rewriter.create(origInsert.getLoc(), narrowValue, + narrowDest, + origInsert.getMixedPosition()); } }; diff --git a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp --- a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp +++ b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp @@ -74,7 +74,7 @@ if (auto maskOp = extractOp.getVector().getDefiningOp()) return TransferMask{maskOp, - SmallVector(extractOp.getPosition())}; + SmallVector(extractOp.getStaticPosition())}; // All other cases: not supported. return failure(); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -223,6 +223,48 @@ return failure(); } +/// Returns the integer numbers in `values`. `values` are expected to be +/// constant operations. +SmallVector vector::getAsIntegers(ArrayRef values) { + SmallVector ints; + llvm::transform(values, std::back_inserter(ints), [](Value value) { + auto constOp = value.getDefiningOp(); + assert(constOp && "Unexpected non-constant index"); + return constOp.value(); + }); + return ints; +} + +/// Returns the integer numbers in `foldResults`. `foldResults` are expected to +/// be constant operations. +SmallVector vector::getAsIntegers(ArrayRef foldResults) { + SmallVector ints; + llvm::transform( + foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) { + assert(foldResult.is() && "Unexpected non-constant index"); + return cast(foldResult.get()).getInt(); + }); + return ints; +} + +/// Convert `foldResults` into Values. Integer attributes are converted to +/// constant op. +SmallVector vector::getAsValues(OpBuilder &builder, Location loc, + ArrayRef foldResults) { + SmallVector values; + llvm::transform(foldResults, std::back_inserter(values), + [&](OpFoldResult foldResult) { + if (auto attr = foldResult.dyn_cast()) + return builder + .create( + loc, cast(attr).getInt()) + .getResult(); + + return foldResult.get(); + }); + return values; +} + //===----------------------------------------------------------------------===// // CombiningKindAttr //===----------------------------------------------------------------------===// @@ -389,12 +431,11 @@ } else { // This means we are reducing all the dimensions, and all reduction // dimensions are of size 1. So a simple extraction would do. - SmallVector zeroAttr(shape.size(), 0); + SmallVector zeroIdx(shape.size(), 0); if (mask) - mask = rewriter.create(loc, rewriter.getI1Type(), - mask, zeroAttr); - cast = rewriter.create( - loc, reductionOp.getDestType(), reductionOp.getSource(), zeroAttr); + mask = rewriter.create(loc, mask, zeroIdx); + cast = rewriter.create(loc, reductionOp.getSource(), + zeroIdx); } Value result = vector::makeArithReduction( @@ -574,11 +615,9 @@ mask = rewriter.create(loc, mask); result = rewriter.create(loc, reductionOp.getVector()); } else { - if (mask) { - mask = rewriter.create(loc, rewriter.getI1Type(), mask, 0); - } - result = rewriter.create(loc, reductionOp.getType(), - reductionOp.getVector(), 0); + if (mask) + mask = rewriter.create(loc, mask, 0); + result = rewriter.create(loc, reductionOp.getVector(), 0); } if (Value acc = reductionOp.getAcc()) @@ -1148,12 +1187,29 @@ // ExtractOp //===----------------------------------------------------------------------===// -// Convenience builder which assumes the values are constant indices. void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, - Value source, ValueRange position) { - SmallVector positionConstants = llvm::to_vector(llvm::map_range( - position, [](Value pos) { return getConstantIntValue(pos).value(); })); - build(builder, result, source, positionConstants); + Value source, int64_t position) { + build(builder, result, source, ArrayRef{position}); +} + +void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, + Value source, OpFoldResult position) { + build(builder, result, source, ArrayRef{position}); +} + +void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, + Value source, ArrayRef position) { + build(builder, result, source, /*dynamic_position=*/ArrayRef(), + builder.getDenseI64ArrayAttr(position)); +} + +void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, + Value source, ArrayRef position) { + SmallVector staticPos; + SmallVector dynamicPos; + dispatchIndexOpFoldResults(position, dynamicPos, staticPos); + build(builder, result, source, dynamicPos, + builder.getDenseI64ArrayAttr(staticPos)); } LogicalResult @@ -1161,12 +1217,12 @@ ExtractOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { auto vectorType = llvm::cast(adaptor.getVector().getType()); - if (static_cast(adaptor.getPosition().size()) == + if (static_cast(adaptor.getStaticPosition().size()) == vectorType.getRank()) { inferredReturnTypes.push_back(vectorType.getElementType()); } else { - auto n = - std::min(adaptor.getPosition().size(), vectorType.getRank()); + auto n = std::min(adaptor.getStaticPosition().size(), + vectorType.getRank()); inferredReturnTypes.push_back(VectorType::get( vectorType.getShape().drop_front(n), vectorType.getElementType(), vectorType.getScalableDims().drop_front(n))); @@ -1188,17 +1244,20 @@ } LogicalResult vector::ExtractOp::verify() { - ArrayRef position = getPosition(); + auto position = getMixedPosition(); if (position.size() > static_cast(getSourceVectorType().getRank())) return emitOpError( "expected position attribute of rank no greater than vector rank"); - for (const auto &en : llvm::enumerate(position)) { - if (en.value() < 0 || - en.value() >= getSourceVectorType().getDimSize(en.index())) - return emitOpError("expected position attribute #") - << (en.index() + 1) - << " to be a non-negative integer smaller than the corresponding " - "vector dimension"; + for (auto [idx, pos] : llvm::enumerate(position)) { + if (pos.is()) { + int64_t constIdx = cast(pos.get()).getInt(); + if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) { + return emitOpError("expected position attribute #") + << (idx + 1) + << " to be a non-negative integer smaller than the " + "corresponding vector dimension"; + } + } } return success(); } @@ -1216,20 +1275,24 @@ if (!extractOp.getVector().getDefiningOp()) return failure(); - SmallVector globalPosition; + // TODO: Canonicalization for dynamic position not implemented yet. + if (extractOp.hasDynamicPosition()) + return failure(); + + SmallVector globalPosition; ExtractOp currentOp = extractOp; - ArrayRef extrPos = currentOp.getPosition(); + ArrayRef extrPos = currentOp.getStaticPosition(); globalPosition.append(extrPos.rbegin(), extrPos.rend()); while (ExtractOp nextOp = currentOp.getVector().getDefiningOp()) { currentOp = nextOp; - ArrayRef extrPos = currentOp.getPosition(); + ArrayRef extrPos = currentOp.getStaticPosition(); globalPosition.append(extrPos.rbegin(), extrPos.rend()); } - extractOp.setOperand(currentOp.getVector()); + extractOp.setOperand(0, currentOp.getVector()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); std::reverse(globalPosition.begin(), globalPosition.end()); - extractOp.setPosition(globalPosition); + extractOp.setStaticPosition(globalPosition); return success(); } @@ -1335,19 +1398,23 @@ ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState( ExtractOp e) : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()), - extractedRank(extractOp.getPosition().size()) { - assert(vectorRank >= extractedRank && "extracted pos overflow"); + extractedRank(extractOp.getNumIndices()) { + assert(vectorRank >= extractedRank && "Extracted position overflow"); sentinels.reserve(vectorRank - extractedRank); for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i) sentinels.push_back(-(i + 1)); - extractPosition.assign(extractOp.getPosition().begin(), - extractOp.getPosition().end()); + extractPosition.assign(extractOp.getStaticPosition().begin(), + extractOp.getStaticPosition().end()); llvm::append_range(extractPosition, sentinels); } // Case 1. If we hit a transpose, just compose the map and iterate. // Invariant: insert + transpose do not change rank, we can always compose. LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() { + // TODO: Canonicalization for dynamic position not implemented yet. + if (extractOp.hasDynamicPosition()) + return failure(); + if (!nextTransposeOp) return failure(); auto permutation = extractVector(nextTransposeOp.getTransp()); @@ -1361,7 +1428,11 @@ LogicalResult ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos( Value &res) { - ArrayRef insertedPos = nextInsertOp.getPosition(); + // TODO: Canonicalization for dynamic position not implemented yet. + if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition()) + return failure(); + + ArrayRef insertedPos = nextInsertOp.getStaticPosition(); if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank)) return failure(); // Case 2.a. early-exit fold. @@ -1375,7 +1446,11 @@ /// This method updates the internal state. LogicalResult ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) { - ArrayRef insertedPos = nextInsertOp.getPosition(); + // TODO: Canonicalization for dynamic position not implemented yet. + if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition()) + return failure(); + + ArrayRef insertedPos = nextInsertOp.getStaticPosition(); if (!isContainedWithin(insertedPos, extractPosition)) return failure(); // Set leading dims to zero. @@ -1395,19 +1470,29 @@ /// internal tranposition in the result). Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace( Value source) { + // TODO: Canonicalization for dynamic position not implemented yet. + if (extractOp.hasDynamicPosition()) + return Value(); + // If we can't fold (either internal transposition, or nothing to fold), bail. bool nothingToFold = (source == extractOp.getVector()); if (nothingToFold || !canFold()) return Value(); + // Otherwise, fold by updating the op inplace and return its result. OpBuilder b(extractOp.getContext()); - extractOp.setPosition(ArrayRef(extractPosition).take_front(extractedRank)); + extractOp.setStaticPosition( + ArrayRef(extractPosition).take_front(extractedRank)); extractOp.getVectorMutable().assign(source); return extractOp.getResult(); } /// Iterate over producing insert and transpose ops until we find a fold. Value ExtractFromInsertTransposeChainState::fold() { + // TODO: Canonicalization for dynamic position not implemented yet. + if (extractOp.hasDynamicPosition()) + return Value(); + Value valueToExtractFrom = extractOp.getVector(); updateStateForNextIteration(valueToExtractFrom); while (nextInsertOp || nextTransposeOp) { @@ -1431,7 +1516,7 @@ // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel // values. This is a more difficult case and we bail. - ArrayRef insertedPos = nextInsertOp.getPosition(); + ArrayRef insertedPos = nextInsertOp.getStaticPosition(); if (isContainedWithin(extractPosition, insertedPos) || intersectsWhereNonNegative(extractPosition, insertedPos)) return Value(); @@ -1457,6 +1542,10 @@ /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp. static Value foldExtractFromBroadcast(ExtractOp extractOp) { + // TODO: Canonicalization for dynamic position not implemented yet. + if (extractOp.hasDynamicPosition()) + return Value(); + Operation *defOp = extractOp.getVector().getDefiningOp(); if (!defOp || !isa(defOp)) return Value(); @@ -1497,7 +1586,7 @@ // extract position to `0` when extracting from the source operand. llvm::SetVector broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims(); - SmallVector extractPos(extractOp.getPosition()); + SmallVector extractPos(extractOp.getStaticPosition()); int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank; for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i) if (broadcastedUnitDims.contains(i)) @@ -1509,13 +1598,17 @@ std::next(extractPos.begin(), extractPos.size() - rankDiff)); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); - extractOp.setOperand(source); - extractOp.setPosition(extractPos); + extractOp.setOperand(0, source); + extractOp.setStaticPosition(extractPos); return extractOp.getResult(); } // Fold extractOp with source coming from ShapeCast op. static Value foldExtractFromShapeCast(ExtractOp extractOp) { + // TODO: Canonicalization for dynamic position not implemented yet. + if (extractOp.hasDynamicPosition()) + return Value(); + auto shapeCastOp = extractOp.getVector().getDefiningOp(); if (!shapeCastOp) return Value(); @@ -1549,7 +1642,7 @@ } // Extract the strides associated with the extract op vector source. Then use // this to calculate a linearized position for the extract. - SmallVector extractedPos(extractOp.getPosition()); + SmallVector extractedPos(extractOp.getStaticPosition()); std::reverse(extractedPos.begin(), extractedPos.end()); SmallVector strides; int64_t stride = 1; @@ -1575,13 +1668,17 @@ SmallVector newPosition = delinearize(position, newStrides); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); - extractOp.setPosition(newPosition); - extractOp.setOperand(shapeCastOp.getSource()); + extractOp.setStaticPosition(newPosition); + extractOp.setOperand(0, shapeCastOp.getSource()); return extractOp.getResult(); } /// Fold an ExtractOp from ExtractStridedSliceOp. static Value foldExtractFromExtractStrided(ExtractOp extractOp) { + // TODO: Canonicalization for dynamic position not implemented yet. + if (extractOp.hasDynamicPosition()) + return Value(); + auto extractStridedSliceOp = extractOp.getVector().getDefiningOp(); if (!extractStridedSliceOp) @@ -1615,19 +1712,25 @@ if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() - sliceOffsets.size()) return Value(); - SmallVector extractedPos(extractOp.getPosition()); + + SmallVector extractedPos(extractOp.getStaticPosition()); assert(extractedPos.size() >= sliceOffsets.size()); for (size_t i = 0, e = sliceOffsets.size(); i < e; i++) extractedPos[i] = extractedPos[i] + sliceOffsets[i]; extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector()); + // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); - extractOp.setPosition(extractedPos); + extractOp.setStaticPosition(extractedPos); return extractOp.getResult(); } /// Fold extract_op fed from a chain of insertStridedSlice ops. static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) { + // TODO: Canonicalization for dynamic position not implemented yet. + if (extractOp.hasDynamicPosition()) + return Value(); + int64_t destinationRank = llvm::isa(extractOp.getType()) ? llvm::cast(extractOp.getType()).getRank() @@ -1647,7 +1750,7 @@ if (destinationRank > insertOp.getSourceVectorType().getRank()) return Value(); auto insertOffsets = extractVector(insertOp.getOffsets()); - ArrayRef extractOffsets = extractOp.getPosition(); + ArrayRef extractOffsets = extractOp.getStaticPosition(); if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) { return llvm::cast(attr).getInt() != 1; @@ -1687,7 +1790,7 @@ extractOp.getVectorMutable().assign(insertOp.getSource()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); - extractOp.setPosition(offsetDiffs); + extractOp.setStaticPosition(offsetDiffs); return extractOp.getResult(); } // If the chunk extracted is disjoint from the chunk inserted, keep @@ -1698,7 +1801,7 @@ } OpFoldResult ExtractOp::fold(FoldAdaptor) { - if (getPosition().empty()) + if (getNumIndices() == 0) return getVector(); if (succeeded(foldExtractOpFromExtractChain(*this))) return getResult(); @@ -1788,6 +1891,10 @@ LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { + // TODO: Canonicalization for dynamic position not implemented yet. + if (extractOp.hasDynamicPosition()) + return failure(); + // Return if 'ExtractOp' operand is not defined by a compatible vector // ConstantOp. Value sourceVector = extractOp.getVector(); @@ -1807,7 +1914,7 @@ // Calculate the linearized position of the continuous chunk of elements to // extract. llvm::SmallVector completePositions(vecTy.getRank(), 0); - copy(extractOp.getPosition(), completePositions.begin()); + copy(extractOp.getStaticPosition(), completePositions.begin()); int64_t elemBeginPosition = linearize(completePositions, computeStrides(vecTy.getShape())); auto denseValuesBegin = dense.value_begin() + elemBeginPosition; @@ -2322,18 +2429,38 @@ // InsertOp //===----------------------------------------------------------------------===// -// Convenience builder which assumes the values are constant indices. -void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, - Value dest, ValueRange position) { - SmallVector positionConstants = - llvm::to_vector<4>(llvm::map_range(position, [](Value pos) { - return getConstantIntValue(pos).value(); - })); - build(builder, result, source, dest, positionConstants); +void vector::InsertOp::build(OpBuilder &builder, OperationState &result, + Value source, Value dest, int64_t position) { + build(builder, result, source, dest, ArrayRef{position}); +} + +void vector::InsertOp::build(OpBuilder &builder, OperationState &result, + Value source, Value dest, OpFoldResult position) { + build(builder, result, source, dest, ArrayRef{position}); +} + +void vector::InsertOp::build(OpBuilder &builder, OperationState &result, + Value source, Value dest, + ArrayRef position) { + SmallVector posVals; + posVals.reserve(position.size()); + llvm::transform(position, std::back_inserter(posVals), + [&](int64_t pos) { return builder.getI64IntegerAttr(pos); }); + build(builder, result, source, dest, posVals); +} + +void vector::InsertOp::build(OpBuilder &builder, OperationState &result, + Value source, Value dest, + ArrayRef position) { + SmallVector staticPos; + SmallVector dynamicPos; + dispatchIndexOpFoldResults(position, dynamicPos, staticPos); + build(builder, result, source, dest, dynamicPos, + builder.getDenseI64ArrayAttr(staticPos)); } LogicalResult InsertOp::verify() { - ArrayRef position = getPosition(); + SmallVector position = getMixedPosition(); auto destVectorType = getDestVectorType(); if (position.size() > static_cast(destVectorType.getRank())) return emitOpError( @@ -2348,13 +2475,17 @@ (position.size() != static_cast(destVectorType.getRank()))) return emitOpError( "expected position attribute rank to match the dest vector rank"); - for (const auto &en : llvm::enumerate(position)) { - int64_t attr = en.value(); - if (attr < 0 || attr >= destVectorType.getDimSize(en.index())) - return emitOpError("expected position attribute #") - << (en.index() + 1) - << " to be a non-negative integer smaller than the corresponding " - "dest vector dimension"; + for (auto [idx, pos] : llvm::enumerate(position)) { + if (auto attr = pos.dyn_cast()) { + int64_t constIdx = cast(attr).getInt(); + if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) { + return emitOpError("expected position attribute #") + << (idx + 1) + << " to be a non-negative integer smaller than the " + "corresponding " + "dest vector dimension"; + } + } } return success(); } @@ -2411,6 +2542,10 @@ LogicalResult matchAndRewrite(InsertOp op, PatternRewriter &rewriter) const override { + // TODO: Canonicalization for dynamic position not implemented yet. + if (op.hasDynamicPosition()) + return failure(); + // Return if 'InsertOp' operand is not defined by a compatible vector // ConstantOp. TypedValue destVector = op.getDest(); @@ -2437,7 +2572,7 @@ // Calculate the linearized position of the continuous chunk of elements to // insert. llvm::SmallVector completePositions(destTy.getRank(), 0); - copy(op.getPosition(), completePositions.begin()); + copy(op.getStaticPosition(), completePositions.begin()); int64_t insertBeginPosition = linearize(completePositions, computeStrides(destTy.getShape())); @@ -2468,7 +2603,7 @@ // value. This happens when the source and destination vectors have identical // sizes. OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) { - if (getPosition().empty()) + if (getNumIndices() == 0) return getSource(); return {}; } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -89,20 +89,20 @@ PatternRewriter &rewriter) { if (index == -1) return val; - Type lowType = type.getRank() > 1 ? VectorType::Builder(type).dropDim(0) - : type.getElementType(); + // At extraction dimension? if (index == 0) - return rewriter.create(loc, lowType, val, pos); + return rewriter.create(loc, val, pos); + // Unroll leading dimensions. - VectorType vType = cast(lowType); + VectorType vType = VectorType::Builder(type).dropDim(0); VectorType resType = VectorType::Builder(type).dropDim(index); Value result = rewriter.create( loc, resType, rewriter.getZeroAttr(resType)); for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) { - Value ext = rewriter.create(loc, vType, val, d); + Value ext = rewriter.create(loc, val, d); Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); - result = rewriter.create(loc, resType, load, result, d); + result = rewriter.create(loc, load, result, d); } return result; } @@ -117,16 +117,15 @@ return val; // At insertion dimension? if (index == 0) - return rewriter.create(loc, type, val, result, pos); + return rewriter.create(loc, val, result, pos); + // Unroll leading dimensions. - VectorType lowType = VectorType::Builder(type).dropDim(0); - Type insType = lowType.getRank() > 1 ? VectorType::Builder(lowType).dropDim(0) - : lowType.getElementType(); + VectorType vType = VectorType::Builder(type).dropDim(0); for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { - Value ext = rewriter.create(loc, lowType, result, d); - Value ins = rewriter.create(loc, insType, val, d); - Value sto = reshapeStore(loc, ins, ext, lowType, index - 1, pos, rewriter); - result = rewriter.create(loc, type, sto, result, d); + Value ext = rewriter.create(loc, result, d); + Value ins = rewriter.create(loc, val, d); + Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); + result = rewriter.create(loc, sto, result, d); } return result; } @@ -1175,7 +1174,7 @@ loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask); if (!m.has_value()) return failure(); - result = rewriter.create(loc, resType, *m, result, d); + result = rewriter.create(loc, *m, result, d); } rewriter.replaceOp(rootOp, result); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -79,7 +79,7 @@ Value val = rewriter.create(loc, arith::CmpIPredicate::slt, bnd, idx); Value sel = rewriter.create(loc, val, trueVal, falseVal); - result = rewriter.create(loc, dstType, sel, result, d); + result = rewriter.create(loc, sel, result, d); } rewriter.replaceOp(op, result); return success(); @@ -151,8 +151,8 @@ Value result = rewriter.create( loc, dstType, rewriter.getZeroAttr(dstType)); for (int64_t d = 0; d < trueDimSize; d++) - result = - rewriter.create(loc, dstType, trueVal, result, d); + result = rewriter.create(loc, trueVal, result, d); + rewriter.replaceOp(op, result); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1040,13 +1040,17 @@ "vector.extract does not support rank 0 sources"); // "vector.extract %v[] : vector<...xf32>" can be canonicalized to %v. - if (extractOp.getPosition().empty()) + if (extractOp.getNumIndices() == 0) return failure(); // Rewrite vector.extract with 1d source to vector.extractelement. if (extractSrcType.getRank() == 1) { - assert(extractOp.getPosition().size() == 1 && "expected 1 index"); - int64_t pos = extractOp.getPosition()[0]; + if (extractOp.hasDynamicPosition()) + // TODO: Dinamic position not supported yet. + return failure(); + + assert(extractOp.getNumIndices() == 1 && "expected 1 index"); + int64_t pos = extractOp.getStaticPosition()[0]; rewriter.setInsertionPoint(extractOp); rewriter.replaceOpWithNewOp( extractOp, extractOp.getVector(), @@ -1070,7 +1074,7 @@ Value distributedVec = newWarpOp->getResult(newRetIndices[0]); // Extract from distributed vector. Value newExtract = rewriter.create( - loc, distributedVec, extractOp.getPosition()); + loc, distributedVec, extractOp.getMixedPosition()); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newExtract); return success(); @@ -1096,7 +1100,7 @@ SmallVector newDistributedShape(extractSrcType.getShape().begin(), extractSrcType.getShape().end()); for (int i = 0; i < distributedType.getRank(); ++i) - newDistributedShape[i + extractOp.getPosition().size()] = + newDistributedShape[i + extractOp.getNumIndices()] = distributedType.getDimSize(i); auto newDistributedType = VectorType::get(newDistributedShape, distributedType.getElementType()); @@ -1108,7 +1112,7 @@ Value distributedVec = newWarpOp->getResult(newRetIndices[0]); // Extract from distributed vector. Value newExtract = rewriter.create( - loc, distributedVec, extractOp.getPosition()); + loc, distributedVec, extractOp.getMixedPosition()); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newExtract); return success(); @@ -1297,13 +1301,17 @@ Location loc = insertOp.getLoc(); // "vector.insert %v, %v[] : ..." can be canonicalized to %v. - if (insertOp.getPosition().empty()) + if (insertOp.getNumIndices() == 0) return failure(); // Rewrite vector.insert with 1d dest to vector.insertelement. if (insertOp.getDestVectorType().getRank() == 1) { - assert(insertOp.getPosition().size() == 1 && "expected 1 index"); - int64_t pos = insertOp.getPosition()[0]; + if (insertOp.hasDynamicPosition()) + // TODO: Dinamic position not supported yet. + return failure(); + + assert(insertOp.getNumIndices() == 1 && "expected 1 index"); + int64_t pos = insertOp.getStaticPosition()[0]; rewriter.setInsertionPoint(insertOp); rewriter.replaceOpWithNewOp( insertOp, insertOp.getSource(), insertOp.getDest(), @@ -1323,7 +1331,7 @@ Value distributedSrc = newWarpOp->getResult(newRetIndices[0]); Value distributedDest = newWarpOp->getResult(newRetIndices[1]); Value newResult = rewriter.create( - loc, distributedSrc, distributedDest, insertOp.getPosition()); + loc, distributedSrc, distributedDest, insertOp.getMixedPosition()); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult); return success(); @@ -1354,7 +1362,7 @@ // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that // case, one lane will insert the source vector<96xf32>. The other // lanes will not do anything. - int64_t distrSrcDim = distrDestDim - insertOp.getPosition().size(); + int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices(); if (distrSrcDim >= 0) distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim); auto distrSrcType = @@ -1374,11 +1382,12 @@ if (distrSrcDim >= 0) { // Every lane inserts a small piece. newResult = rewriter.create( - loc, distributedSrc, distributedDest, insertOp.getPosition()); + loc, distributedSrc, distributedDest, insertOp.getMixedPosition()); } else { // One lane inserts the entire source vector. int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim); - SmallVector newPos(insertOp.getPosition()); + SmallVector pos = insertOp.getMixedPosition(); + SmallVector newPos = getAsIntegers(pos); // tid of inserting lane: pos / elementsPerLane Value insertingLane = rewriter.create( loc, newPos[distrDestDim] / elementsPerLane); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" @@ -176,14 +177,16 @@ // type has leading unit dims, we also trim the position array accordingly, // then (2) if source type also has leading unit dims, we need to append // zeroes to the position array accordingly. - unsigned oldPosRank = insertOp.getPosition().size(); + unsigned oldPosRank = insertOp.getNumIndices(); unsigned newPosRank = std::max(0, oldPosRank - dstDropCount); - SmallVector newPositions = - llvm::to_vector(insertOp.getPosition().take_back(newPosRank)); - newPositions.resize(newDstType.getRank() - newSrcRank, 0); + SmallVector oldPosition = insertOp.getMixedPosition(); + SmallVector newPosition = + llvm::to_vector(ArrayRef(oldPosition).take_back(newPosRank)); + newPosition.resize(newDstType.getRank() - newSrcRank, + rewriter.getI64IntegerAttr(0)); auto newInsertOp = rewriter.create( - loc, newDstType, newSrcVector, newDstVector, newPositions); + loc, newSrcVector, newDstVector, newPosition); rewriter.replaceOpWithNewOp(insertOp, oldDstType, newInsertOp); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -707,10 +707,10 @@ auto xferOp = extractOp.getVector().getDefiningOp(); SmallVector newIndices(xferOp.getIndices().begin(), xferOp.getIndices().end()); - for (const auto &it : llvm::enumerate(extractOp.getPosition())) { - int64_t offset = it.value(); - int64_t idx = - newIndices.size() - extractOp.getPosition().size() + it.index(); + for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) { + assert(pos.is() && "Unexpected non-constant index"); + int64_t offset = cast(pos.get()).getInt(); + int64_t idx = newIndices.size() - extractOp.getNumIndices() + i; OpFoldResult ofr = affine::makeComposedFoldedAffineApply( rewriter, extractOp.getLoc(), rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]}); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -598,27 +598,34 @@ unsigned expandRatio = castDstType.getNumElements() / castSrcType.getNumElements(); - uint64_t index = extractOp.getPosition()[0]; + auto getFirstIntValue = [](ArrayRef values) -> uint64_t { + assert(values[0].is() && "Unexpected non-constant index"); + return cast(values[0].get()).getInt(); + }; + + uint64_t index = getFirstIntValue(extractOp.getMixedPosition()); // Get the single scalar (as a vector) in the source value that packs the // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> - VectorType oneScalarType = - VectorType::get({1}, castSrcType.getElementType()); + Location loc = extractOp.getLoc(); Value packedValue = rewriter.create( - extractOp.getLoc(), oneScalarType, castOp.getSource(), - index / expandRatio); + loc, castOp.getSource(), index / expandRatio); + Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType()); + Value zero = rewriter.create( + loc, packedVecType, rewriter.getZeroAttr(packedVecType)); + packedValue = rewriter.create(loc, packedValue, zero, + /*position=*/0); // Cast it to a vector with the desired scalar's type. // E.g. f32 -> vector<2xf16> VectorType packedType = VectorType::get({expandRatio}, castDstType.getElementType()); - Value castedValue = rewriter.create( - extractOp.getLoc(), packedType, packedValue); + Value castedValue = + rewriter.create(loc, packedType, packedValue); // Finally extract the desired scalar. - rewriter.replaceOpWithNewOp( - extractOp, extractOp.getType(), castedValue, index % expandRatio); - + rewriter.replaceOpWithNewOp(extractOp, castedValue, + index % expandRatio); return success(); } }; diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -728,6 +728,17 @@ // ----- +func.func @extract_element_with_value_1d(%arg0: vector<16xf32>, %arg1: index) -> f32 { + %0 = vector.extract %arg0[%arg1]: vector<16xf32> + return %0 : f32 +} +// CHECK-LABEL: @extract_element_with_value_1d +// CHECK-SAME: %[[VEC:.+]]: vector<16xf32>, %[[INDEX:.+]]: index +// CHECK: %[[UC:.+]] = builtin.unrealized_conversion_cast %[[INDEX]] : index to i64 +// CHECK: llvm.extractelement %[[VEC]][%[[UC]] : i64] : vector<16xf32> + +// ----- + // CHECK-LABEL: @insert_element_0d // CHECK-SAME: %[[A:.*]]: f32, func.func @insert_element_0d(%a: f32, %b: vector) -> vector { @@ -830,6 +841,19 @@ // ----- +func.func @insert_element_with_value_1d(%arg0: vector<16xf32>, %arg1: f32, %arg2: index) + -> vector<16xf32> { + %0 = vector.insert %arg1, %arg0[%arg2]: f32 into vector<16xf32> + return %0 : vector<16xf32> +} + +// CHECK-LABEL: @insert_element_with_value_1d +// CHECK-SAME: %[[DST:.+]]: vector<16xf32>, %[[SRC:.+]]: f32, %[[INDEX:.+]]: index +// CHECK: %[[UC:.+]] = builtin.unrealized_conversion_cast %[[INDEX]] : index to i64 +// CHECK: llvm.insertelement %[[SRC]], %[[DST]][%[[UC]] : i64] : vector<16xf32> + +// ----- + func.func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref> { %0 = vector.type_cast %arg0: memref<8x8x8xf32> to memref> return %0 : memref> diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -155,8 +155,8 @@ // CHECK: spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32> // CHECK: spirv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32> func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) { - %0 = "vector.extract"(%arg0) <{position = array}> : (vector<2xf32>) -> vector<1xf32> - %1 = "vector.extract"(%arg0) <{position = array}> : (vector<2xf32>) -> f32 + %0 = "vector.extract"(%arg0) <{static_position = array}> : (vector<2xf32>) -> vector<1xf32> + %1 = "vector.extract"(%arg0) <{static_position = array}> : (vector<2xf32>) -> f32 return %0, %1: vector<1xf32>, f32 } diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -133,7 +133,7 @@ func.func @extract_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute of rank no greater than vector rank}} - %1 = "vector.extract" (%arg0) <{position = array}> : (vector<4x8x16xf32>) -> (vector<16xf32>) + %1 = "vector.extract" (%arg0) <{static_position = array}> : (vector<4x8x16xf32>) -> (vector<16xf32>) } // ----- diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -206,8 +206,9 @@ return %1 : f32 } -// CHECK-LABEL: @extract -func.func @extract(%arg0: vector<4x8x16xf32>) -> (vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32) { +// CHECK-LABEL: @extract_const_idx +func.func @extract_const_idx(%arg0: vector<4x8x16xf32>) + -> (vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32) { // CHECK: vector.extract {{.*}}[] : vector<4x8x16xf32> %0 = vector.extract %arg0[] : vector<4x8x16xf32> // CHECK: vector.extract {{.*}}[3] : vector<4x8x16xf32> @@ -219,6 +220,19 @@ return %0, %1, %2, %3 : vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32 } +// CHECK-LABEL: @extract_val_idx +// CHECK-SAME: %[[VEC:.+]]: vector<4x8x16xf32>, %[[IDX:.+]]: index +func.func @extract_val_idx(%arg0: vector<4x8x16xf32>, %idx: index) + -> (vector<8x16xf32>, vector<16xf32>, f32) { + // CHECK: vector.extract %[[VEC]][%[[IDX]]] : vector<4x8x16xf32> + %0 = vector.extract %arg0[%idx] : vector<4x8x16xf32> + // CHECK-NEXT: vector.extract %[[VEC]][%[[IDX]], %[[IDX]]] : vector<4x8x16xf32> + %1 = vector.extract %arg0[%idx, %idx] : vector<4x8x16xf32> + // CHECK-NEXT: vector.extract %[[VEC]][%[[IDX]], 5, %[[IDX]]] : vector<4x8x16xf32> + %2 = vector.extract %arg0[%idx, 5, %idx] : vector<4x8x16xf32> + return %0, %1, %2 : vector<8x16xf32>, vector<16xf32>, f32 +} + // CHECK-LABEL: @extract_0d func.func @extract_0d(%a: vector) -> f32 { // CHECK-NEXT: vector.extract %{{.*}}[] : vector @@ -242,8 +256,9 @@ return %1 : vector<16xf32> } -// CHECK-LABEL: @insert -func.func @insert(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> { +// CHECK-LABEL: @insert_const_idx +func.func @insert_const_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, + %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> { // CHECK: vector.insert %{{.*}}, %{{.*}}[3] : vector<8x16xf32> into vector<4x8x16xf32> %1 = vector.insert %c, %res[3] : vector<8x16xf32> into vector<4x8x16xf32> // CHECK: vector.insert %{{.*}}, %{{.*}}[3, 3] : vector<16xf32> into vector<4x8x16xf32> @@ -255,6 +270,19 @@ return %4 : vector<4x8x16xf32> } +// CHECK-LABEL: @insert_val_idx +// CHECK-SAME: %[[A:.+]]: f32, %[[B:.+]]: vector<16xf32>, %[[C:.+]]: vector<8x16xf32>, %[[IDX:.+]]: index +func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, + %idx: index, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> { + // CHECK: vector.insert %[[C]], %{{.*}}[%[[IDX]]] : vector<8x16xf32> into vector<4x8x16xf32> + %0 = vector.insert %c, %res[%idx] : vector<8x16xf32> into vector<4x8x16xf32> + // CHECK: vector.insert %[[B]], %{{.*}}[%[[IDX]], %[[IDX]]] : vector<16xf32> into vector<4x8x16xf32> + %1 = vector.insert %b, %res[%idx, %idx] : vector<16xf32> into vector<4x8x16xf32> + // CHECK: vector.insert %[[A]], %{{.*}}[%[[IDX]], 5, %[[IDX]]] : f32 into vector<4x8x16xf32> + %2 = vector.insert %a, %res[%idx, 5, %idx] : f32 into vector<4x8x16xf32> + return %2 : vector<4x8x16xf32> +} + // CHECK-LABEL: @insert_0d func.func @insert_0d(%a: f32, %b: vector, %c: vector<2x3xf32>) -> (vector, vector<2x3xf32>) { // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector @@ -1007,7 +1035,7 @@ %C: vector<3x[8]xf32>, %M : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> { // CHECK: vector.mask %[[M]] { vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[A]], %[[B]], %[[C]] : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32> - %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } + %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32> return %0 : vector<3x[8]xf32> } diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -286,11 +286,13 @@ func.func @bubble_down_bitcast_in_extract(%src: vector<4xf32>) -> (f16, f16) { %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16> // CHECK: %[[EXTRACT1:.+]] = vector.extract %[[SRC]][1] : vector<4xf32> - // CHECK: %[[CAST1:.+]] = vector.bitcast %[[EXTRACT1]] : vector<1xf32> to vector<2xf16> + // CHECK: %[[INSERT1:.+]] = vector.insert %[[EXTRACT1]], %{{.+}} [0] : f32 into vector<1xf32> + // CHECK: %[[CAST1:.+]] = vector.bitcast %[[INSERT1]] : vector<1xf32> to vector<2xf16> // CHECK: %[[EXTRACT2:.+]] = vector.extract %[[CAST1]][1] : vector<2xf16> %1 = vector.extract %0[3] : vector<8xf16> // CHECK: %[[EXTRACT3:.+]] = vector.extract %[[SRC]][2] : vector<4xf32> - // CHECK: %[[CAST2:.+]] = vector.bitcast %[[EXTRACT3]] : vector<1xf32> to vector<2xf16> + // CHECK: %[[INSERT3:.+]] = vector.insert %[[EXTRACT3]], %{{.+}} [0] : f32 into vector<1xf32> + // CHECK: %[[CAST2:.+]] = vector.bitcast %[[INSERT3]] : vector<1xf32> to vector<2xf16> // CHECK: %[[EXTRACT4:.+]] = vector.extract %[[CAST2]][0] : vector<2xf16> %2 = vector.extract %0[4] : vector<8xf16> // CHECK: return %[[EXTRACT2]], %[[EXTRACT4]]