diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -151,7 +151,7 @@ Value valueMask = rewriter.create( loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u)); - if (auto vectorType = copySignOp.getType().dyn_cast()) { + if (auto vectorType = type.dyn_cast()) { assert(vectorType.getRank() == 1); int count = vectorType.getNumElements(); intType = VectorType::get(count, intType); diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir @@ -65,3 +65,28 @@ // CHECK-LABEL: func @copy_sign_tensor // CHECK-NEXT: math.copysign {{%.+}}, {{%.+}} : tensor<3x3xf32> // CHECK-NEXT: return +// ----- + +module attributes { spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { + +func.func @copy_sign_vector_0D(%value: vector<1xf16>, %sign: vector<1xf16>) -> vector<1xf16> { + %0 = math.copysign %value, %sign : vector<1xf16> + return %0: vector<1xf16> +} + +} + +// CHECK-LABEL: func @copy_sign_vector_0D +// CHECK-SAME: (%[[VALUE:.+]]: vector<1xf16>, %[[SIGN:.+]]: vector<1xf16>) +// CHECK: %[[CASTVAL:.+]] = builtin.unrealized_conversion_cast %[[VALUE]] : vector<1xf16> to f16 +// CHECK: %[[CASTSIGN:.+]] = builtin.unrealized_conversion_cast %[[SIGN]] : vector<1xf16> to f16 +// CHECK: %[[SMASK:.+]] = spirv.Constant -32768 : i16 +// CHECK: %[[VMASK:.+]] = spirv.Constant 32767 : i16 +// CHECK: %[[VCAST:.+]] = spirv.Bitcast %[[CASTVAL]] : f16 to i16 +// CHECK: %[[SCAST:.+]] = spirv.Bitcast %[[CASTSIGN]] : f16 to i16 +// CHECK: %[[VAND:.+]] = spirv.BitwiseAnd %[[VCAST]], %[[VMASK]] : i16 +// CHECK: %[[SAND:.+]] = spirv.BitwiseAnd %[[SCAST]], %[[SMASK]] : i16 +// CHECK: %[[OR:.+]] = spirv.BitwiseOr %[[VAND]], %[[SAND]] : i16 +// CHECK: %[[RESULT:.+]] = spirv.Bitcast %[[OR]] : i16 to f16 +// CHECK: %[[CASTRESULT:.+]] = builtin.unrealized_conversion_cast %[[RESULT]] : f16 to vector<1xf16> +// CHECK: return %[[CASTRESULT]]