You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Using reduce_window with 0.0 as it's init value causes a Linearization error when only jit compiling a grad transformed version of the function. Below is the example code with error. The error doesn't happen if -jnp.inf is used or the current version is not jit compiled.
Interestingly removing jit compilation from the above make it not throw out this error even though it's the same. Below is the working version in which both gradand jit functions work.
defmaxpool_step(x, kernel_size, stride):
returnjax.lax.reduce_window(
x,
-jnp.inf, # 0.0, jnp.float32(0.0), jnp.array(0.0) or jnp.array(-jnp.inf) all cause errorsjax.lax.max,
window_dimensions=(1, 1, kernel_size, kernel_size),
window_strides=(1, 1, stride, stride),
padding="VALID",
)
My questions are:
Why does it throw an error for just the certain values? ( 0.0 / jnp.float32(0.0) / jnp.array(0.0) / jnp.array(-jnp.inf) )
In the first case why does it throw an error only when the grad function is jit compiled?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Using reduce_window with
0.0as it'sinitvalue causes a Linearization error when only jit compiling a grad transformed version of the function. Below is the example code with error. The error doesn't happen if-jnp.infis used or the current version is not jit compiled.Interestingly removing jit compilation from the above make it not throw out this error even though it's the same. Below is the working version in which both
gradandjitfunctions work.My questions are:
Beta Was this translation helpful? Give feedback.
All reactions