Skip to content

Commit 544a16d

Browse files
authored
Support for floating point types (#1307)
* Support for floating point types * Use C++23 types for bfloat16 and half * Implement << op * Use floating-point headers only if the types require them * Print a warning if new C++ types (half and bloat16) are used * Fix typo
1 parent 1545563 commit 544a16d

File tree

13 files changed

+577
-12
lines changed

13 files changed

+577
-12
lines changed

hls4ml/backends/fpga/fpga_backend.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,12 @@
4040
from hls4ml.model.types import (
4141
ExponentPrecisionType,
4242
FixedPrecisionType,
43+
FloatPrecisionType,
4344
IntegerPrecisionType,
4445
PrecisionType,
4546
RoundingMode,
4647
SaturationMode,
48+
StandardFloatPrecisionType,
4749
UnspecifiedPrecisionType,
4850
XnorPrecisionType,
4951
)
@@ -362,11 +364,22 @@ def convert_precision_string(cls, precision):
362364
if precision.lower() == 'auto':
363365
return cls._convert_auto_type(precision)
364366

367+
if precision in ['float', 'double', 'half', 'bfloat16'] or precision.startswith(
368+
('ap_float', 'ac_std_float', 'std_float')
369+
):
370+
return cls._convert_standard_float_type(precision)
371+
372+
if precision.startswith('ac_float'):
373+
return cls._convert_ac_float_type(precision)
374+
365375
if precision.startswith('ac_'):
366376
return cls._convert_ac_type(precision)
367-
else:
377+
378+
if precision.startswith(('ap_', 'fixed', 'ufixed', 'int', 'uint')): # We parse AP notation even without 'ap_' prefix
368379
return cls._convert_ap_type(precision)
369380

381+
raise ValueError(f'Unsupported precision type: {precision}')
382+
370383
@classmethod
371384
def _convert_ap_type(cls, precision):
372385
'''
@@ -435,6 +448,44 @@ def _convert_ac_type(cls, precision):
435448
elif 'int' in precision:
436449
return IntegerPrecisionType(width, signed)
437450

451+
@classmethod
452+
def _convert_standard_float_type(cls, precision):
453+
# Some default values
454+
if precision == 'float':
455+
return StandardFloatPrecisionType(width=32, exponent=8, use_cpp_type=True)
456+
if precision == 'double':
457+
return StandardFloatPrecisionType(width=64, exponent=11, use_cpp_type=True)
458+
if precision == 'half':
459+
return StandardFloatPrecisionType(width=16, exponent=5, use_cpp_type=True)
460+
if precision == 'bfloat16':
461+
return StandardFloatPrecisionType(width=16, exponent=8, use_cpp_type=True)
462+
463+
# If it is a float type, parse the width and exponent
464+
bits = re.search('.+<(.+?)>', precision).group(1).split(',')
465+
if len(bits) == 2:
466+
width = int(bits[0])
467+
exponent = int(bits[1])
468+
return StandardFloatPrecisionType(width=width, exponent=exponent, use_cpp_type=False)
469+
else:
470+
raise ValueError(f'Invalid standard float precision format: {precision}')
471+
472+
@classmethod
473+
def _convert_ac_float_type(cls, precision):
474+
# If it is a float type, parse the width and exponent
475+
bits = re.search('.+<(.+?)>', precision).group(1).split(',')
476+
if len(bits) == 3 or len(bits) == 4:
477+
mantissa = int(bits[0])
478+
integer = int(bits[1])
479+
exponent = int(bits[2])
480+
width = mantissa + exponent
481+
if len(bits) == 4:
482+
round_mode = RoundingMode.from_string(bits[3])
483+
else:
484+
round_mode = None
485+
return FloatPrecisionType(width=width, integer=integer, exponent=exponent, rounding_mode=round_mode)
486+
else:
487+
raise ValueError(f'Invalid ac_float precision format: {precision}')
488+
438489
@classmethod
439490
def _convert_auto_type(cls, precision):
440491
'''

hls4ml/backends/fpga/fpga_types.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
ExponentPrecisionType,
66
ExponentType,
77
FixedPrecisionType,
8+
FloatPrecisionType,
89
IntegerPrecisionType,
910
NamedType,
1011
PackedType,
12+
StandardFloatPrecisionType,
1113
XnorPrecisionType,
1214
)
1315

@@ -51,6 +53,25 @@ def definition_cpp(self):
5153
return typestring
5254

5355

56+
class APFloatPrecisionDefinition(PrecisionDefinition):
57+
def definition_cpp(self):
58+
raise NotImplementedError(
59+
'FloatPrecisionType is not supported in AP type precision definitions. Use StandardFloatPrecisionType instead.'
60+
)
61+
62+
63+
class APStandardFloatPrecisionDefinition(PrecisionDefinition):
64+
def definition_cpp(self):
65+
typestring = str(self)
66+
if typestring.startswith('std_float'):
67+
typestring = typestring.replace('std_float', 'ap_float')
68+
elif typestring == 'half':
69+
typestring = 'std::float16_t'
70+
elif typestring == 'bfloat16':
71+
typestring = 'std::bfloat16_t'
72+
return typestring
73+
74+
5475
class ACIntegerPrecisionDefinition(PrecisionDefinition):
5576
def definition_cpp(self):
5677
typestring = f'ac_int<{self.width}, {str(self.signed).lower()}>'
@@ -94,12 +115,40 @@ def definition_cpp(self):
94115
return typestring
95116

96117

118+
class ACFloatPrecisionDefinition(PrecisionDefinition):
119+
def _rounding_mode_cpp(self, mode):
120+
if mode is not None:
121+
return 'AC_' + str(mode)
122+
123+
def definition_cpp(self):
124+
args = [
125+
self.width,
126+
self.integer,
127+
self.exponent,
128+
self._rounding_mode_cpp(self.rounding_mode),
129+
]
130+
if args[3] == 'AC_TRN':
131+
# This is the default, so we won't write the full definition for brevity
132+
args[3] = None
133+
args = ','.join([str(arg) for arg in args[:5] if arg is not None])
134+
typestring = f'ac_float<{args}>'
135+
return typestring
136+
137+
138+
class ACStandardFloatPrecisionDefinition(PrecisionDefinition):
139+
def definition_cpp(self):
140+
typestring = str(self)
141+
if typestring.startswith('std_float'):
142+
typestring = 'ac_' + typestring
143+
return typestring
144+
145+
97146
class PrecisionConverter:
98147
def convert(self, precision_type):
99148
raise NotImplementedError
100149

101150

102-
class FixedPrecisionConverter(PrecisionConverter):
151+
class FPGAPrecisionConverter(PrecisionConverter):
103152
def __init__(self, type_map, prefix):
104153
self.type_map = type_map
105154
self.prefix = prefix
@@ -124,25 +173,29 @@ def convert(self, precision_type):
124173
raise Exception(f'Cannot convert precision type to {self.prefix}: {precision_type.__class__.__name__}')
125174

126175

127-
class APTypeConverter(FixedPrecisionConverter):
176+
class APTypeConverter(FPGAPrecisionConverter):
128177
def __init__(self):
129178
super().__init__(
130179
type_map={
131180
FixedPrecisionType: APFixedPrecisionDefinition,
132181
IntegerPrecisionType: APIntegerPrecisionDefinition,
182+
FloatPrecisionType: APFloatPrecisionDefinition,
183+
StandardFloatPrecisionType: APStandardFloatPrecisionDefinition,
133184
ExponentPrecisionType: APIntegerPrecisionDefinition,
134185
XnorPrecisionType: APIntegerPrecisionDefinition,
135186
},
136187
prefix='AP',
137188
)
138189

139190

140-
class ACTypeConverter(FixedPrecisionConverter):
191+
class ACTypeConverter(FPGAPrecisionConverter):
141192
def __init__(self):
142193
super().__init__(
143194
type_map={
144195
FixedPrecisionType: ACFixedPrecisionDefinition,
145196
IntegerPrecisionType: ACIntegerPrecisionDefinition,
197+
FloatPrecisionType: ACFloatPrecisionDefinition,
198+
StandardFloatPrecisionType: ACStandardFloatPrecisionDefinition,
146199
ExponentPrecisionType: ACIntegerPrecisionDefinition,
147200
XnorPrecisionType: ACIntegerPrecisionDefinition,
148201
},

hls4ml/backends/oneapi/oneapi_types.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,15 @@
66

77
from hls4ml.backends.fpga.fpga_types import (
88
ACFixedPrecisionDefinition,
9+
ACFloatPrecisionDefinition,
910
ACIntegerPrecisionDefinition,
10-
FixedPrecisionConverter,
11+
ACStandardFloatPrecisionDefinition,
12+
FloatPrecisionType,
13+
FPGAPrecisionConverter,
1114
HLSTypeConverter,
1215
NamedTypeConverter,
1316
PrecisionDefinition,
17+
StandardFloatPrecisionType,
1418
TypeDefinition,
1519
TypePrecisionConverter,
1620
VariableDefinition,
@@ -35,12 +39,14 @@ def definition_cpp(self):
3539
return typestring
3640

3741

38-
class OneAPIACTypeConverter(FixedPrecisionConverter):
42+
class OneAPIACTypeConverter(FPGAPrecisionConverter):
3943
def __init__(self):
4044
super().__init__(
4145
type_map={
4246
FixedPrecisionType: ACFixedPrecisionDefinition,
4347
IntegerPrecisionType: ACIntegerPrecisionDefinition,
48+
FloatPrecisionType: ACFloatPrecisionDefinition,
49+
StandardFloatPrecisionType: ACStandardFloatPrecisionDefinition,
4450
ExponentPrecisionType: ACExponentPrecisionDefinition,
4551
XnorPrecisionType: ACIntegerPrecisionDefinition,
4652
},

hls4ml/backends/vitis/passes/feature_check.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from hls4ml.model.optimizer import OptimizerPass
2+
from hls4ml.model.types import StandardFloatPrecisionType
23

34

45
class ValidateConvImplementation(OptimizerPass):
@@ -87,3 +88,23 @@ def transform(self, model, node):
8788
f'WARNING: "{node.model.config.config["IOType"]}" IO Type is not supported in Vitis backend '
8889
f'for "{node.name}" ({node.class_name}). Please use "io_parallel".'
8990
)
91+
92+
93+
class ValidateStdCppTypes(OptimizerPass):
94+
def match(self, node):
95+
return True
96+
97+
def transform(self, model, node):
98+
prec_types = [prec_type.precision for prec_type in node.get_layer_precision().values()]
99+
prec_types = [
100+
prec_type
101+
for prec_type in prec_types
102+
if isinstance(prec_type, StandardFloatPrecisionType)
103+
and prec_type.use_cpp_type
104+
and str(prec_type) not in ('float', 'double')
105+
]
106+
if len(prec_types) > 0:
107+
print(
108+
f'WARNING: Layer "{node.name}" uses C++ types that are not synthesizable with Vitis backend. '
109+
'Use only for testing purposes.'
110+
)

hls4ml/backends/vitis/vitis_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def _register_flows(self):
3030
'vitis:validate_resource_unrolled_strategy',
3131
'vitis:validate_bidirectional_merge_mode',
3232
'vitis:validate_bidirectional_io_type',
33+
'vitis:validate_std_cpp_types',
3334
]
3435
validation_flow = register_flow('validation', validation_passes, requires=['vivado:init_layers'], backend=self.name)
3536

0 commit comments

Comments
 (0)