
Let’s go into bits and bytes

For 50 years, from the time of Kernighan, Ritchie, and their 1st edition of the C Language book, it was known that a single-precision “float” type has a 32-bit size and a double-precision type has 64 bits. There was also an 80-bit “long double” type with prolonged precision, and all these types covered just about all the needs for floating-point data processing. Nonetheless, throughout the previous couple of years, the arrival of huge neural network models required developers to maneuver into one other a part of the spectrum and to shrink floating point types as much as possible.
Truthfully, I used to be surprised after I discovered that the 4-bit floating-point format exists. How on Earth can it’s possible? The perfect method to know is to check it on our own. In this text, we are going to discover the most well-liked floating point formats, make a straightforward neural network, and see how it really works.
Let’s start.
A “Standard” 32-bit Floating point
Before going into “extreme” formats, let’s recall a typical one. An IEEE 754 standard for floating-point arithmetic was established in 1985 by the Institute of Electrical and Electronics Engineers (IEEE). A typical number in a 32-float type looks like this:
Here, the primary bit is an indication, the subsequent 8 bits represent an exponent, and the last bits represent the mantissa. The ultimate value is calculated using the formula:
This easy helper function allows us to print a floating point value in binary form:
import structdef print_float32(val: float):
""" Print Float32 in a binary form """
m = struct.unpack('I', struct.pack('f', val))[0]
return format(m, 'b').zfill(32)
print_float32(0.15625)
# > 00111110001000000000000000000000
Let’s also make one other helper for backward conversion, which shall be useful later:
def ieee_754_conversion(sign, exponent_raw, mantissa, exp_len=8, mant_len=23):
""" Convert binary data into the floating point value """
sign_mult = -1 if sign == 1 else 1
exponent = exponent_raw - (2 ** (exp_len - 1) - 1)
mant_mult = 1
for b in range(mant_len - 1, -1, -1):
if mantissa & (2 **…