diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1441,11 +1441,27 @@ return tryToFoldExtractOpInPlace(valueToExtractFrom); } +/// Returns true if the operation has a 0-D vector type operand or result. +static bool hasZeroDimVectors(Operation *op) { + auto hasZeroDimVectorType = [](Type type) -> bool { + auto vecType = dyn_cast(type); + return vecType && vecType.getRank() == 0; + }; + + return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) || + llvm::any_of(op->getResultTypes(), hasZeroDimVectorType); +} + /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp. static Value foldExtractFromBroadcast(ExtractOp extractOp) { Operation *defOp = extractOp.getVector().getDefiningOp(); if (!defOp || !isa(defOp)) return Value(); + + // 0-D vectors not supported. + if (hasZeroDimVectors(extractOp) || hasZeroDimVectors(defOp)) + return Value(); + Value source = defOp->getOperand(0); if (extractOp.getType() == source.getType()) return source; @@ -1497,6 +1513,11 @@ auto shapeCastOp = extractOp.getVector().getDefiningOp(); if (!shapeCastOp) return Value(); + + // 0-D vectors not supported. + if (hasZeroDimVectors(extractOp) || hasZeroDimVectors(shapeCastOp)) + return Value(); + // Get the nth dimension size starting from lowest dimension. auto getDimReverse = [](VectorType type, int64_t n) { return type.getShape().take_back(n + 1).front(); @@ -1559,6 +1580,11 @@ extractOp.getVector().getDefiningOp(); if (!extractStridedSliceOp) return Value(); + + // 0-D vectors not supported. + if (hasZeroDimVectors(extractOp) || hasZeroDimVectors(extractStridedSliceOp)) + return Value(); + // Return if 'extractStridedSliceOp' has non-unit strides. if (extractStridedSliceOp.hasNonUnitStrides()) return Value(); @@ -1595,18 +1621,26 @@ } /// Fold extract_op fed from a chain of insertStridedSlice ops. -static Value foldExtractStridedOpFromInsertChain(ExtractOp op) { - int64_t destinationRank = llvm::isa(op.getType()) - ? llvm::cast(op.getType()).getRank() - : 0; - auto insertOp = op.getVector().getDefiningOp(); +static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) { + int64_t destinationRank = + llvm::isa(extractOp.getType()) + ? llvm::cast(extractOp.getType()).getRank() + : 0; + auto insertOp = extractOp.getVector().getDefiningOp(); + if (!insertOp) + return Value(); + + // 0-D vectors not supported. + if (hasZeroDimVectors(extractOp) || hasZeroDimVectors(insertOp)) + return Value(); + while (insertOp) { int64_t insertRankDiff = insertOp.getDestVectorType().getRank() - insertOp.getSourceVectorType().getRank(); if (destinationRank > insertOp.getSourceVectorType().getRank()) return Value(); auto insertOffsets = extractVector(insertOp.getOffsets()); - auto extractOffsets = extractVector(op.getPosition()); + auto extractOffsets = extractVector(extractOp.getPosition()); if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) { return llvm::cast(attr).getInt() != 1; @@ -1643,12 +1677,12 @@ insertRankDiff)) return Value(); } - op.getVectorMutable().assign(insertOp.getSource()); + extractOp.getVectorMutable().assign(insertOp.getSource()); // OpBuilder is only used as a helper to build an I64ArrayAttr. - OpBuilder b(op.getContext()); - op->setAttr(ExtractOp::getPositionAttrStrName(), + OpBuilder b(extractOp.getContext()); + extractOp->setAttr(ExtractOp::getPositionAttrStrName(), b.getI64ArrayAttr(offsetDiffs)); - return op.getResult(); + return extractOp.getResult(); } // If the chunk extracted is disjoint from the chunk inserted, keep // looking in the insert chain.