Skip to content

Conversation

vloncar
Copy link
Contributor

@vloncar vloncar commented Aug 4, 2025

Description

Followup to #1322 to add support for Pad op in (Q)ONNX.

I cleaned up the test so that it checks both PyTorch and ONNX. Additionally, the channels last converter seem to ignore the setting "off", so I modified that too. May need an additional check.

Type of change

  • Bug fix (non-breaking change that fixes an issue) - The change in handling ChannelsLastConversion = off
  • New feature (non-breaking change which adds functionality) - Support for ONNX Pad op

Tests

Tests are included in test_zeropadding_pytorch_onnx.py

@vloncar vloncar requested review from jmitrevs and JanFSchulte August 4, 2025 18:54
@@ -13,8 +13,9 @@ class ChannelsLastConverter(OptimizerPass):

def match(self, node):
# If this parameter has not been set, this model does not need to be converted
if 'ChannelsLastConversion' not in node.model.config.config['HLSConfig']['Model']:
return False # No littering of unused property
do_convert = node.model.config.config['HLSConfig']['Model'].get('ChannelsLastConversion', 'off')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JanFSchulte Can you check this? I found that if I set ChannelsLastConversion to off as instructed by the docs, nothing happens, and we still get into this optimizer and the node ends up with a change to data_format = channels_last.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that seems like a bug, thanks for catching. Not quite sure how that happened since it looks like I implemented the switches in the config without ever propagating the 'off' setting to the optimizer. But this change looks good to me.

hls_model_pytorch.compile()

onnx_path = str(test_root_path / 'hls4mlprj_constpad_1d/pad1d.onnx')
torch.onnx.export(model, torch.randn(1, 2, 4), onnx_path, opset_version=10)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the now recommended usage now with dynamo=True? I have previously followed the recipe from https://docs.pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html successfully, though I can't guarantee that it would work here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw that the docs now push dynamo=True a lot, but wasn't sure what's our stance on it. I can change if that's what you prefer? I realize now that I have to change this line anyway, opset_version=10 is a workaround for my local env, it should not be in the code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the docs, they do push it pretty heavily indeed, so I'd be in favor of going with the dynamo=True option just to be sure we support the recommended usage.

@vloncar vloncar mentioned this pull request Aug 5, 2025
8 tasks
@vloncar vloncar added the please test Trigger testing by creating local PR branch label Aug 6, 2025
@JanFSchulte JanFSchulte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Sep 8, 2025
hls_model_pytorch.compile()

onnx_path = str(test_root_path / 'hls4mlprj_constpad_1d/pad1d.onnx')
torch.onnx.export(model, torch.randn(1, 2, 4), onnx_path, dynamo=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The torch.onnx.export needs the onnxscript module that we don't have in the test environment. So that needs to be added for these tests to work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just adding onnxscript to the toml file should be fine, right? There's no need for a new container for this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that should work. We would also need to require torch>=2.5 to make sure that the dynamo argument is available for the export. Trying that out just royally blew up my testing environment so I think I need to rebuilt that from scratch and I'll try to come with a setup that works.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apparently there is an issue right now with ONNX and ml-dtypes, so adding these 3 lines to the pyproject.toml gets the test to work:

  "torch>=2.5",
  "onnxscript",
  "ml-dtype>=0.5.3"

Now that the test actually runs it reveals an actual issue:

    @onnx_handler('Pad')
    def parse_pad_layer(node, input_names, input_shapes, graph):
        layer = {}
        layer['name'] = node.name
        layer['class_name'] = 'ZeroPadding'
        layer['inputs'] = input_names
        layer['outputs'] = list(node.output)
        layer['data_format'] = (
            'channels_last' if any(node.domain == 'qonnx.custom_op.channels_last' for node in graph.node) else 'channels_first'
        )
    
        mode = get_onnx_attribute(node, 'mode')
        if mode is not None and mode != 'constant':
            raise RuntimeError(f'Unsupported padding mode: {mode} in node {node.name}')
    
        pads = get_onnx_attribute(node, 'pads')
    
        dim = 0
        if len(input_shapes[0]) == 3:
            dim = 1  # 2D input (batch, channels, width), will use ZeroPadding1D
            if layer['data_format'] == 'channels_first':
                _, channels, width = input_shapes[0]
>               pad_left, pad_right = pads[2], pads[5]
E               TypeError: 'NoneType' object is not subscriptable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In their infinite wisdom, ONNX team changed the pads to be an input and not an attribute, then depending on the PyTorch setup you have it will export one or the other.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that they are at opset 23 already, I think requiring >= 11 seems reasonable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See here. Tools move far slower than opset versions. We can support both, supporting it as input is far more annoying though as that's a separate node

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do both, of course. But it looks that if we want to support the preferred opset version, the one where it's an input is required either way :/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apparently FINN already requires 13 or later, so there is a qonnx pr to update the preferred opset to 13. (Probably qonnx should remove the warning on the GEMM to matmul converter that suggests an old opset.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now support the opset >= 11. But the problem with the dependencies remains. If I put the dependency on ml-dtypes>=.0.5.3 in the testing optional the TF breaks so most of the tests fails. Should we split the env into two in a separate PR?

@jmitrevs jmitrevs added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Sep 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
please test Trigger testing by creating local PR branch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants