diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1628,8 +1628,13 @@ while (i < rankA && j < rankB) { int64_t dimA = a[i]; int64_t dimB = 1; - while (dimB < dimA && j < rankB) + while (dimB <= dimA && j < rankB) { dimB *= b[j++]; + if (dimB > dimA) { + dimB /= b[--j]; + break; + } + } if (dimA != dimB) break; ++i; diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -266,8 +266,9 @@ // CHECK-LABEL: @shape_cast func @shape_cast(%arg0 : vector<5x1x3x2xf32>, - %arg1 : tuple, vector<3x4x2xf32>>) - -> (vector<15x2xf32>, tuple, vector<12x2xf32>>) { + %arg1 : tuple, vector<3x4x2xf32>>, + %arg2 : vector<8x1xf32>) + -> (vector<15x2xf32>, tuple, vector<12x2xf32>>, vector<8xf32>) { // CHECK: vector.shape_cast %{{.*}} : vector<5x1x3x2xf32> to vector<15x2xf32> %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xf32> @@ -276,7 +277,10 @@ %1 = vector.shape_cast %arg1 : tuple, vector<3x4x2xf32>> to tuple, vector<12x2xf32>> - return %0, %1 : vector<15x2xf32>, tuple, vector<12x2xf32>> + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x1xf32> to vector<8xf32> + %2 = vector.shape_cast %arg2 : vector<8x1xf32> to vector<8xf32> + + return %0, %1, %2 : vector<15x2xf32>, tuple, vector<12x2xf32>>, vector<8xf32> } // CHECK-LABEL: @vector_fma