Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

I switched from pytorch to jax just before triton appeared. Does anyone know how jax compares to this autotuning machinery in pytorch ? I know jax does jit, but i don't have a good intuition if jit is better than this type of autotuning.


Pallas is the Triton equivalent in JAX land. There are some old auto tuning prototypes if you search for Pallas, like this https://github.com/jax-ml/jax-triton/pull/108




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: