3
3
4
4
import re
5
5
import typing
6
- from collections .abc import Sequence
6
+ from collections .abc import Generator , Sequence
7
7
from copy import copy
8
8
from functools import reduce , singledispatch
9
9
from math import ceil , log2 , prod
@@ -866,6 +866,16 @@ def transform(self, model: 'ModelGraph'):
866
866
return True
867
867
868
868
869
+ def get_output_quantizers (node : Layer ) -> Generator [FixedPointQuantizer , None , None ]:
870
+ for _node in get_output_layers (node ):
871
+ if isinstance (_node , FixedPointQuantizer ):
872
+ yield _node
873
+ elif isinstance (_node , (Reshape , Transpose )):
874
+ yield from get_output_quantizers (_node )
875
+ else :
876
+ raise ValueError (f'Layer { node .name } ({ node .class_name } ) unexpected input layer chain.' )
877
+
878
+
869
879
class FixInputPrecision (OptimizerPass ):
870
880
def match (self , node : Layer ):
871
881
if not isinstance (node , Input ):
@@ -875,11 +885,7 @@ def match(self, node: Layer):
875
885
return node .get_output_variable ().type .precision .width > 100
876
886
877
887
def transform (self , model , node : Layer ):
878
- out_layers : list [FixedPointQuantizer ] = get_output_layers (node ) # type: ignore
879
- for layer in out_layers :
880
- assert isinstance (
881
- layer , FixedPointQuantizer
882
- ), f'Input { node .name } connected to non-quantizer { layer .name } with non-trivial configuration'
888
+ out_layers = list (get_output_quantizers (node ))
883
889
884
890
if len (out_layers ) == 0 : # Input connected to nothing
885
891
new_type = to_hls4ml_fixed (0 , 0 , 1 , f'{ node .name } _t' )
0 commit comments