diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1441,11 +1441,28 @@ return tryToFoldExtractOpInPlace(valueToExtractFrom); } +/// Returns true if the operation has a 0-D vector type operand or result. +static bool hasZeroDimVectors(Operation *op) { + auto hasZeroDimVectorType = [](Type type) -> bool { + auto vecType = dyn_cast(type); + return vecType && vecType.getRank() == 0; + }; + + return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) || + llvm::any_of(op->getResultTypes(), hasZeroDimVectorType); +} + /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp. static Value foldExtractFromBroadcast(ExtractOp extractOp) { Operation *defOp = extractOp.getVector().getDefiningOp(); if (!defOp || !isa(defOp)) return Value(); + + // 0-D vectors not supported. + assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported"); + if (hasZeroDimVectors(defOp)) + return Value(); + Value source = defOp->getOperand(0); if (extractOp.getType() == source.getType()) return source; @@ -1497,6 +1514,12 @@ auto shapeCastOp = extractOp.getVector().getDefiningOp(); if (!shapeCastOp) return Value(); + + // 0-D vectors not supported. + assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported"); + if (hasZeroDimVectors(shapeCastOp)) + return Value(); + // Get the nth dimension size starting from lowest dimension. auto getDimReverse = [](VectorType type, int64_t n) { return type.getShape().take_back(n + 1).front(); @@ -1559,6 +1582,12 @@ extractOp.getVector().getDefiningOp(); if (!extractStridedSliceOp) return Value(); + + // 0-D vectors not supported. + assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported"); + if (hasZeroDimVectors(extractStridedSliceOp)) + return Value(); + // Return if 'extractStridedSliceOp' has non-unit strides. if (extractStridedSliceOp.hasNonUnitStrides()) return Value(); @@ -1595,18 +1624,27 @@ } /// Fold extract_op fed from a chain of insertStridedSlice ops. -static Value foldExtractStridedOpFromInsertChain(ExtractOp op) { - int64_t destinationRank = llvm::isa(op.getType()) - ? llvm::cast(op.getType()).getRank() - : 0; - auto insertOp = op.getVector().getDefiningOp(); +static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) { + int64_t destinationRank = + llvm::isa(extractOp.getType()) + ? llvm::cast(extractOp.getType()).getRank() + : 0; + auto insertOp = extractOp.getVector().getDefiningOp(); + if (!insertOp) + return Value(); + + // 0-D vectors not supported. + assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported"); + if (hasZeroDimVectors(insertOp)) + return Value(); + while (insertOp) { int64_t insertRankDiff = insertOp.getDestVectorType().getRank() - insertOp.getSourceVectorType().getRank(); if (destinationRank > insertOp.getSourceVectorType().getRank()) return Value(); auto insertOffsets = extractVector(insertOp.getOffsets()); - auto extractOffsets = extractVector(op.getPosition()); + auto extractOffsets = extractVector(extractOp.getPosition()); if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) { return llvm::cast(attr).getInt() != 1; @@ -1643,12 +1681,12 @@ insertRankDiff)) return Value(); } - op.getVectorMutable().assign(insertOp.getSource()); + extractOp.getVectorMutable().assign(insertOp.getSource()); // OpBuilder is only used as a helper to build an I64ArrayAttr. - OpBuilder b(op.getContext()); - op->setAttr(ExtractOp::getPositionAttrStrName(), - b.getI64ArrayAttr(offsetDiffs)); - return op.getResult(); + OpBuilder b(extractOp.getContext()); + extractOp->setAttr(ExtractOp::getPositionAttrStrName(), + b.getI64ArrayAttr(offsetDiffs)); + return extractOp.getResult(); } // If the chunk extracted is disjoint from the chunk inserted, keep // looking in the insert chain. diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -650,8 +650,7 @@ // CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32> // CHECK: %[[R:.*]] = vector.extract %[[V]][1] : vector<2x4x2xf32> // CHECK: return %[[R]] : vector<4x2xf32> -func.func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>, - %arg1 : vector<8x4x2xf32>) -> vector<4x2xf32> { +func.func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>) -> vector<4x2xf32> { %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32> %r = vector.extract %0[1] : vector<2x4x2xf32> return %r : vector<4x2xf32> @@ -659,6 +658,18 @@ // ----- +// CHECK-LABEL: dont_fold_0d_extract_shapecast +// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector to vector<1xf32> +// CHECK: %[[R:.*]] = vector.extract %[[V]][0] : vector<1xf32> +// CHECK: return %[[R]] : f32 +func.func @dont_fold_0d_extract_shapecast(%arg0 : vector) -> f32 { + %0 = vector.shape_cast %arg0 : vector to vector<1xf32> + %r = vector.extract %0[0] : vector<1xf32> + return %r : f32 +} + +// ----- + // CHECK-LABEL: dont_fold_expand_collapse // CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32> // CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32> @@ -2159,4 +2170,3 @@ %0 = vector.mask %all_true { arith.addf %a, %a : vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32> return %0 : vector<3x4xf32> } -