Skip to content
Discussion options

You must be logged in to vote

@DiagRisker I have found the answer... it took time I went through the Jax implementation of conv_transpose and I found out that Jax doesn't automatically flip the kernel by 90 degrees i.e. swap axes which is what conv_transpose2d does in pytorch (like in gradient calculation), just by setting transpose_kernel=True in the jax.lax.conv_transpose call the kernel is flipped and the result matches the pytorch implementation.

# JAX implementation
# L.conv_transpose(lhs, rhs, stride, padding, dimension_numbers=dimension_numbers, transpose_kernel=True)
    
 Array([[[[-1.2976712 ,  7.6736045 ,  3.9103894 , ...,  4.156144  ,
          -3.4796333 , -0.91263664],
         [ 4.741432  ,  3.938364  ,  

Replies: 4 comments 8 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
1 reply
@SanjithKumar2
Comment options

Comment options

You must be logged in to vote
1 reply
@hawkinsp
Comment options

Comment options

You must be logged in to vote
6 replies
@DiagRisker
Comment options

@SanjithKumar2
Comment options

@DiagRisker
Comment options

@SanjithKumar2
Comment options

Answer selected by SanjithKumar2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
3 participants
Converted from issue

This discussion was converted from issue #32566 on October 13, 2025 14:45.