Skip to content

Commit 68de690

Browse files
committed
allow transpose after inputs
1 parent a5383c0 commit 68de690

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

hls4ml/model/optimizer/passes/bit_exact.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import re
55
import typing
6-
from collections.abc import Sequence
6+
from collections.abc import Generator, Sequence
77
from copy import copy
88
from functools import reduce, singledispatch
99
from math import ceil, log2, prod
@@ -866,6 +866,16 @@ def transform(self, model: 'ModelGraph'):
866866
return True
867867

868868

869+
def get_output_quantizers(node: Layer) -> Generator[FixedPointQuantizer, None, None]:
870+
for _node in get_output_layers(node):
871+
if isinstance(_node, FixedPointQuantizer):
872+
yield _node
873+
elif isinstance(_node, (Reshape, Transpose)):
874+
yield from get_output_quantizers(_node)
875+
else:
876+
raise ValueError(f'Layer {node.name} ({node.class_name}) unexpected input layer chain.')
877+
878+
869879
class FixInputPrecision(OptimizerPass):
870880
def match(self, node: Layer):
871881
if not isinstance(node, Input):
@@ -875,11 +885,7 @@ def match(self, node: Layer):
875885
return node.get_output_variable().type.precision.width > 100
876886

877887
def transform(self, model, node: Layer):
878-
out_layers: list[FixedPointQuantizer] = get_output_layers(node) # type: ignore
879-
for layer in out_layers:
880-
assert isinstance(
881-
layer, FixedPointQuantizer
882-
), f'Input {node.name} connected to non-quantizer {layer.name} with non-trivial configuration'
888+
out_layers = list(get_output_quantizers(node))
883889

884890
if len(out_layers) == 0: # Input connected to nothing
885891
new_type = to_hls4ml_fixed(0, 0, 1, f'{node.name}_t')

0 commit comments

Comments
 (0)