diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1292,20 +1292,25 @@ }; unsigned broadcastSrcRank = getRank(source.getType()); unsigned extractResultRank = getRank(extractOp.getType()); - if (extractResultRank < broadcastSrcRank) { - auto extractPos = extractVector(extractOp.getPosition()); - unsigned rankDiff = broadcastSrcRank - extractResultRank; - extractPos.erase( - extractPos.begin(), - std::next(extractPos.begin(), extractPos.size() - rankDiff)); - extractOp.setOperand(source); - // OpBuilder is only used as a helper to build an I64ArrayAttr. - OpBuilder b(extractOp.getContext()); - extractOp->setAttr(ExtractOp::getPositionAttrStrName(), - b.getI64ArrayAttr(extractPos)); - return extractOp.getResult(); - } - return Value(); + if (extractResultRank >= broadcastSrcRank) + return Value(); + // Check that the dimension of the result haven't been broadcasted. + auto extractVecType = extractOp.getType().dyn_cast(); + auto broadcastVecType = source.getType().dyn_cast(); + if (extractVecType && broadcastVecType && + extractVecType.getShape() != + broadcastVecType.getShape().take_back(extractResultRank)) + return Value(); + auto extractPos = extractVector(extractOp.getPosition()); + unsigned rankDiff = broadcastSrcRank - extractResultRank; + extractPos.erase(extractPos.begin(), + std::next(extractPos.begin(), extractPos.size() - rankDiff)); + extractOp.setOperand(source); + // OpBuilder is only used as a helper to build an I64ArrayAttr. + OpBuilder b(extractOp.getContext()); + extractOp->setAttr(ExtractOp::getPositionAttrStrName(), + b.getI64ArrayAttr(extractPos)); + return extractOp.getResult(); } // Fold extractOp with source coming from ShapeCast op. diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -521,6 +521,17 @@ // ----- +// CHECK-LABEL: fold_extract_broadcast_negative +// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32> +// CHECK: vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> +func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32> { + %b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32> + %r = vector.extract %b[0, 0] : vector<1x1x4xf32> + return %r : vector<4xf32> +} + +// ----- + // CHECK-LABEL: fold_extract_splat // CHECK-SAME: %[[A:.*]]: f32 // CHECK: return %[[A]] : f32