Penzai: JAX research toolkit for building, editing, and visualizing neural nets

261 points
1/20/1970
12 days ago
by mccoyb

Comments


yklcs

I like JAX, and find most of the core functionality as an "accelerated NumPy" great. Ecosystem fragmentation and difficulties in interop make adopting JAX hard though.

There's too much fragmentation within the JAX NN library space, which penzai isn't helping with. I wish everyone using JAX could agree on a single set of libraries for NN, optimization, and data loading.

PyTorch code can't be called, meaning a lot of reimplementation in JAX is needed when extending and iterating on prior works, which is the case for most of research. Custom CUDA kernels are a bit fiddly too, I haven't been able to bring Gaussian Splatting to JAX yet.

12 days ago

ddjohnson

I'm curious what interop difficulties you've run into in JAX? In my experience, the JAX ecosystem is quite modular and most JAX libraries work pretty well together. Penzai's core visualization tooling should work for most JAX NN libraries out of the box, and Penzai's neural net components are compatible with existing JAX optimization libraries (like Optax) and data loaders (like tfds/seqio or grain).

(Interop with PyTorch seems more difficult, of course!)

11 days ago

yklcs

It's mostly an ecosystem thing, being unable to use existing methods. In my experience, research goes something like

1. Milestone paper introducing novel method is published with green-field implementation

2. Bunch of papers extend milestone paper with brown-field implementation

3. Goto 1

Most things in 1 are written in PyTorch, meaning 2 also has to be in PyTorch. I know this isn't JAX's fault, but I don't think JAX's philosophy to stay unopinionated and low-level is helping. Seems like the community agreeing on a single set of DL libraries around JAX will help it gain some momentum.

11 days ago

cs702

That's my experience as well. PyTorch dominates the ecosystem.

Which is a shame, because JAX's approach is superior.[a]

---

[a] In my experience, anytime I've have to do anything in PyTorch that isn't well supported out-of-the-box, I've quickly found myself tinkering with Triton, which usually becomes... very frustrating. Meanwhile, JAX offers decent parallelization of anything I write in plain Python, plus really nice primitives like jax.lax.while_loop, jax.lax.associative_scan, jax.lax.select, etc. And yet, I keep using PyTorch... because of the ecosystem.

11 days ago

jacktang

The best is not always popular. JAX idea is very like Erlang programming language.

11 days ago

cs702

> The best is not always popular.

I agree. Network effects routinely overpower better technology.

11 days ago

abhgh

Another issue I've personally faced is debugging - although I am saying this from my experience from more than a yr ago, and maybe things are better now. I have used it mostly for optimization and the error messages aren't helpful.

12 days ago

catgary

I’ve only been reading through the docs for a few moments, but I’m pleasantly surprised to find they the authors are using effect handlers to handle effectful computations in ML models. I was in the process of translating a model from torch to Jax using Equinox, this makes me think penzai could be a better choice.

12 days ago

patrickkidger

I was just reading this too! I think it's a really interesting choice in the design space.

So to elucidate this a little bit, the trade-off is that this is now incompatible with e.g. `jax.grad` or `lax.scan`: you can't compose things in the order `discharge_effect(jax.grad(your_model_here))`, or put an effectful `lax.scan` inside your forward pass, etc. The effect-discharging process only knows how to handle traversing pytree structures. (And they do mention this at the end of their docs.)

This kind of thing was actually something I explicitly considered later on in Equinox, but in part decided against as I couldn't see a way to make that work either. The goal of Equinox was always absolute compatibility with arbitrary JAX code.

Now, none of that should be taken as a bash at Penzai! They've made a different set of trade-offs, and if the above incompatibility doesn't affect your goals then indeed their effect system is incredibly elegant, so certainly give it a try. (Seriously, it's been pretty cool to see the release of Penzai, which explicitly acknowledges how much it's inspired by Equinox.)

12 days ago

ddjohnson

Author of Penzai here! In idiomatic Penzai usage, you should always discharge all effects before running your model. While it's true you can't do `discharge_effect(jax.grad(your_model_here))`, you can still do `jax.grad(discharge_effect(your_model_here))`, which is probably what you meant to do anyway in most cases. Once you've wrapped your model in a handler layer, it has a pure interface again, which makes it fully compatible with all arbitrary JAX transformations. The intended use of effects is as an internal helper to simplify plumbing of values into and out of layers, not as something that affects the top-level interface of using the model!

(As an example of this, the GemmaTransformer example model uses the SideInput effect internally to do attention masking. But it exposes a pure functional interface by using a handler internally, so you can call it anywhere you could call an Equinox model, and you shouldn't have to think about the effect system at all as a user of the model.)

It's not clear to me what the semantics of ordinary JAX transformations like `lax.scan` should be if the model has side effects. But if you don't have any effects in your model, or if you've explicitly handled them already, then it's perfectly fine to use `lax.scan`. This is similar to how it works in ordinary JAX; if you try to do a `lax.scan` over a function that mutates Python state, you'll probably hit an error or get something unexpected. But if you mutate Python state internally inside `lax.scan`, it works fine.

I'll also note that adding support for higher-order layer combinators (like "layer scan") is something that's on the roadmap! The goal would be to support some of the fancier features of libraries like Flax when you need them, while still admitting a simple purely-functional mental model when you don't.

12 days ago

ddjohnson

Thanks! This is one of the more experimental design choices I made in designing Penzai, but so far I've found it to be quite useful.

The effect system does come with a few sharp edges at the moment if you want to use JAX transformations inside the forward pass of your model (see my reply to Patrick), but I'm hoping to make it more flexible as time goes on. (Figuring out how effect systems should compose with function transformations is a bit nontrivial!)

Please let me know if you run into any issues using Penzai for your model! (Also, most of Penzai's visualization and patching utilities should work with Equinox too, so you shouldn't necessarily need to fully commit to either one.)

12 days ago

catgary

This something I’ve thought about in the past, since I messed around with trying to add monads to JAX - I think you made the right call with effect handlers. You might want to take a look at what Koka does, that was the best implementation of effect handlers the last time I checked.

12 days ago

pizza

I remember pytorch has some pytree capability, no? So is it safe to say that the any-pytree-compatible modules here are already compatible w/ pytorch?

12 days ago

ddjohnson

Author here! I didn't know PyTorch had its own pytree system. It looks like it's separate from JAX's pytree registry, though, so Penzai's tooling probably won't work with PyTorch out of the box.

12 days ago

sillysaurusx

I implemented Jax’s pytrees in pure python. You can use it with whatever you want. https://github.com/shawwn/pytreez

The readme is a todo, but the tests are complete. They’re the same that Jax itself uses, but zero dependencies. https://github.com/shawwn/pytreez/blob/master/tests/test_pyt...

The concept is simple. The hard part is cross pollination. Suppose you wanted to literally use Jax pytrees with PyTorch. Now you’ll have to import Jax, or my library, and register your modules with it. But anything else that ever uses pytrees need to use the same pytree library, because the registry (the thing that keeps track of pytree compatible classes) is in the library you choose. They don’t share registries.

A better way of phrasing it is that if you use a jax-style pytree interface, it should work with any other pytree library. But to my knowledge, the only pytree library besides Jax itself is mine here, and only I use it. So when you ask if pytree-compatible modules are compatible with PyTorch, it’s equivalent to asking whether PyTorch projects use jax, and the answer tends to be no.

EDIT: perhaps I’m outdated. OP says that PyTorch has pytree functionality now. https://news.ycombinator.com/item?id=40109662 I guess yet again I was ahead of the times by a couple years; happy to see other ecosystems catch up. Hopefully seeing a simple implementation will clarify the tradeoffs.

The best approach for a universal pytree library would be to assume that any class with tree_flatten and tree_unflatten methods are pytreeable, and not require those classes to be explicitly registered. That way you don’t have to worry whether you’re using Jax or PyTorch pytrees. But I gave up trying to make library-agnostic ML modules; in practice it’s better just to choose Jax or PyTorch and be done with it, since making PyTorch modules run in Jax automatically (and vice versa) is a fool’s errand (I was the fool, and it was an errand) for many reasons, not the least of which is that Jax builds an explicit computation graph via jax.jit, a feature PyTorch has only recently (and reluctantly) embraced. But of course, that means if you pick the wrong ecosystem, you’ll miss out on the best tools — hello React vs Vue, or Unreal Engine vs Unity, or dozens of other examples.

12 days ago

albertzeyer

There are a couple more such libraries. One was inside tensorflow (nest) and then extracted into the standalone dm-tree: https://github.com/deepmind/tree

Or also: https://github.com/metaopt/optree

I think ideally you would try to use mostly standard types (dict, list, tuple, etc) which are supported by all those libraries in mostly the same way, so it's easy to switch.

You have to be careful in some of the small differences though. E.g. what basic types are supported (e.g. dataclass, namedtuple, other derived instances from dict, tuple, etc), or how None is handled.

12 days ago

ubj

Does anyone know if and how well Penzai can work with Diffrax [1]? I currently use Diffrax + Equinox for scientific machine learning. Penzai looks like an interesting alternative to Equinox.

[1]: https://docs.kidger.site/diffrax/

12 days ago

thatguysaguy

Not sure on the specific combination, but since everything in Jax is functionally pure it's generally really easy to compose libraries. E.g. I've written code which embedded a flax model inside a haiku model without much effort.

12 days ago

patrickkidger

IIUC then penzai is (deliberately) sacrificing support for higher-order operations like `lax.{while_loop, scan, cond}` or `diffrax.diffeqsolve`, in return for some of the other new features it is trying out (treescope, effects).

So it's slightly more framework-y than Equinox and will not be completely compatible with arbitrary JAX code. However I have already had a collaborator demonstrate that as long as you don't use any higher-order operations, then treescope will actually work out-of-the-box with Equinox modules!

So I think the answer to your question is "sort of":

* As long as you only try to inspect things that are happening outside of your `diffrax.diffeqsolve` then you should be good to go. And moreover can probably do this simply by using e.g. Penzai's treescope directly alongside your existing Equinox code, without needing to move things over wholesale.

* But anything inside probably isn't supported + if I understand their setup correctly can never be supported. (Not bashing Penzai there, which I think genuinely looks excellent -- I think it's just fundamentally tricky at a technical level.)

12 days ago

ddjohnson

Author of Penzai here. I think the answer is a bit more nuanced (and closer to "yes") than this:

- If you want to use the treescope pretty-printer or the pz.select tree manipulation utility, those should work out-of-the-box with both Equinox and Diffrax. Penzai's utilities are designed to be as modular as possible (we explicitly try not to be "frameworky") so they support arbitrary JAX pytrees; if you run into any problems with this please file an issue!

- If you want to call a Penzai model inside `diffrax.diffeqsolve`, that should also be fully supported out of the box. Penzai models expose a pure functional interface when called, so you should be able to call a Penzai model anywhere that you'd call an Equinox model. From the perspective of the model user, you should be able to think of the effect system as an implementation detail. Again, if you run into problems here, please file an issue.

- If you want to write your own Penzai layer that uses `diffrax.diffeqsolve` internally, that should also work. You can put arbitrary logic inside a Penzai layer as long as it's pure.

- The specific thing that is not currently fully supported is: (1) defining a higher-order Penzai combinator layer that uses `diffrax.diffeqsolve` internally, (2) and having that layer run one of its sublayers inside the `diffrax.diffeqsolve` function, (3) while simultaneously having that internal sublayer use an effect (like random numbers, state, or parameter sharing), (4) where the handler for that effect is placed outside of the combinator layer. This is because the temporary effect implementation node that gets inserted while a handler is running isn't a JAX array type, so you'll get a JAX error when you try to pass it through a function transformation.

This last case is something I'd like to support as well, but I still need to figure out what the semantics of it should be. (E.g. what does it even mean to solve a differential equation that has a local state variable in it?) I think having side effects inside a transformed function is fundamentally hard to get right!

12 days ago

felarof

I have a small YT channel that teaches JAX bit-by-bit, check it out! https://www.youtube.com/@TwoMinuteJAX

11 days ago

ein0p

Looks great, but outside Google I do not personally know anyone who uses Jax, and I work in this space.

12 days ago

logicchains

Not at Google but currently using Jax to leverage TPUs, because AWS's GPU pricing is eye-gougingly expensive. For the lower-end A10 GPUs, the price-per-gpu for a 4 GPU machine is 1.5x the price for a 1 GPU machine, and the price-per-gpu for a 8 GPU machine is 2x the price of a 1 GPU machine! If you want a A100 or H100, the only option is renting an 8 GPU instance. With properly TPU-optimised code you get something like 30-50% cost saving on GCP TPUs compared to AWS (and I say that as someone who otherwise doesn't like Google as a company and would prefer to avoid GCP if there wasn't such a significant cost advantage).

12 days ago

KeplerBoy

I use it for GPU accelerated signal processing. It really delivers on the promise of "Numpy but for GPU" better than all competing libraries out there.

12 days ago

MasterScrat

We've built our startup from scratch on JAX, selling text-to-image model finetuning, and it's given us a consistent edge not only in terms of pure performance but also in terms of "dollars per unit of work"

12 days ago

Havoc

Is that gain from TPU usage or something else?

12 days ago

MasterScrat

Mostly from the tight JAX-TPU integration yeah

11 days ago

_ntka

Isn't JAX the most widely used framework in the GenAI space? Most companies there use it -- Cohere, Anthropic, CharacterAI, xAI, Midjourney etc.

12 days ago

smhx

most of the GenAI players use both PyTorch and JAX, depending on the hardware they are running on. Character, Anthro, Midjourney, etc. are dual shops (they use both). xAI only uses JAX afaik.

12 days ago

mistrial9

just guessing that tech leadership at all of those traces back to Google somehow

12 days ago

error9348

Jax trends on papers with code:

https://paperswithcode.com/trends

12 days ago

nostrademons

Was gonna ask "What's that MindSpore thing that seems to be taking the research world by storm?" but I Googled and it's apparently Huawei's open-source AI framework. 1% to 7% market share in 2 years is nothing to sneeze at - that's growth rates similar to Chrome or Facebook in their heyday.

It's telling that Huawei-backed MindSpore can go from 1% to 7% in 2 years, while Google-backed Jax is stuck at 2-3%. Contrary to popular narrative in the Western world, Chinese dominance is alive and well.

12 days ago

logicchains

>It's telling that Huawei-backed MindSpore can go from 1% to 7% in 2 years, while Google-backed Jax is stuck at 2-3%. Contrary to popular narrative in the Western world, Chinese dominance is alive and well.

MindSpore has an advantage there because of its integrated support for Huawei's Ascend 910B, the only Chinese GPU that comes close to matching the A100. Given the US banned export of A100 and H100s to China, this creates artificial demand for the Ascend 910B chips and the MindSpore framework that utilises them.

12 days ago

bigcat12345678

No, mindspore rises because of the chip embargo

No one is going to use stuff that one day is cut off supply.

This is one signal why Huawei was listed by Nvidia as competitor in 4 out of 5 categories of areas, in nvidia's earnings

12 days ago

ein0p

Its meteoric rise started well before the chip embargo. I've looked into it, it liberally borrows ideas from other frameworks, both PyTorch and Jax, and adds some of its own. You lose some of the conceptual purity, but it makes up for it in practical usability, assuming it works as it says on the tin, which it may or may not. PyTorch also has support for Ascend as far as I can tell https://github.com/Ascend/pytorch, so that support does not necessarily explain MindSpore's relative success. Why MindSpore is rising so rapidly is not entirely clear to me. Could be something as simple as preferring a domestic alternative that is adequate to the task and has better documentation in Chinese. Could be cost of compute. Could be both. Nowadays, however, I do agree that the various embargoes would help it (as well as Huawei) a great deal. As a side note I wish Huawei could export its silicon to the West. I bet that'd result in dramatically cheaper compute.

12 days ago

creato

This data might just be unreliable. It had a weird spike in Dec 2021 that looks unusual compared to all the other frameworks.

12 days ago

VHRanger

China publishes a looooootttttt of papers. A lot of it is careerist crap.

To be fair, a lot of US papers are also crap, but Chinese crap research is on another level. There's a reason a lot of top US researchers are Chinese - there's brain drain going on.

12 days ago

niihau_island

When I looked into a random sampling of these uses, my impression was that it was a common kind of project in China to take a common paper (or another repo) and implement it in Mindspore. That accounted for the vast majority of the implementations.

12 days ago

ein0p

Note that most of Jax’s minuscule share is Google.

12 days ago

j7ake

I’m in academia and I use jax because it’s closest to translate maths to code.

12 days ago

hyperbovine

Same, Jax is extremely popular with the applied math/modeling crowd.

12 days ago

sudosysgen

I use it all the time, and there's also a few classes at my uni that use Jax. It's really great for experimentation and research, you can do a lot of things in Jax you just can't in, say, PyTorch.

12 days ago

p1esk

Like what?

12 days ago

sudosysgen

Anytime you want to make something GPU accelerated that doesn't fit as standard operations on tensors. For example, I often write RL environments in Jax, which is something you can't do in PyTorch. There's also things you can do in PyTorch but that would be far more difficult, for example an efficient implementation of MCTS.

I also used Jax a lot for differential equations, not even sure how I would do that with PyTorch.

Basically, Torch is a lot more like a specialization of Numpy for neural networks, while Jax feels a lot more like if you could just write CUDA as Python, and also get the Jacobians (jacs! jax!) and jvp for free (of everything, you can even differentiate hyperparameters through your optimizer which is crazy).

At the end, when you're doing fundamental research and coming up with something new, I think Jax is just better. If all I had to do was implementation, then I would be a happy PyTorch user.

12 days ago

polygamous_bat

A small addendum: the only people I know who uses Jax are people who work at Google, or people who had a big GCP grant and needed to use TPUs as a result.

12 days ago

mccoyb

That's cool -- but wouldn't it be more constructive to discuss "the ideas" in this package anyways?

For instance, it would be interesting to discern if the design of PyTorch (and their modules) preclude or admit the same sort of visualization tooling? If you have expertise in PyTorch, perhaps you could help answer this sort of question?

JAX's Pytrees are like "immutable structs, with array leaves" -- does PyTorch have a similar concept?

12 days ago

fpgamlirfanboy

> does PyTorch have a similar concept

of course https://github.com/pytorch/pytorch/blob/main/torch/utils/_py...

12 days ago

ein0p

Idk if you need that immutability actually. You could probably reconstruct enough to do this kind of viz from the autograd graph, or capture the graph and intermediates in the forward pass using hooks. My hunch is it should be doable.

12 days ago

eli_gottlieb

If JAX had affine_grid() and grid_sample(), I'd be using it instead of PyTorch for my current project.

12 days ago

yshvrdhn

it would be great if we can have intelligent tools for building neural networks in pytorch.

12 days ago

Edmond

would a comprehensive object construction platform with schema support and the ability to hookup to a compiler (ie turn object data to code for instance) be a useful tool in this domain?

ex: https://www.youtube.com/watch?v=fPnD6I9w84c

I am the developer, happy to answer questions.

12 days ago