Replies: 1 comment 1 reply
-
| 
 Can you elaborate here? How does this trip up the compiler? Note that  | 
Beta Was this translation helpful? Give feedback.
                  
                    1 reply
                  
                
            
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment
  
        
    
Uh oh!
There was an error while loading. Please reload this page.
-
We are given
arr = jax.numpy.array(...)wherea.shapecan be anything, and we seek another arrayidxswhose shape is(*a.shape, 1)having roughly the property thatidxs[i] == ifor all validarr-indicesi. The precise intended behavior is that of the following code:Thus by including
idxs(arr.shape)along witharras arguments to avmap, it is useful anywhere one wants to know the index at which they are operating, like havingenumeratein standard Python.The problem with the above
idxscode is that, although its shape behavior is obviously statically determined by that of its argumentarr, the construction passes througharr.shape, which trips up the JIT compiler. There are functionszeros_like,full_like, and so on, that directly takearr(notarr.shape) and are respected by the JIT compiler. Is there some analogue ("idxs_like") here?Beta Was this translation helpful? Give feedback.
All reactions