diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -216,6 +216,7 @@ if (!vecType) continue; if (maxSize < vecType.getNumElements()) { + maxSize = vecType.getNumElements(); largestShape.assign(vecType.getShape().begin(), vecType.getShape().end()); } diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -169,7 +169,7 @@ %11 = mulf %arg5, %8 : f32 %12 = rsqrt %arg5 : f32 %13 = select %7, %arg5, %arg6 : f32 - %14 = subf %arg5, %arg6 : f32 + %14 = subf %arg5, %arg4 : f32 %15 = tanh %arg5 : f32 linalg.yield %6, %8, %c1_f32, %9, %10, %11, %12, %13, %14, %15 : f32, f32, f32, f32, f32, f32, f32, f32, f32, f32 @@ -196,7 +196,8 @@ // CHECK: %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32> // CHECK: %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32> // CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32> -// CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V1]] : vector<4x256xf32> +// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> +// CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32> // CHECK: %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32> // CHECK: vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> // CHECK: vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>