diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -943,7 +943,7 @@ PatternRewriter &rewriter) const { auto shape = vectorShape(op.operand().getType(), isF32); // Only support already-vectorized rsqrt's. - if (!shape.hasValue() || (*shape)[0] != 8) + if (!shape.hasValue() || shape->size() != 1 || shape->back() != 8) return rewriter.notifyMatchFailure(op, "unsupported operand type"); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -425,3 +425,21 @@ %0 = math.rsqrt %arg0 : vector<8xf32> return %0 : vector<8xf32> } + +// CHECK-LABEL: func @rsqrt_vector_8x2xf32 +// CHECK: math.rsqrt +// AVX2-LABEL: func @rsqrt_vector_8x2xf32 +// AVX2: math.rsqrt +func @rsqrt_vector_8x2xf32(%arg0: vector<8x2xf32>) -> vector<8x2xf32> { + %0 = math.rsqrt %arg0 : vector<8x2xf32> + return %0 : vector<8x2xf32> +} + +// CHECK-LABEL: func @rsqrt_vector_2x8xf32 +// CHECK: math.rsqrt +// AVX2-LABEL: func @rsqrt_vector_2x8xf32 +// AVX: math.rsqrt +func @rsqrt_vector_2x8xf32(%arg0: vector<2x8xf32>) -> vector<2x8xf32> { + %0 = math.rsqrt %arg0 : vector<2x8xf32> + return %0 : vector<2x8xf32> +}