How to scale your model: A systems view of LLMs on TPUs

184 points
1/21/1970
12 days ago
by mattjjatgoogle

Comments


3abiton

I am really looking forward for JAX to take over pytorch/cuda over the next years. The whole PTX kerfuffle with Deepseek team shows the value of investing in more low levels approaches to squeeze out the most out of your hardware.

12 days ago

kadushka

Most Pytorch users don’t bother even with the simplest performance optimizations, and you are talking about PTX.

12 days ago

throwaway287391

I like JAX but I'm not sure how an ML framework debate like "JAX vs PyTorch" is relevant to DeepSeek/PTX. The JAX API is at a similar level of abstraction to PyTorch [0]. Both are Python libraries and sit a few layers of abstraction above PTX/CUDA and their TPU equivalents.

[0] Although PyTorch arguably encompasses 2 levels, with both a pure functional library like the JAX API, as well as a "neural network" framework on top of it. Whereas JAX doesn't have the latter and leaves that to separate libraries like Flax.

12 days ago

jdeaton

The interesting thing about this comment is that JAX is actually higher-level even than pytorch generally. Since everything is compiled you just express a logcial program and let the compiler (XLA) worry about the rest.

Are you suggesting that XLA would be where this "lower level" approach would reside since it can do more automatic optimization?

12 days ago

Scene_Cast2

I'm curious, what does paradigmatic JAX look like? Is there an equivalent of picoGPT [1] for JAX?

[1] https://github.com/jaymody/picoGPT/blob/main/gpt2.py

12 days ago

jdeaton

yeah it looks exactly like that file but replace "import numpy as np" with "import jax.numpy as np" :)

12 days ago

achierius

What PTX kerfuffle are you referring to?

11 days ago

saagarjha

You do understand that PTX is part of CUDA right?

12 days ago

lordswork

This has been my bible for performance work internally at Google. Kind of surprised they released it publicly, but I guess they removed all the Gemini-specific details.

12 days ago

memhole

This is awesome! Can't wait to read it. I've been very curious about why we don't hear more about LLMs on TPUs.

12 days ago

jdeaton

Something nice about this guide is that it generally transfers to GPU directly thanks to JAX/XLA.

12 days ago

brap

Not strictly related, but does anyone know why JAX uses tracing and not AST via reflection?

12 days ago

shoyer

The short answer is that tracing is way, way easier to implement in a predictable and reliably performant way. This especially matters for distributed computation and automatic differentiation, two areas where JAX shines.

AST parsing via reflection means your ML compiler needs to re-implement all of Python, which is not a small language. This is a lot of work and hard to do well with abstractions that are not designed for those use-cases. (I believe Julia's whole language auto-diff systems struggle for essential the same reason.)

12 days ago

almostgotcaught

> AST via reflection

I literally am a paid ML compiler engineer and I have no idea what this means. You understand that reflection, ala looking in a mirror is about being about to identify a type's type at runtime. It has nothing to do with the AST.

11 days ago

brap

Congratulations, but that's not what reflection means.

Wikipedia: "reflection is the ability of a process to examine, introspect, and modify its own structure and behavior."

Would you say inspect.getsource(func) fits the definition of reflection?

Would you say ast.parse(inspect.getsource(func)) has something to do with the AST?

10 days ago

almostgotcaught

> Would you say inspect.getsource(func) fits the definition of reflection?

I would say that reflection is absolutely meaningless in an an interpreted runtime because you can always query the runtime.

> Would you say ast.parse(inspect.getsource(func)) has something to do with the AST?

It has something to do with the AST but it doesn't have much to do with reflection.

10 days ago

bronxbomber92

11 days ago

mattjjatgoogle

12 days ago

awongh

Here in the thread he says: https://x.com/jacobaustin132/status/1886844724339675340 : `5 years ago, there were many ML architectures, but today, there is (mostly) only one [transformers].`

To what degree is this actually true, and what else is on the horizon that might become as popular as transformers?

12 days ago

swyx

it's quite true. the convergence of all archs to transformers is well documented by karpathy. SSMs were once touted as transformer killers, but increasingly look like just optional supplements.

12 days ago

perfobotto

What an amazing write up! Thank you very much!

12 days ago

eamag

Any way to convert this Jekyll site to a PDF?

12 days ago

atomala

There are plans to release a PDF version; need to fix some formatting issues + convert the animated diagrams into static images.

12 days ago

hassleblad23

Great writeup. Congrats.

12 days ago

whatever1

How do they make these fancy animations?

12 days ago

alevskaya

Nothing fancy. I made these with some pretty simple hand written scripts in javascript rendering to canvas: lots of fiddly little boxes moving around are simpler to script than to hand animate. (If I were to do much more of this I might rewrite these in blender since it has much nicer authoring tooling and export control.)

12 days ago

nicodjimenez

Shameless request for help: if anybody has experience with seq2seq on TPU, and you want to do a cool project to deploy a world class Pytorch image parsing model to TPU (and do this quickly), please contact me immediately for a well paid and interesting job opportunity at nico [at] mathpix.com.

12 days ago

jdeaton

if you're using tpu why are you using pytorch

12 days ago

hustwindmaple1

there is limited TPU support in pytorch via torch_xla

11 days ago

jdeaton

Sounds limited

10 days ago