diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -152,17 +152,19 @@ loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u)); if (auto vectorType = copySignOp.getType().dyn_cast()) { - assert(vectorType.getRank() == 1); - int count = vectorType.getNumElements(); - intType = VectorType::get(count, intType); - - SmallVector signSplat(count, signMask); - signMask = - rewriter.create(loc, intType, signSplat); - - SmallVector valueSplat(count, valueMask); - valueMask = rewriter.create(loc, intType, - valueSplat); + if (vectorType.getNumElements() != 1) { + assert(vectorType.getRank() == 1); + int count = vectorType.getNumElements(); + intType = VectorType::get(count, intType); + + SmallVector signSplat(count, signMask); + signMask = rewriter.create(loc, intType, + signSplat); + + SmallVector valueSplat(count, valueMask); + valueMask = rewriter.create(loc, intType, + valueSplat); + } } Value lhsCast = diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir @@ -65,3 +65,28 @@ // CHECK-LABEL: func @copy_sign_tensor // CHECK-NEXT: math.copysign {{%.+}}, {{%.+}} : tensor<3x3xf32> // CHECK-NEXT: return +// ----- + +module attributes { spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { + +func.func @copy_sign_vector_0D(%value: vector<1xf16>, %sign: vector<1xf16>) -> vector<1xf16> { + %0 = math.copysign %value, %sign : vector<1xf16> + return %0: vector<1xf16> +} + +} + +// CHECK-LABEL: func @copy_sign_vector_0D +// CHECK-SAME: (%[[VALUE:.+]]: vector<1xf16>, %[[SIGN:.+]]: vector<1xf16>) +// CHECK: %[[CASTVAL:.+]] = builtin.unrealized_conversion_cast %[[VALUE]] : vector<1xf16> to f16 +// CHECK: %[[CASTSIGN:.+]] = builtin.unrealized_conversion_cast %[[SIGN]] : vector<1xf16> to f16 +// CHECK: %[[SMASK:.+]] = spirv.Constant -32768 : i16 +// CHECK: %[[VMASK:.+]] = spirv.Constant 32767 : i16 +// CHECK: %[[VCAST:.+]] = spirv.Bitcast %[[CASTVAL]] : f16 to i16 +// CHECK: %[[SCAST:.+]] = spirv.Bitcast %[[CASTSIGN]] : f16 to i16 +// CHECK: %[[VAND:.+]] = spirv.BitwiseAnd %[[VCAST]], %[[VMASK]] : i16 +// CHECK: %[[SAND:.+]] = spirv.BitwiseAnd %[[SCAST]], %[[SMASK]] : i16 +// CHECK: %[[OR:.+]] = spirv.BitwiseOr %[[VAND]], %[[SAND]] : i16 +// CHECK: %[[RESULT:.+]] = spirv.Bitcast %[[OR]] : i16 to f16 +// CHECK: %[[CASTRESULT:.+]] = builtin.unrealized_conversion_cast %[[RESULT]] : f16 to vector<1xf16> +// CHECK: return %[[CASTRESULT]]