diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -141,12 +141,20 @@ return failure(); Location loc = countOp.getLoc(); + Value allOneBits = getScalarOrVectorI32Constant(type, -1, rewriter, loc); + Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc); Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc); Value msb = rewriter.create(loc, adaptor.getOperand()); // We need to subtract from 31 given that the index is from the least // significant bit. - rewriter.replaceOpWithNewOp(countOp, val31, msb); + Value sub = rewriter.create(loc, val31, msb); + // If the integer has all zero bits, GLSL FindUMsb would return -1. So + // theoretically (31 - FindUMsb) should still give the correct result. + // However, certain Vulkan implementations have driver bugs regarding it. + // So handle the corner case explicity to workaround it. + Value cmp = rewriter.create(loc, msb, allOneBits); + rewriter.replaceOpWithNewOp(countOp, cmp, val32, sub); return success(); } }; diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir @@ -96,10 +96,14 @@ // CHECK-LABEL: @ctlz_scalar // CHECK-SAME: (%[[VAL:.+]]: i32) func.func @ctlz_scalar(%val: i32) -> i32 { - // CHECK: %[[V31:.+]] = spv.Constant 31 : i32 + // CHECK-DAG: %[[MAX:.+]] = spv.Constant -1 : i32 + // CHECK-DAG: %[[V32:.+]] = spv.Constant 32 : i32 + // CHECK-DAG: %[[V31:.+]] = spv.Constant 31 : i32 // CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : i32 // CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : i32 - // CHECK: return %[[SUB]] + // CHECK: %[[CMP:.+]] = spv.IEqual %[[MSB]], %[[MAX]] : i32 + // CHECK: %[[R:.+]] = spv.Select %[[CMP]], %[[V32]], %[[SUB]] : i1, i32 + // CHECK: return %[[R]] %0 = math.ctlz %val : i32 return %0 : i32 } @@ -108,6 +112,8 @@ func.func @ctlz_vector1(%val: vector<1xi32>) -> vector<1xi32> { // CHECK: spv.GLSL.FindUMsb // CHECK: spv.ISub + // CHECK: spv.IEqual + // CHECK: spv.Select %0 = math.ctlz %val : vector<1xi32> return %0 : vector<1xi32> } @@ -115,10 +121,14 @@ // CHECK-LABEL: @ctlz_vector2 // CHECK-SAME: (%[[VAL:.+]]: vector<2xi32>) func.func @ctlz_vector2(%val: vector<2xi32>) -> vector<2xi32> { + // CHECK-DAG: %[[MAX:.+]] = spv.Constant dense<-1> : vector<2xi32> + // CHECK-DAG: %[[V32:.+]] = spv.Constant dense<32> : vector<2xi32> // CHECK-DAG: %[[V31:.+]] = spv.Constant dense<31> : vector<2xi32> // CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : vector<2xi32> // CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : vector<2xi32> - // CHECK: return %[[SUB]] + // CHECK: %[[CMP:.+]] = spv.IEqual %[[MSB]], %[[MAX]] : vector<2xi32> + // CHECK: %[[R:.+]] = spv.Select %[[CMP]], %[[V32]], %[[SUB]] : vector<2xi1>, vector<2xi32> + // CHECK: return %[[R]] %0 = math.ctlz %val : vector<2xi32> return %0 : vector<2xi32> }