diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -321,7 +321,7 @@ ``` }]; let extraClassDeclaration = [{ - VectorType getVectorType() { + VectorType getSourceVectorType() { return getVector().getType().cast(); } }]; @@ -449,7 +449,7 @@ }]; let extraClassDeclaration = [{ Type getSourceType() { return getSource().getType(); } - VectorType getVectorType() { + VectorType getResultVectorType() { return getVector().getType().cast(); } @@ -466,7 +466,7 @@ /// `value`, `dstShape` and `broadcastedDims` must be properly specified or /// the helper will assert. This means: /// 1. `dstShape` must not be empty. - /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)] + /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getResultVectorType)] /// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims` // must match the `value` shape. static Value createOrFoldBroadcastOp( @@ -537,7 +537,7 @@ VectorType getV2VectorType() { return getV2().getType().cast(); } - VectorType getVectorType() { + VectorType getResultVectorType() { return getVector().getType().cast(); } }]; @@ -584,7 +584,7 @@ OpBuilder<(ins "Value":$source)>, ]; let extraClassDeclaration = [{ - VectorType getVectorType() { + VectorType getSourceVectorType() { return getVector().getType().cast(); } }]; @@ -619,7 +619,7 @@ ]; let extraClassDeclaration = [{ static StringRef getPositionAttrStrName() { return "position"; } - VectorType getVectorType() { + VectorType getSourceVectorType() { return getVector().getType().cast(); } static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); @@ -996,7 +996,7 @@ ? VectorType() : (*getAcc().begin()).getType().cast(); } - VectorType getVectorType() { + VectorType getResultVectorType() { return getResult().getType().cast(); } static constexpr StringRef getKindAttrStrName() { @@ -1172,7 +1172,9 @@ static StringRef getOffsetsAttrStrName() { return "offsets"; } static StringRef getSizesAttrStrName() { return "sizes"; } static StringRef getStridesAttrStrName() { return "strides"; } - VectorType getVectorType(){ return getVector().getType().cast(); } + VectorType getSourceVectorType() { + return getVector().getType().cast(); + } void getOffsets(SmallVectorImpl &results); bool hasNonUnitStrides() { return llvm::any_of(getStrides(), [](Attribute attr) { @@ -2424,10 +2426,10 @@ OpBuilder<(ins "Value":$vector, "ArrayRef":$transp)> ]; let extraClassDeclaration = [{ - VectorType getVectorType() { + VectorType getSourceVectorType() { return getVector().getType().cast(); } - VectorType getResultType() { + VectorType getResultVectorType() { return getResult().getType().cast(); } void getTransp(SmallVectorImpl &results); diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -203,7 +203,7 @@ /// Return true if this is a broadcast from scalar to a 2D vector. static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) { - return broadcastOp.getVectorType().getRank() == 2; + return broadcastOp.getResultVectorType().getRank() == 2; } /// Return true if this integer extend op can be folded into a contract op. @@ -949,7 +949,7 @@ SmallVector sizes; populateFromInt64AttrArray(op.getSizes(), sizes); - ArrayRef warpVectorShape = op.getVectorType().getShape(); + ArrayRef warpVectorShape = op.getSourceVectorType().getShape(); // Compute offset in vector registers. Note that the mma.sync vector registers // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector @@ -1045,7 +1045,7 @@ assert(broadcastSupportsMMAMatrixType(op)); const char *fragType = inferFragType(op); - auto vecType = op.getVectorType(); + auto vecType = op.getResultVectorType(); gpu::MMAMatrixType type = gpu::MMAMatrixType::get( vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); auto matrix = rewriter.create( diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -939,7 +939,7 @@ auto loc = shuffleOp->getLoc(); auto v1Type = shuffleOp.getV1VectorType(); auto v2Type = shuffleOp.getV2VectorType(); - auto vectorType = shuffleOp.getVectorType(); + auto vectorType = shuffleOp.getResultVectorType(); Type llvmType = typeConverter->convertType(vectorType); auto maskArrayAttr = shuffleOp.getMask(); @@ -1002,7 +1002,7 @@ LogicalResult matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto vectorType = extractEltOp.getVectorType(); + auto vectorType = extractEltOp.getSourceVectorType(); auto llvmType = typeConverter->convertType(vectorType.getElementType()); // Bail if result type cannot be lowered. diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -83,7 +83,8 @@ LogicalResult matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type resultType = getTypeConverter()->convertType(castOp.getVectorType()); + Type resultType = + getTypeConverter()->convertType(castOp.getResultVectorType()); if (!resultType) return failure(); @@ -92,10 +93,10 @@ return success(); } - SmallVector source(castOp.getVectorType().getNumElements(), + SmallVector source(castOp.getResultVectorType().getNumElements(), adaptor.getSource()); rewriter.replaceOpWithNewOp( - castOp, castOp.getVectorType(), source); + castOp, castOp.getResultVectorType(), source); return success(); } }; @@ -405,7 +406,7 @@ LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto oldResultType = shuffleOp.getVectorType(); + auto oldResultType = shuffleOp.getResultVectorType(); if (!spirv::CompositeType::isValid(oldResultType)) return failure(); Type newResultType = getTypeConverter()->convertType(oldResultType); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -416,7 +416,7 @@ LogicalResult ReductionOp::verify() { // Verify for 0-D and 1-D vector. - int64_t rank = getVectorType().getRank(); + int64_t rank = getSourceVectorType().getRank(); if (rank > 1) return emitOpError("unsupported reduction rank: ") << rank; @@ -465,7 +465,7 @@ /// Returns the mask type expected by this operation. Type ReductionOp::getExpectedMaskType() { - auto vecType = getVectorType(); + auto vecType = getSourceVectorType(); return vecType.cloneWith(std::nullopt, IntegerType::get(vecType.getContext(), /*width=*/1)); } @@ -515,7 +515,7 @@ } std::optional> ReductionOp::getShapeForUnroll() { - return llvm::to_vector<4>(getVectorType().getShape()); + return llvm::to_vector<4>(getSourceVectorType().getShape()); } namespace { @@ -530,7 +530,7 @@ if (maskableOp.isMasked()) return failure(); - auto vectorType = reductionOp.getVectorType(); + auto vectorType = reductionOp.getSourceVectorType(); if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1) return failure(); @@ -1074,7 +1074,7 @@ } LogicalResult vector::ExtractElementOp::verify() { - VectorType vectorType = getVectorType(); + VectorType vectorType = getSourceVectorType(); if (vectorType.getRank() == 0) { if (getPosition()) return emitOpError("expected position to be empty with 0-D vector"); @@ -1167,13 +1167,14 @@ LogicalResult vector::ExtractOp::verify() { auto positionAttr = getPosition().getValue(); - if (positionAttr.size() > static_cast(getVectorType().getRank())) + if (positionAttr.size() > + static_cast(getSourceVectorType().getRank())) return emitOpError( "expected position attribute of rank smaller than vector rank"); for (const auto &en : llvm::enumerate(positionAttr)) { auto attr = en.value().dyn_cast(); if (!attr || attr.getInt() < 0 || - attr.getInt() >= getVectorType().getDimSize(en.index())) + attr.getInt() >= getSourceVectorType().getDimSize(en.index())) return emitOpError("expected position attribute #") << (en.index() + 1) << " to be a non-negative integer smaller than the corresponding " @@ -1314,7 +1315,7 @@ ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState( ExtractOp e) - : extractOp(e), vectorRank(extractOp.getVectorType().getRank()), + : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()), extractedRank(extractOp.getPosition().size()) { assert(vectorRank >= extractedRank && "extracted pos overflow"); sentinels.reserve(vectorRank - extractedRank); @@ -1510,7 +1511,8 @@ int64_t stride = 1; for (int64_t i = 0, e = extractedPos.size(); i < e; i++) { strides.push_back(stride); - stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank); + stride *= + getDimReverse(extractOp.getSourceVectorType(), i + destinationRank); } int64_t position = linearize(extractedPos, strides); @@ -1552,7 +1554,7 @@ size_t lastOffset = sliceOffsets.size() - 1; if (sliceOffsets.back() != 0 || extractStridedSliceOp.getType().getDimSize(lastOffset) != - extractStridedSliceOp.getVectorType().getDimSize(lastOffset)) + extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset)) break; sliceOffsets.pop_back(); } @@ -1561,8 +1563,8 @@ destinationRank = vecType.getRank(); // The dimensions of the result need to be untouched by the // extractStridedSlice op. - if (destinationRank > - extractStridedSliceOp.getVectorType().getRank() - sliceOffsets.size()) + if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() - + sliceOffsets.size()) return Value(); auto extractedPos = extractVector(extractOp.getPosition()); assert(extractedPos.size() >= sliceOffsets.size()); @@ -1827,7 +1829,7 @@ if (!srcVectorType) return {}; return ::computeBroadcastedUnitDims(srcVectorType.getShape(), - getVectorType().getShape()); + getResultVectorType().getShape()); } /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the @@ -1973,8 +1975,8 @@ LogicalResult BroadcastOp::verify() { std::pair mismatchingDims; - BroadcastableToResult res = - isBroadcastableTo(getSourceType(), getVectorType(), &mismatchingDims); + BroadcastableToResult res = isBroadcastableTo( + getSourceType(), getResultVectorType(), &mismatchingDims); if (res == BroadcastableToResult::Success) return success(); if (res == BroadcastableToResult::SourceRankHigher) @@ -1988,11 +1990,11 @@ } OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { - if (getSourceType() == getVectorType()) + if (getSourceType() == getResultVectorType()) return getSource(); if (!adaptor.getSource()) return {}; - auto vectorType = getVectorType(); + auto vectorType = getResultVectorType(); if (adaptor.getSource().isa()) return DenseElementsAttr::get(vectorType, adaptor.getSource()); if (auto attr = adaptor.getSource().dyn_cast()) @@ -2011,8 +2013,9 @@ auto srcBroadcast = broadcastOp.getSource().getDefiningOp(); if (!srcBroadcast) return failure(); - rewriter.replaceOpWithNewOp( - broadcastOp, broadcastOp.getVectorType(), srcBroadcast.getSource()); + rewriter.replaceOpWithNewOp(broadcastOp, + broadcastOp.getResultVectorType(), + srcBroadcast.getSource()); return success(); } }; @@ -2035,7 +2038,7 @@ } LogicalResult ShuffleOp::verify() { - VectorType resultType = getVectorType(); + VectorType resultType = getResultVectorType(); VectorType v1Type = getV1VectorType(); VectorType v2Type = getV2VectorType(); // Verify ranks. @@ -2143,7 +2146,7 @@ } } - return DenseElementsAttr::get(getVectorType(), results); + return DenseElementsAttr::get(getResultVectorType(), results); } namespace { @@ -2764,7 +2767,7 @@ Type tRHS = getOperandTypeRHS(); VectorType vLHS = getOperandVectorTypeLHS(), vRHS = tRHS.dyn_cast(), - vACC = getOperandVectorTypeACC(), vRES = getVectorType(); + vACC = getOperandVectorTypeACC(), vRES = getResultVectorType(); if (vLHS.getRank() != 1) return emitOpError("expected 1-d vector for operand #1"); @@ -2805,7 +2808,7 @@ /// Returns the mask type expected by this operation. Mostly used for /// verification purposes. It requires the operation to be vectorized." Type OuterProductOp::getExpectedMaskType() { - auto vecType = this->getVectorType(); + auto vecType = this->getResultVectorType(); return VectorType::get(vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1)); } @@ -2913,7 +2916,7 @@ } LogicalResult ExtractStridedSliceOp::verify() { - auto type = getVectorType(); + auto type = getSourceVectorType(); auto offsets = getOffsetsAttr(); auto sizes = getSizesAttr(); auto strides = getStridesAttr(); @@ -2944,8 +2947,8 @@ /*halfOpen=*/false))) return failure(); - auto resultType = - inferStridedSliceOpResultType(getVectorType(), offsets, sizes, strides); + auto resultType = inferStridedSliceOpResultType(getSourceVectorType(), + offsets, sizes, strides); if (getResult().getType() != resultType) return emitOpError("expected result type to be ") << resultType; @@ -2966,7 +2969,7 @@ ArrayAttr extractSizes = op.getSizes(); auto insertOp = op.getVector().getDefiningOp(); while (insertOp) { - if (op.getVectorType().getRank() != + if (op.getSourceVectorType().getRank() != insertOp.getSourceVectorType().getRank()) return failure(); ArrayAttr insertOffsets = insertOp.getOffsets(); @@ -3020,7 +3023,7 @@ } OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) { - if (getVectorType() == getResult().getType()) + if (getSourceVectorType() == getResult().getType()) return getVector(); if (succeeded(foldExtractStridedOpFromInsertChain(*this))) return getResult(); @@ -5113,7 +5116,7 @@ // Eliminate splat constant transpose ops. if (auto attr = adaptor.getVector().dyn_cast_or_null()) if (attr.isSplat()) - return attr.reshape(getResultType()); + return attr.reshape(getResultVectorType()); // Eliminate identity transpose ops. This happens when the dimensions of the // input vector remain in their original order after the transpose operation. @@ -5131,8 +5134,8 @@ } LogicalResult vector::TransposeOp::verify() { - VectorType vectorType = getVectorType(); - VectorType resultType = getResultType(); + VectorType vectorType = getSourceVectorType(); + VectorType resultType = getResultVectorType(); int64_t rank = resultType.getRank(); if (vectorType.getRank() != rank) return emitOpError("vector result rank mismatch: ") << rank; @@ -5156,7 +5159,7 @@ } std::optional> TransposeOp::getShapeForUnroll() { - return llvm::to_vector<4>(getResultType().getShape()); + return llvm::to_vector<4>(getResultVectorType().getShape()); } namespace { @@ -5215,7 +5218,7 @@ auto srcVectorType = bcastOp.getSourceType().dyn_cast(); if (!srcVectorType || srcVectorType.getNumElements() == 1) { rewriter.replaceOpWithNewOp( - transposeOp, transposeOp.getResultType(), bcastOp.getSource()); + transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource()); return success(); } @@ -5235,7 +5238,7 @@ return failure(); rewriter.replaceOpWithNewOp( - transposeOp, transposeOp.getResultType(), splatOp.getInput()); + transposeOp, transposeOp.getResultVectorType(), splatOp.getInput()); return success(); } }; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -897,7 +897,7 @@ return failure(); unsigned int operandNumber = operand->getOperandNumber(); auto extractOp = operand->get().getDefiningOp(); - VectorType extractSrcType = extractOp.getVectorType(); + VectorType extractSrcType = extractOp.getSourceVectorType(); Location loc = extractOp.getLoc(); // "vector.extract %v[] : vector" is an invalid op. @@ -930,7 +930,7 @@ SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( rewriter, warpOp, {extractOp.getVector()}, - {extractOp.getVectorType()}, newRetIndices); + {extractOp.getSourceVectorType()}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); Value distributedVec = newWarpOp->getResult(newRetIndices[0]); // Extract from distributed vector. @@ -994,7 +994,7 @@ return failure(); unsigned int operandNumber = operand->getOperandNumber(); auto extractOp = operand->get().getDefiningOp(); - VectorType extractSrcType = extractOp.getVectorType(); + VectorType extractSrcType = extractOp.getSourceVectorType(); bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1; Type elType = extractSrcType.getElementType(); VectorType distributedVecType; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -48,7 +48,7 @@ // vector.extract_strided_slice requires the input and output vector to have // the same rank. Here we drop leading one dimensions from the input vector // type to make sure we don't cause mismatch. - VectorType oldSrcType = extractOp.getVectorType(); + VectorType oldSrcType = extractOp.getSourceVectorType(); VectorType newSrcType = trimLeadingOneDims(oldSrcType); if (newSrcType.getRank() == oldSrcType.getRank()) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -264,7 +264,7 @@ LogicalResult matchAndRewrite(vector::BroadcastOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); - VectorType dstType = op.getVectorType(); + VectorType dstType = op.getResultVectorType(); VectorType srcType = op.getSourceType().dyn_cast(); Type eltType = dstType.getElementType(); @@ -402,8 +402,8 @@ auto loc = op.getLoc(); Value input = op.getVector(); - VectorType inputType = op.getVectorType(); - VectorType resType = op.getResultType(); + VectorType inputType = op.getSourceVectorType(); + VectorType resType = op.getResultVectorType(); // Set up convenience transposition table. SmallVector transp; @@ -490,7 +490,7 @@ PatternRewriter &rewriter) const override { auto loc = op.getLoc(); - VectorType srcType = op.getVectorType(); + VectorType srcType = op.getSourceVectorType(); if (srcType.getRank() != 2) return rewriter.notifyMatchFailure(op, "Not a 2D transpose"); @@ -516,8 +516,8 @@ Value shuffled = rewriter.create(loc, casted, casted, mask); - rewriter.replaceOpWithNewOp(op, op.getResultType(), - shuffled); + rewriter.replaceOpWithNewOp( + op, op.getResultVectorType(), shuffled); return success(); } @@ -550,7 +550,7 @@ VectorType lhsType = op.getOperandVectorTypeLHS(); VectorType rhsType = op.getOperandTypeRHS().dyn_cast(); - VectorType resType = op.getVectorType(); + VectorType resType = op.getResultVectorType(); Type eltType = resType.getElementType(); bool isInt = eltType.isa(); Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0]; @@ -1206,15 +1206,16 @@ continue; // contractionOp can only take vector as operands. auto srcType = broadcast.getSourceType().dyn_cast(); - if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank()) + if (!srcType || + srcType.getRank() == broadcast.getResultVectorType().getRank()) continue; int64_t rankDiff = - broadcast.getVectorType().getRank() - srcType.getRank(); + broadcast.getResultVectorType().getRank() - srcType.getRank(); bool innerDimBroadcast = false; SmallVector originalDims; for (const auto &dim : llvm::enumerate(srcType.getShape())) { - if (dim.value() != - broadcast.getVectorType().getDimSize(rankDiff + dim.index())) { + if (dim.value() != broadcast.getResultVectorType().getDimSize( + rankDiff + dim.index())) { innerDimBroadcast = true; break; } @@ -1230,7 +1231,7 @@ // of non-unit size. bool nonUnitDimReductionBroadcast = false; for (int64_t i = 0; i < rankDiff; ++i) { - if (broadcast.getVectorType().getDimSize(i) != 1 && + if (broadcast.getResultVectorType().getDimSize(i) != 1 && isReductionIterator(contractOp.getIteratorTypes() .getValue()[map.getDimPosition(i)])) { nonUnitDimReductionBroadcast = true; @@ -1241,8 +1242,8 @@ continue; AffineMap broadcastMap = - AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims, - contractOp.getContext()); + AffineMap::get(broadcast.getResultVectorType().getRank(), 0, + originalDims, contractOp.getContext()); map = broadcastMap.compose(map); *operand = broadcast.getSource(); changed = true; @@ -1361,7 +1362,7 @@ auto transposeOp = operand.getDefiningOp(); if (transposeOp) { transposeMaps.push_back(transposeOp.getTransp()); - srcType = transposeOp.getVectorType(); + srcType = transposeOp.getSourceVectorType(); } else if (!matchPattern(operand, m_Constant())) { return failure(); } @@ -2374,7 +2375,7 @@ LogicalResult matchAndRewrite(vector::ExtractOp extractOp, PatternRewriter &rewriter) const override { // Only support extracting scalars for now. - if (extractOp.getVectorType().getRank() != 1) + if (extractOp.getSourceVectorType().getRank() != 1) return failure(); auto castOp = extractOp.getVector().getDefiningOp(); @@ -2466,7 +2467,7 @@ [](const APInt &val) { return !val.isOneValue(); })) return failure(); - unsigned rank = extractOp.getVectorType().getRank(); + unsigned rank = extractOp.getSourceVectorType().getRank(); assert(castDstLastDim % castSrcLastDim == 0); int64_t expandRatio = castDstLastDim / castSrcLastDim; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -602,12 +602,12 @@ LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, PatternRewriter &rewriter) const override { - if (transposeOp.getResultType().getRank() == 0) + if (transposeOp.getResultVectorType().getRank() == 0) return failure(); auto targetShape = getTargetShape(options, transposeOp); if (!targetShape) return failure(); - auto originalVectorType = transposeOp.getResultType(); + auto originalVectorType = transposeOp.getResultVectorType(); SmallVector strides(targetShape->size(), 1); Location loc = transposeOp.getLoc(); ArrayRef originalSize = originalVectorType.getShape(); diff --git a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp --- a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp @@ -252,7 +252,7 @@ // Check if the source vector type is supported. AVX2 patterns can only be // applied to f32 vector types with two dimensions greater than one. - VectorType srcType = op.getVectorType(); + VectorType srcType = op.getSourceVectorType(); if (!srcType.getElementType().isF32()) return rewriter.notifyMatchFailure(op, "Unsupported vector element type"); @@ -287,7 +287,7 @@ // Reshape the n-D input vector with only two dimensions greater than one // to a 2-D vector. auto flattenedType = - VectorType::get({n * m}, op.getVectorType().getElementType()); + VectorType::get({n * m}, op.getSourceVectorType().getElementType()); auto reshInputType = VectorType::get({m, n}, srcType.getElementType()); auto reshInput = ib.create(flattenedType, op.getVector()); @@ -315,7 +315,7 @@ // We have to transpose their dimensions and retrieve its original rank // (e.g., 1x8x1x4x1). res = ib.create(flattenedType, res); - res = ib.create(op.getResultType(), res); + res = ib.create(op.getResultVectorType(), res); rewriter.replaceOp(op, res); return success(); };