Skip to content

Commit 5da636e

Browse files
committed
support more general case
1 parent 68de690 commit 5da636e

File tree

3 files changed

+49
-20
lines changed

3 files changed

+49
-20
lines changed

hls4ml/converters/keras_v3/merge.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ def handle(
3939
match cls_name:
4040
case 'Concatenate':
4141
rank = len(output_shape)
42-
class_name = f'Concatenate{rank}d'
42+
class_name = 'Concatenate'
43+
op = f'Concatenate{rank}d'
4344
config['axis'] = layer.axis
4445
case 'Dot':
4546
msg = (

hls4ml/model/optimizer/passes/bit_exact.py

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import re
55
import typing
6-
from collections.abc import Generator, Sequence
6+
from collections.abc import Sequence
77
from copy import copy
88
from functools import reduce, singledispatch
99
from math import ceil, log2, prod
@@ -220,8 +220,16 @@ def _produce_kif(layer: Layer) -> KIF_t:
220220

221221
@_produce_kif.register
222222
def _(layer: Input):
223-
k = np.ones(get_output_shape(layer), dtype=np.int8)
224-
i = f = np.full(get_output_shape(layer), 126, dtype=np.int8)
223+
shape = get_output_shape(layer)
224+
if layer.attributes.get('trusted', False):
225+
precision: FixedPrecisionType = layer.get_output_variable().type.precision
226+
k, i, f = precision.signed, precision.integer - precision.signed, precision.fractional
227+
k = np.full(shape, k, dtype=np.int8)
228+
i = np.full(shape, i, dtype=np.int8)
229+
f = np.full(shape, f, dtype=np.int8)
230+
else:
231+
k = np.ones(shape, dtype=np.int8)
232+
i = f = np.full(shape, 126, dtype=np.int8)
225233
return k, i, f
226234

227235

@@ -603,8 +611,8 @@ def kif_arrs_to_ints(arr: tuple[np.ndarray, np.ndarray, np.ndarray]):
603611
return tuple(int(np.max(a)) for a in arr)
604612

605613

606-
def produce_kif(layer: Layer) -> KIF_t:
607-
if layer.attributes.get('_produce_kif'):
614+
def produce_kif(layer: Layer, force_reset=False) -> KIF_t:
615+
if layer.attributes.get('_produce_kif') and not force_reset:
608616
return layer.attributes['_produce_kif']
609617
kif = _produce_kif(layer)
610618
layer.attributes['_produce_kif'] = kif
@@ -849,7 +857,9 @@ def transform(self, model: 'ModelGraph'):
849857
for node in model.graph.values():
850858
if node.attributes.get('bit_exact_transformed'):
851859
continue
852-
produce_kif(node) # Shrink FixedPointQuantizer bits when possible to be used in backward flow (requested_kif).
860+
produce_kif(
861+
node, force_reset=True
862+
) # Shrink FixedPointQuantizer bits when possible to be used in backward flow (requested_kif).
853863

854864
for node in model.graph.values():
855865
if node.attributes.get('bit_exact_transformed'):
@@ -858,22 +868,29 @@ def transform(self, model: 'ModelGraph'):
858868
node.attributes['bit_exact_transformed'] = True
859869

860870
for node in model.graph.values():
861-
if node.attributes.get('_produce_kif'):
871+
if '_produce_kif' in node.attributes:
862872
del node.attributes['_produce_kif']
863-
if node.attributes.get('_request_kif'):
873+
if '_request_kif' in node.attributes:
864874
del node.attributes['_request_kif']
865875

866876
return True
867877

868878

869-
def get_output_quantizers(node: Layer) -> Generator[FixedPointQuantizer, None, None]:
879+
def get_output_layers_and_quantizers(
880+
node: Layer, layers: list | None = None, quantizers: list | None = None
881+
) -> tuple[list[Layer], list[FixedPointQuantizer]]:
882+
883+
layers = layers if layers is not None else []
884+
quantizers = quantizers if quantizers is not None else []
870885
for _node in get_output_layers(node):
871886
if isinstance(_node, FixedPointQuantizer):
872-
yield _node
873-
elif isinstance(_node, (Reshape, Transpose)):
874-
yield from get_output_quantizers(_node)
887+
quantizers.append(_node)
888+
elif isinstance(_node, (Reshape, Transpose, Concatenate)):
889+
layers.append(_node)
890+
get_output_layers_and_quantizers(_node, layers, quantizers)
875891
else:
876892
raise ValueError(f'Layer {node.name} ({node.class_name}) unexpected input layer chain.')
893+
return layers, quantizers
877894

878895

879896
class FixInputPrecision(OptimizerPass):
@@ -885,17 +902,17 @@ def match(self, node: Layer):
885902
return node.get_output_variable().type.precision.width > 100
886903

887904
def transform(self, model, node: Layer):
888-
out_layers = list(get_output_quantizers(node))
905+
layers, out_quantizers = get_output_layers_and_quantizers(node)
889906

890-
if len(out_layers) == 0: # Input connected to nothing
907+
if len(out_quantizers) == 0: # Input connected to nothing
891908
new_type = to_hls4ml_fixed(0, 0, 1, f'{node.name}_t')
892909
node.get_output_variable().type = new_type
893910
node.model.config.layer_name_precision[node.name] = str(new_type)
894911
return False
895912

896-
sat_modes = [l.SAT for l in out_layers]
913+
sat_modes = [l.SAT for l in out_quantizers]
897914
sat_modes_set = set(sat_modes)
898-
rnd_modes = [l.RND for l in out_layers]
915+
rnd_modes = [l.RND for l in out_quantizers]
899916
rnd_modes_set = set(rnd_modes)
900917
illegal_sat_modes = sat_modes_set - {'WRAP', 'SAT', 'SAT_SYM'}
901918
illegal_rnd_modes = rnd_modes_set - {'TRN', 'RND'}
@@ -906,7 +923,7 @@ def transform(self, model, node: Layer):
906923
if illegal_rnd_modes:
907924
warn(f'Saturation mode {illegal_rnd_modes} may compromise bit-exactness. Forcing at maximum 24 fractional bits.')
908925

909-
kifs = [_produce_kif(l) for l in out_layers]
926+
kifs = [_produce_kif(l) for l in out_quantizers]
910927
i = np.max([np.max(i) for _, i, _ in kifs])
911928
k = np.max([np.max(k) for k, _, _ in kifs])
912929
if illegal_rnd_modes:
@@ -921,4 +938,15 @@ def transform(self, model, node: Layer):
921938
new_type.precision.saturation_mode = 'SAT'
922939
node.get_output_variable().type = new_type
923940
node.model.config.layer_name_precision[node.name] = str(new_type)
941+
node.attributes['trusted'] = True
942+
943+
for layer in layers:
944+
produce_kif(layer, force_reset=True)
945+
for layer in layers:
946+
register_precision(layer)
947+
for layer in layers:
948+
if '_produce_kif' in layer.attributes:
949+
del layer.attributes['_produce_kif']
950+
if '_request_kif' in layer.attributes:
951+
del layer.attributes['_request_kif']
924952
return False

hls4ml/model/optimizer/passes/hgq_proxy_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77

88
from hls4ml.model.attributes import Attribute, TypeAttribute, WeightAttribute
9-
from hls4ml.model.layers import Activation, Layer, Reshape, register_layer
9+
from hls4ml.model.layers import Activation, Layer, Reshape, Transpose, register_layer
1010
from hls4ml.model.optimizer import OptimizerPass, register_pass
1111
from hls4ml.model.types import FixedPrecisionType, UnspecifiedPrecisionType
1212

@@ -97,7 +97,7 @@ def propagate(self, node: Layer, precision: FixedPrecisionType):
9797
node.attributes['result_t'].precision = precision
9898
node.attributes['_result_t_propagated'] = True
9999

100-
if not isinstance(node, Reshape):
100+
if not isinstance(node, (Reshape, Transpose)):
101101
return node
102102

103103
inp_layer = get_input_layers(node)[0]

0 commit comments

Comments
 (0)