diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index 4bc3fbe85..4414816fd 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -238,13 +238,28 @@ def parse_pytorch_model(config, verbose=True): # if a 'getitem' is the input to a node, step back in the graph to find the real source of the input elif "getitem" in node.args[0].name: - for tmp_node in traced_model.nodes: - if tmp_node.name == node.args[0].name: - if "getitem" in tmp_node.args[0].name: - raise Exception('Nested getitem calles not resolved at the moment.') - input_names = [inputs_map.get(str(tmp_node.args[0]), str(tmp_node.args[0]))] - input_shapes = [output_shapes[str(tmp_node.args[0])]] - node.args = [tmp_node.args[0]] + + def resolve_getitem_source(node_name, visited=None): + """Recursively resolve nested getitem calls to find the actual source node.""" + if visited is None: + visited = set() + + if node_name in visited: + raise Exception(f'Circular reference detected in getitem chain: {node_name}') + visited.add(node_name) + + for tmp_node in traced_model.nodes: + if tmp_node.name == node_name: + if "getitem" in tmp_node.args[0].name: + return resolve_getitem_source(tmp_node.args[0].name, visited) + else: + return tmp_node.args[0] + raise Exception(f'Could not find source node for getitem: {node_name}') + + source_node = resolve_getitem_source(node.args[0].name) + input_names = [inputs_map.get(str(source_node), str(source_node))] + input_shapes = [output_shapes[str(source_node)]] + node.args = [source_node] else: input_shapes = [output_shapes[str(i)] for i in node.args] # for Conv layers