Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

Do you have experience in both JAX and PyTorch? Why do you prefer JAX?


Not OP. I prefer JAX for non-AI tasks in scientific computing because of the different mental model than PyTorch. In JAX, you think about functions and gradients of functions. In PyTorch you think about tensors which accumulate a gradient while being manipulated through functions. JAX just suits my way of thinking much better.

I also like that jax.jit forces you to write "functional" functions free of side effects or inplace array updates. It might feel weird at first (and not every algorithm is suited for this style) but ultimately it leads to clearer and faster code.

I am surprised that JIT in PyTorch gets so little attention. Maybe it's less impactful for PyTorch's usual usecase of large networks, as opposed to general scientific computing?


>I also like that jax.jit forces you to write "functional" functions free of side effects or inplace array updates. It might feel weird at first (and not every algorithm is suited for this style) but ultimately it leads to clearer and faster code.

It's not weird. It's actually the most natural way of doing things for me. You just write down your math equations as JAX and you're done.


> You just write down your math equations as JAX and you're done.

It's natural when your basic unit is a whole vector (tensor), manipulated by some linear algebra expression. It's less natural if your basic unit is an element of a vector.

If you're solving sudoku, for example, the obvious 'update' is in-place.

In-place updates are also often the right answer for performance reasons, such as writing the output of a .map() operation directly to the destination tensor. Jax leans heavily on compile-time optimizations to turn the mathematically-nice code into computer-nice code, so the delta between eager-Jax and compiled-Jax is much larger than the delta between eager-Pytorch and compiled-Pytorch.


Not Op. I have production / scale experience in PyTorch and toy/hobby experience in JAX. I wish I could have time time or liberty to use JAX more. It consists of small, orthogonal set of ideas that combine like lego blocks. I can attempt to reason from first principals about performance. The documentation is super readable and strives to make you understand things.

JAX seems well engineered. One would argue so was TensorFlow. But ideas behind JAX were built outside Google (autograd) so it has struck right balance with being close to idiomatic Python / Numpy.

PyTorch is where the tailwinds are, though. It is a wildly successful project which has acquired ton of code over the years. So it is little harder to figure out how something works (say torch-compile) from first principles.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: