3
3
4
4
import re
5
5
import typing
6
- from collections .abc import Generator , Sequence
6
+ from collections .abc import Sequence
7
7
from copy import copy
8
8
from functools import reduce , singledispatch
9
9
from math import ceil , log2 , prod
@@ -220,8 +220,16 @@ def _produce_kif(layer: Layer) -> KIF_t:
220
220
221
221
@_produce_kif .register
222
222
def _ (layer : Input ):
223
- k = np .ones (get_output_shape (layer ), dtype = np .int8 )
224
- i = f = np .full (get_output_shape (layer ), 126 , dtype = np .int8 )
223
+ shape = get_output_shape (layer )
224
+ if layer .attributes .get ('trusted' , False ):
225
+ precision : FixedPrecisionType = layer .get_output_variable ().type .precision
226
+ k , i , f = precision .signed , precision .integer - precision .signed , precision .fractional
227
+ k = np .full (shape , k , dtype = np .int8 )
228
+ i = np .full (shape , i , dtype = np .int8 )
229
+ f = np .full (shape , f , dtype = np .int8 )
230
+ else :
231
+ k = np .ones (shape , dtype = np .int8 )
232
+ i = f = np .full (shape , 126 , dtype = np .int8 )
225
233
return k , i , f
226
234
227
235
@@ -603,8 +611,8 @@ def kif_arrs_to_ints(arr: tuple[np.ndarray, np.ndarray, np.ndarray]):
603
611
return tuple (int (np .max (a )) for a in arr )
604
612
605
613
606
- def produce_kif (layer : Layer ) -> KIF_t :
607
- if layer .attributes .get ('_produce_kif' ):
614
+ def produce_kif (layer : Layer , force_reset = False ) -> KIF_t :
615
+ if layer .attributes .get ('_produce_kif' ) and not force_reset :
608
616
return layer .attributes ['_produce_kif' ]
609
617
kif = _produce_kif (layer )
610
618
layer .attributes ['_produce_kif' ] = kif
@@ -849,7 +857,9 @@ def transform(self, model: 'ModelGraph'):
849
857
for node in model .graph .values ():
850
858
if node .attributes .get ('bit_exact_transformed' ):
851
859
continue
852
- produce_kif (node ) # Shrink FixedPointQuantizer bits when possible to be used in backward flow (requested_kif).
860
+ produce_kif (
861
+ node , force_reset = True
862
+ ) # Shrink FixedPointQuantizer bits when possible to be used in backward flow (requested_kif).
853
863
854
864
for node in model .graph .values ():
855
865
if node .attributes .get ('bit_exact_transformed' ):
@@ -858,22 +868,29 @@ def transform(self, model: 'ModelGraph'):
858
868
node .attributes ['bit_exact_transformed' ] = True
859
869
860
870
for node in model .graph .values ():
861
- if node . attributes . get ( '_produce_kif' ) :
871
+ if '_produce_kif' in node . attributes :
862
872
del node .attributes ['_produce_kif' ]
863
- if node . attributes . get ( '_request_kif' ) :
873
+ if '_request_kif' in node . attributes :
864
874
del node .attributes ['_request_kif' ]
865
875
866
876
return True
867
877
868
878
869
- def get_output_quantizers (node : Layer ) -> Generator [FixedPointQuantizer , None , None ]:
879
+ def get_output_layers_and_quantizers (
880
+ node : Layer , layers : list | None = None , quantizers : list | None = None
881
+ ) -> tuple [list [Layer ], list [FixedPointQuantizer ]]:
882
+
883
+ layers = layers if layers is not None else []
884
+ quantizers = quantizers if quantizers is not None else []
870
885
for _node in get_output_layers (node ):
871
886
if isinstance (_node , FixedPointQuantizer ):
872
- yield _node
873
- elif isinstance (_node , (Reshape , Transpose )):
874
- yield from get_output_quantizers (_node )
887
+ quantizers .append (_node )
888
+ elif isinstance (_node , (Reshape , Transpose , Concatenate )):
889
+ layers .append (_node )
890
+ get_output_layers_and_quantizers (_node , layers , quantizers )
875
891
else :
876
892
raise ValueError (f'Layer { node .name } ({ node .class_name } ) unexpected input layer chain.' )
893
+ return layers , quantizers
877
894
878
895
879
896
class FixInputPrecision (OptimizerPass ):
@@ -885,17 +902,17 @@ def match(self, node: Layer):
885
902
return node .get_output_variable ().type .precision .width > 100
886
903
887
904
def transform (self , model , node : Layer ):
888
- out_layers = list ( get_output_quantizers ( node ) )
905
+ layers , out_quantizers = get_output_layers_and_quantizers ( node )
889
906
890
- if len (out_layers ) == 0 : # Input connected to nothing
907
+ if len (out_quantizers ) == 0 : # Input connected to nothing
891
908
new_type = to_hls4ml_fixed (0 , 0 , 1 , f'{ node .name } _t' )
892
909
node .get_output_variable ().type = new_type
893
910
node .model .config .layer_name_precision [node .name ] = str (new_type )
894
911
return False
895
912
896
- sat_modes = [l .SAT for l in out_layers ]
913
+ sat_modes = [l .SAT for l in out_quantizers ]
897
914
sat_modes_set = set (sat_modes )
898
- rnd_modes = [l .RND for l in out_layers ]
915
+ rnd_modes = [l .RND for l in out_quantizers ]
899
916
rnd_modes_set = set (rnd_modes )
900
917
illegal_sat_modes = sat_modes_set - {'WRAP' , 'SAT' , 'SAT_SYM' }
901
918
illegal_rnd_modes = rnd_modes_set - {'TRN' , 'RND' }
@@ -906,7 +923,7 @@ def transform(self, model, node: Layer):
906
923
if illegal_rnd_modes :
907
924
warn (f'Saturation mode { illegal_rnd_modes } may compromise bit-exactness. Forcing at maximum 24 fractional bits.' )
908
925
909
- kifs = [_produce_kif (l ) for l in out_layers ]
926
+ kifs = [_produce_kif (l ) for l in out_quantizers ]
910
927
i = np .max ([np .max (i ) for _ , i , _ in kifs ])
911
928
k = np .max ([np .max (k ) for k , _ , _ in kifs ])
912
929
if illegal_rnd_modes :
@@ -921,4 +938,15 @@ def transform(self, model, node: Layer):
921
938
new_type .precision .saturation_mode = 'SAT'
922
939
node .get_output_variable ().type = new_type
923
940
node .model .config .layer_name_precision [node .name ] = str (new_type )
941
+ node .attributes ['trusted' ] = True
942
+
943
+ for layer in layers :
944
+ produce_kif (layer , force_reset = True )
945
+ for layer in layers :
946
+ register_precision (layer )
947
+ for layer in layers :
948
+ if '_produce_kif' in layer .attributes :
949
+ del layer .attributes ['_produce_kif' ]
950
+ if '_request_kif' in layer .attributes :
951
+ del layer .attributes ['_request_kif' ]
924
952
return False
0 commit comments