Skip to content

Commit 7667750

Browse files
committed
allow transpose after inputs
1 parent 8a4f268 commit 7667750

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

hls4ml/model/optimizer/passes/bit_exact.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import re
55
import typing
6-
from collections.abc import Sequence
6+
from collections.abc import Generator, Sequence
77
from copy import copy
88
from functools import reduce, singledispatch
99
from math import ceil, log2, prod
@@ -824,6 +824,16 @@ def transform(self, model: 'ModelGraph'):
824824
return True
825825

826826

827+
def get_output_quantizers(node: Layer) -> Generator[FixedPointQuantizer, None, None]:
828+
for _node in get_output_layers(node):
829+
if isinstance(_node, FixedPointQuantizer):
830+
yield _node
831+
elif isinstance(_node, (Reshape, Transpose)):
832+
yield from get_output_quantizers(_node)
833+
else:
834+
raise ValueError(f'Layer {node.name} ({node.class_name}) unexpected input layer chain.')
835+
836+
827837
class FixInputPrecision(OptimizerPass):
828838
def match(self, node: Layer):
829839
if not isinstance(node, Input):
@@ -833,11 +843,7 @@ def match(self, node: Layer):
833843
return node.get_output_variable().type.precision.width > 100
834844

835845
def transform(self, model, node: Layer):
836-
out_layers: list[FixedPointQuantizer] = get_output_layers(node) # type: ignore
837-
for layer in out_layers:
838-
assert isinstance(
839-
layer, FixedPointQuantizer
840-
), f'Input {node.name} connected to non-quantizer {layer.name} with non-trivial configuration'
846+
out_layers = list(get_output_quantizers(node))
841847

842848
if len(out_layers) == 0: # Input connected to nothing
843849
new_type = to_hls4ml_fixed(0, 0, 1, f'{node.name}_t')

0 commit comments

Comments
 (0)