3
3
4
4
import re
5
5
import typing
6
- from collections .abc import Sequence
6
+ from collections .abc import Generator , Sequence
7
7
from copy import copy
8
8
from functools import reduce , singledispatch
9
9
from math import ceil , log2 , prod
@@ -824,6 +824,16 @@ def transform(self, model: 'ModelGraph'):
824
824
return True
825
825
826
826
827
+ def get_output_quantizers (node : Layer ) -> Generator [FixedPointQuantizer , None , None ]:
828
+ for _node in get_output_layers (node ):
829
+ if isinstance (_node , FixedPointQuantizer ):
830
+ yield _node
831
+ elif isinstance (_node , (Reshape , Transpose )):
832
+ yield from get_output_quantizers (_node )
833
+ else :
834
+ raise ValueError (f'Layer { node .name } ({ node .class_name } ) unexpected input layer chain.' )
835
+
836
+
827
837
class FixInputPrecision (OptimizerPass ):
828
838
def match (self , node : Layer ):
829
839
if not isinstance (node , Input ):
@@ -833,11 +843,7 @@ def match(self, node: Layer):
833
843
return node .get_output_variable ().type .precision .width > 100
834
844
835
845
def transform (self , model , node : Layer ):
836
- out_layers : list [FixedPointQuantizer ] = get_output_layers (node ) # type: ignore
837
- for layer in out_layers :
838
- assert isinstance (
839
- layer , FixedPointQuantizer
840
- ), f'Input { node .name } connected to non-quantizer { layer .name } with non-trivial configuration'
846
+ out_layers = list (get_output_quantizers (node ))
841
847
842
848
if len (out_layers ) == 0 : # Input connected to nothing
843
849
new_type = to_hls4ml_fixed (0 , 0 , 1 , f'{ node .name } _t' )
0 commit comments