Index: libclc/CMakeLists.txt =================================================================== --- libclc/CMakeLists.txt +++ libclc/CMakeLists.txt @@ -271,11 +271,11 @@ set( spvflags --spirv-max-version=1.1 ) elseif( ${ARCH} STREQUAL "clspv" ) set( t "spir--" ) - set( build_flags ) + set( build_flags "-Wno-unknown-assumption") set( opt_flags -O3 ) elseif( ${ARCH} STREQUAL "clspv64" ) set( t "spir64--" ) - set( build_flags ) + set( build_flags "-Wno-unknown-assumption") set( opt_flags -O3 ) else() set( build_flags ) Index: libclc/clspv/lib/SOURCES =================================================================== --- libclc/clspv/lib/SOURCES +++ libclc/clspv/lib/SOURCES @@ -1,5 +1,6 @@ math/fma.cl math/nextafter.cl +shared/vstore_half.cl subnormal_config.cl ../../generic/lib/geometric/distance.cl ../../generic/lib/geometric/length.cl @@ -45,6 +46,12 @@ ../../generic/lib/math/frexp.cl ../../generic/lib/math/half_cos.cl ../../generic/lib/math/half_divide.cl +../../generic/lib/math/half_exp.cl +../../generic/lib/math/half_exp10.cl +../../generic/lib/math/half_exp2.cl +../../generic/lib/math/half_log.cl +../../generic/lib/math/half_log10.cl +../../generic/lib/math/half_log2.cl ../../generic/lib/math/half_powr.cl ../../generic/lib/math/half_recip.cl ../../generic/lib/math/half_sin.cl Index: libclc/clspv/lib/math/fma.cl =================================================================== --- libclc/clspv/lib/math/fma.cl +++ libclc/clspv/lib/math/fma.cl @@ -34,6 +34,92 @@ uint sign; }; +static uint2 u2_set(uint hi, uint lo) { + uint2 res; + res.lo = lo; + res.hi = hi; + return res; +} + +static uint2 u2_set_u(uint val) { return u2_set(0, val); } + +static uint2 u2_mul(uint a, uint b) { + uint2 res; + res.hi = mul_hi(a, b); + res.lo = a * b; + return res; +} + +static uint2 u2_sll(uint2 val, uint shift) { + if (shift == 0) + return val; + if (shift < 32) { + val.hi <<= shift; + val.hi |= val.lo >> (32 - shift); + val.lo <<= shift; + } else { + val.hi = val.lo << (shift - 32); + val.lo = 0; + } + return val; +} + +static uint2 u2_srl(uint2 val, uint shift) { + if (shift == 0) + return val; + if (shift < 32) { + val.lo >>= shift; + val.lo |= val.hi << (32 - shift); + val.hi >>= shift; + } else { + val.lo = val.hi >> (shift - 32); + val.hi = 0; + } + return val; +} + +static uint2 u2_or(uint2 a, uint b) { + a.lo |= b; + return a; +} + +static uint2 u2_and(uint2 a, uint2 b) { + a.lo &= b.lo; + a.hi &= b.hi; + return a; +} + +static uint2 u2_add(uint2 a, uint2 b) { + uint carry = (hadd(a.lo, b.lo) >> 31) & 0x1; + a.lo += b.lo; + a.hi += b.hi + carry; + return a; +} + +static uint2 u2_add_u(uint2 a, uint b) { return u2_add(a, u2_set_u(b)); } + +static uint2 u2_inv(uint2 a) { + a.lo = ~a.lo; + a.hi = ~a.hi; + return u2_add_u(a, 1); +} + +static uint u2_clz(uint2 a) { + uint leading_zeroes = clz(a.hi); + if (leading_zeroes == 32) { + leading_zeroes += clz(a.lo); + } + return leading_zeroes; +} + +static bool u2_eq(uint2 a, uint2 b) { return a.lo == b.lo && a.hi == b.hi; } + +static bool u2_zero(uint2 a) { return u2_eq(a, u2_set_u(0)); } + +static bool u2_gt(uint2 a, uint2 b) { + return a.hi > b.hi || (a.hi == b.hi && a.lo > b.lo); +} + _CLC_DEF _CLC_OVERLOAD float fma(float a, float b, float c) { /* special cases */ if (isnan(a) || isnan(b) || isnan(c) || isinf(a) || isinf(b)) { @@ -63,12 +149,9 @@ st_b.exponent = b == .0f ? 0 : ((as_uint(b) & 0x7f800000) >> 23) - 127; st_c.exponent = c == .0f ? 0 : ((as_uint(c) & 0x7f800000) >> 23) - 127; - st_a.mantissa.lo = a == .0f ? 0 : (as_uint(a) & 0x7fffff) | 0x800000; - st_b.mantissa.lo = b == .0f ? 0 : (as_uint(b) & 0x7fffff) | 0x800000; - st_c.mantissa.lo = c == .0f ? 0 : (as_uint(c) & 0x7fffff) | 0x800000; - st_a.mantissa.hi = 0; - st_b.mantissa.hi = 0; - st_c.mantissa.hi = 0; + st_a.mantissa = u2_set_u(a == .0f ? 0 : (as_uint(a) & 0x7fffff) | 0x800000); + st_b.mantissa = u2_set_u(b == .0f ? 0 : (as_uint(b) & 0x7fffff) | 0x800000); + st_c.mantissa = u2_set_u(c == .0f ? 0 : (as_uint(c) & 0x7fffff) | 0x800000); st_a.sign = as_uint(a) & 0x80000000; st_b.sign = as_uint(b) & 0x80000000; @@ -81,15 +164,13 @@ // add another bit to detect subtraction underflow struct fp st_mul; st_mul.sign = st_a.sign ^ st_b.sign; - st_mul.mantissa.hi = mul_hi(st_a.mantissa.lo, st_b.mantissa.lo); - st_mul.mantissa.lo = st_a.mantissa.lo * st_b.mantissa.lo; - uint upper_14bits = (st_mul.mantissa.lo >> 18) & 0x3fff; - st_mul.mantissa.lo <<= 14; - st_mul.mantissa.hi <<= 14; - st_mul.mantissa.hi |= upper_14bits; - st_mul.exponent = (st_mul.mantissa.lo != 0 || st_mul.mantissa.hi != 0) - ? st_a.exponent + st_b.exponent - : 0; + st_mul.mantissa = u2_sll(u2_mul(st_a.mantissa.lo, st_b.mantissa.lo), 14); + st_mul.exponent = + !u2_zero(st_mul.mantissa) ? st_a.exponent + st_b.exponent : 0; + + // FIXME: Detecting a == 0 || b == 0 above crashed GCN isel + if (st_mul.exponent == 0 && u2_zero(st_mul.mantissa)) + return c; // Mantissa is 23 fractional bits, shift it the same way as product mantissa #define C_ADJUST 37ul @@ -97,146 +178,80 @@ // both exponents are bias adjusted int exp_diff = st_mul.exponent - st_c.exponent; - uint abs_exp_diff = abs(exp_diff); - st_c.mantissa.hi = (st_c.mantissa.lo << 5); - st_c.mantissa.lo = 0; - uint2 cutoff_bits = (uint2)(0, 0); - uint2 cutoff_mask = (uint2)(0, 0); - if (abs_exp_diff < 32) { - cutoff_mask.lo = (1u << abs(exp_diff)) - 1u; - } else if (abs_exp_diff < 64) { - cutoff_mask.lo = 0xffffffff; - uint remaining = abs_exp_diff - 32; - cutoff_mask.hi = (1u << remaining) - 1u; + st_c.mantissa = u2_sll(st_c.mantissa, C_ADJUST); + uint2 cutoff_bits = u2_set_u(0); + uint2 cutoff_mask = u2_add(u2_sll(u2_set_u(1), abs(exp_diff)), + u2_set(0xffffffff, 0xffffffff)); + if (exp_diff > 0) { + cutoff_bits = + exp_diff >= 64 ? st_c.mantissa : u2_and(st_c.mantissa, cutoff_mask); + st_c.mantissa = + exp_diff >= 64 ? u2_set_u(0) : u2_srl(st_c.mantissa, exp_diff); } else { - cutoff_mask = (uint2)(0, 0); - } - uint2 tmp = (exp_diff > 0) ? st_c.mantissa : st_mul.mantissa; - if (abs_exp_diff > 0) { - cutoff_bits = abs_exp_diff >= 64 ? tmp : (tmp & cutoff_mask); - if (abs_exp_diff < 32) { - // shift some of the hi bits into the shifted lo bits. - uint shift_mask = (1u << abs_exp_diff) - 1; - uint upper_saved_bits = tmp.hi & shift_mask; - upper_saved_bits = upper_saved_bits << (32 - abs_exp_diff); - tmp.hi >>= abs_exp_diff; - tmp.lo >>= abs_exp_diff; - tmp.lo |= upper_saved_bits; - } else if (abs_exp_diff < 64) { - tmp.lo = (tmp.hi >> (abs_exp_diff - 32)); - tmp.hi = 0; - } else { - tmp = (uint2)(0, 0); - } + cutoff_bits = -exp_diff >= 64 ? st_mul.mantissa + : u2_and(st_mul.mantissa, cutoff_mask); + st_mul.mantissa = + -exp_diff >= 64 ? u2_set_u(0) : u2_srl(st_mul.mantissa, -exp_diff); } - if (exp_diff > 0) - st_c.mantissa = tmp; - else - st_mul.mantissa = tmp; struct fp st_fma; st_fma.sign = st_mul.sign; st_fma.exponent = max(st_mul.exponent, st_c.exponent); - st_fma.mantissa = (uint2)(0, 0); if (st_c.sign == st_mul.sign) { - uint carry = (hadd(st_mul.mantissa.lo, st_c.mantissa.lo) >> 31) & 0x1; - st_fma.mantissa = st_mul.mantissa + st_c.mantissa; - st_fma.mantissa.hi += carry; + st_fma.mantissa = u2_add(st_mul.mantissa, st_c.mantissa); } else { // cutoff bits borrow one - uint cutoff_borrow = ((cutoff_bits.lo != 0 || cutoff_bits.hi != 0) && - (st_mul.exponent > st_c.exponent)) - ? 1 - : 0; - uint borrow = 0; - if (st_c.mantissa.lo > st_mul.mantissa.lo) { - borrow = 1; - } else if (st_c.mantissa.lo == UINT_MAX && cutoff_borrow == 1) { - borrow = 1; - } else if ((st_c.mantissa.lo + cutoff_borrow) > st_mul.mantissa.lo) { - borrow = 1; - } - - st_fma.mantissa.lo = st_mul.mantissa.lo - st_c.mantissa.lo - cutoff_borrow; - st_fma.mantissa.hi = st_mul.mantissa.hi - st_c.mantissa.hi - borrow; + st_fma.mantissa = + u2_add(u2_add(st_mul.mantissa, u2_inv(st_c.mantissa)), + (!u2_zero(cutoff_bits) && (st_mul.exponent > st_c.exponent) + ? u2_set(0xffffffff, 0xffffffff) + : u2_set_u(0))); } // underflow: st_c.sign != st_mul.sign, and magnitude switches the sign - if (st_fma.mantissa.hi > INT_MAX) { - st_fma.mantissa = ~st_fma.mantissa; - uint carry = (hadd(st_fma.mantissa.lo, 1u) >> 31) & 0x1; - st_fma.mantissa.lo += 1; - st_fma.mantissa.hi += carry; - + if (u2_gt(st_fma.mantissa, u2_set(0x7fffffff, 0xffffffff))) { + st_fma.mantissa = u2_inv(st_fma.mantissa); st_fma.sign = st_mul.sign ^ 0x80000000; } // detect overflow/underflow - uint leading_zeroes = clz(st_fma.mantissa.hi); - if (leading_zeroes == 32) { - leading_zeroes += clz(st_fma.mantissa.lo); - } - int overflow_bits = 3 - leading_zeroes; + int overflow_bits = 3 - u2_clz(st_fma.mantissa); // adjust exponent st_fma.exponent += overflow_bits; // handle underflow if (overflow_bits < 0) { - uint shift = -overflow_bits; - if (shift < 32) { - uint shift_mask = (1u << shift) - 1; - uint saved_lo_bits = (st_fma.mantissa.lo >> (32 - shift)) & shift_mask; - st_fma.mantissa.lo <<= shift; - st_fma.mantissa.hi <<= shift; - st_fma.mantissa.hi |= saved_lo_bits; - } else if (shift < 64) { - st_fma.mantissa.hi = (st_fma.mantissa.lo << (64 - shift)); - st_fma.mantissa.lo = 0; - } else { - st_fma.mantissa = (uint2)(0, 0); - } - + st_fma.mantissa = u2_sll(st_fma.mantissa, -overflow_bits); overflow_bits = 0; } // rounding - // overflow_bits is now in the range of [0, 3] making the shift greater than - // 32 bits. - uint2 trunc_mask; - uint trunc_shift = C_ADJUST + overflow_bits - 32; - trunc_mask.hi = (1u << trunc_shift) - 1; - trunc_mask.lo = UINT_MAX; - uint2 trunc_bits = st_fma.mantissa & trunc_mask; - trunc_bits.lo |= (cutoff_bits.hi != 0 || cutoff_bits.lo != 0) ? 1 : 0; - uint2 last_bit; - last_bit.lo = 0; - last_bit.hi = st_fma.mantissa.hi & (1u << trunc_shift); - uint grs_shift = C_ADJUST - 3 + overflow_bits - 32; - uint2 grs_bits; - grs_bits.lo = 0; - grs_bits.hi = 0x4u << grs_shift; + uint2 trunc_mask = u2_add(u2_sll(u2_set_u(1), C_ADJUST + overflow_bits), + u2_set(0xffffffff, 0xffffffff)); + uint2 trunc_bits = + u2_or(u2_and(st_fma.mantissa, trunc_mask), !u2_zero(cutoff_bits)); + uint2 last_bit = + u2_and(st_fma.mantissa, u2_sll(u2_set_u(1), C_ADJUST + overflow_bits)); + uint2 grs_bits = u2_sll(u2_set_u(4), C_ADJUST - 3 + overflow_bits); // round to nearest even - if ((trunc_bits.hi > grs_bits.hi || - (trunc_bits.hi == grs_bits.hi && trunc_bits.lo > grs_bits.lo)) || - (trunc_bits.hi == grs_bits.hi && trunc_bits.lo == grs_bits.lo && - last_bit.hi != 0)) { - uint shift = C_ADJUST + overflow_bits - 32; - st_fma.mantissa.hi += 1u << shift; + if (u2_gt(trunc_bits, grs_bits) || + (u2_eq(trunc_bits, grs_bits) && !u2_zero(last_bit))) { + st_fma.mantissa = + u2_add(st_fma.mantissa, u2_sll(u2_set_u(1), C_ADJUST + overflow_bits)); } - // Shift mantissa back to bit 23 - st_fma.mantissa.lo = (st_fma.mantissa.hi >> (C_ADJUST + overflow_bits - 32)); - st_fma.mantissa.hi = 0; + // Shift mantissa back to bit 23 + st_fma.mantissa = u2_srl(st_fma.mantissa, C_ADJUST + overflow_bits); // Detect rounding overflow - if (st_fma.mantissa.lo > 0xffffff) { + if (u2_gt(st_fma.mantissa, u2_set_u(0xffffff))) { ++st_fma.exponent; - st_fma.mantissa.lo >>= 1; + st_fma.mantissa = u2_srl(st_fma.mantissa, 1); } - if (st_fma.mantissa.lo == 0) { + if (u2_zero(st_fma.mantissa)) { return 0.0f; } Index: libclc/generic/include/clc/clcfunc.h =================================================================== --- libclc/generic/include/clc/clcfunc.h +++ libclc/generic/include/clc/clcfunc.h @@ -4,9 +4,11 @@ // avoid inlines for SPIR-V related targets since we'll optimise later in the // chain -#if defined(CLC_SPIRV) || defined(CLC_SPIRV64) || defined(CLC_CLSPV) || \ - defined(CLC_CLSPV64) +#if defined(CLC_SPIRV) || defined(CLC_SPIRV64) #define _CLC_DEF +#elif defined(CLC_CLSPV) || defined(CLC_CLSPV64) +#define _CLC_DEF \ + __attribute__((noinline)) __attribute__((assume("clspv_libclc_builtin"))) #else #define _CLC_DEF __attribute__((always_inline)) #endif