diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1496,6 +1496,7 @@ Operation *defOp = extractOp.getVector().getDefiningOp(); if (!defOp || !isa(defOp)) return failure(); + Value source = defOp->getOperand(0); if (extractOp.getType() == source.getType()) return failure(); @@ -1504,10 +1505,10 @@ }; unsigned broadcastSrcRank = getRank(source.getType()); unsigned extractResultRank = getRank(extractOp.getType()); - // We only consider the case where the rank of the source is smaller than - // the rank of the extract dst. The other cases are handled in the folding - // patterns. - if (extractResultRank <= broadcastSrcRank) + // We only consider the case where the rank of the source is less than or + // equal to the rank of the extract dst. The other cases are handled in the + // folding patterns. + if (extractResultRank < broadcastSrcRank) return failure(); rewriter.replaceOpWithNewOp( extractOp, extractOp.getType(), source); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -566,6 +566,18 @@ // ----- +// CHECK-LABEL: fold_extract_broadcast +// CHECK-SAME: %[[A:.*]]: vector<1xf32> +// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32> +// CHECK: return %[[R]] : vector<8xf32> +func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> { + %b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32> + %r = vector.extract %b[0] : vector<1x8xf32> + return %r : vector<8xf32> +} + +// ----- + // CHECK-LABEL: func @fold_extract_shapecast // CHECK-SAME: (%[[A0:.*]]: vector<5x1x3x2xf32>, %[[A1:.*]]: vector<8x4x2xf32> // CHECK: %[[R0:.*]] = vector.extract %[[A0]][1, 0, 1, 1] : vector<5x1x3x2xf32>