How to scale your model: A systems view of LLMs on TPUs
Comments
3abiton
kadushka
Most Pytorch users don’t bother even with the simplest performance optimizations, and you are talking about PTX.
throwaway287391
I like JAX but I'm not sure how an ML framework debate like "JAX vs PyTorch" is relevant to DeepSeek/PTX. The JAX API is at a similar level of abstraction to PyTorch [0]. Both are Python libraries and sit a few layers of abstraction above PTX/CUDA and their TPU equivalents.
[0] Although PyTorch arguably encompasses 2 levels, with both a pure functional library like the JAX API, as well as a "neural network" framework on top of it. Whereas JAX doesn't have the latter and leaves that to separate libraries like Flax.
jdeaton
The interesting thing about this comment is that JAX is actually higher-level even than pytorch generally. Since everything is compiled you just express a logcial program and let the compiler (XLA) worry about the rest.
Are you suggesting that XLA would be where this "lower level" approach would reside since it can do more automatic optimization?
Scene_Cast2
I'm curious, what does paradigmatic JAX look like? Is there an equivalent of picoGPT [1] for JAX?
jdeaton
yeah it looks exactly like that file but replace "import numpy as np" with "import jax.numpy as np" :)
achierius
What PTX kerfuffle are you referring to?
saagarjha
You do understand that PTX is part of CUDA right?
lordswork
This has been my bible for performance work internally at Google. Kind of surprised they released it publicly, but I guess they removed all the Gemini-specific details.
memhole
This is awesome! Can't wait to read it. I've been very curious about why we don't hear more about LLMs on TPUs.
jdeaton
Something nice about this guide is that it generally transfers to GPU directly thanks to JAX/XLA.
brap
Not strictly related, but does anyone know why JAX uses tracing and not AST via reflection?
shoyer
The short answer is that tracing is way, way easier to implement in a predictable and reliably performant way. This especially matters for distributed computation and automatic differentiation, two areas where JAX shines.
AST parsing via reflection means your ML compiler needs to re-implement all of Python, which is not a small language. This is a lot of work and hard to do well with abstractions that are not designed for those use-cases. (I believe Julia's whole language auto-diff systems struggle for essential the same reason.)
almostgotcaught
> AST via reflection
I literally am a paid ML compiler engineer and I have no idea what this means. You understand that reflection, ala looking in a mirror is about being about to identify a type's type at runtime. It has nothing to do with the AST.
brap
Congratulations, but that's not what reflection means.
Wikipedia: "reflection is the ability of a process to examine, introspect, and modify its own structure and behavior."
Would you say inspect.getsource(func) fits the definition of reflection?
Would you say ast.parse(inspect.getsource(func)) has something to do with the AST?
almostgotcaught
> Would you say inspect.getsource(func) fits the definition of reflection?
I would say that reflection is absolutely meaningless in an an interpreted runtime because you can always query the runtime.
> Would you say ast.parse(inspect.getsource(func)) has something to do with the AST?
It has something to do with the AST but it doesn't have much to do with reflection.
bronxbomber92
An example of “AST via reflection”
https://docs.scala-lang.org/scala3/reference/metaprogramming...
mattjjatgoogle
An author's tweet thread: https://x.com/jacobaustin132/status/1886844716446007300
awongh
Here in the thread he says: https://x.com/jacobaustin132/status/1886844724339675340 : `5 years ago, there were many ML architectures, but today, there is (mostly) only one [transformers].`
To what degree is this actually true, and what else is on the horizon that might become as popular as transformers?
swyx
it's quite true. the convergence of all archs to transformers is well documented by karpathy. SSMs were once touted as transformer killers, but increasingly look like just optional supplements.
perfobotto
What an amazing write up! Thank you very much!
eamag
Any way to convert this Jekyll site to a PDF?
atomala
There are plans to release a PDF version; need to fix some formatting issues + convert the animated diagrams into static images.
hassleblad23
Great writeup. Congrats.
whatever1
How do they make these fancy animations?
alevskaya
Nothing fancy. I made these with some pretty simple hand written scripts in javascript rendering to canvas: lots of fiddly little boxes moving around are simpler to script than to hand animate. (If I were to do much more of this I might rewrite these in blender since it has much nicer authoring tooling and export control.)
nicodjimenez
Shameless request for help: if anybody has experience with seq2seq on TPU, and you want to do a cool project to deploy a world class Pytorch image parsing model to TPU (and do this quickly), please contact me immediately for a well paid and interesting job opportunity at nico [at] mathpix.com.
jdeaton
if you're using tpu why are you using pytorch
hustwindmaple1
there is limited TPU support in pytorch via torch_xla
jdeaton
Sounds limited
I am really looking forward for JAX to take over pytorch/cuda over the next years. The whole PTX kerfuffle with Deepseek team shows the value of investing in more low levels approaches to squeeze out the most out of your hardware.