felarof
Hey HN, we recently fine-tuned the llama3.1 405B model on 8xAMD MI300x GPUs using JAX instead of PyTorch. JAX's advanced sharding APIs allowed us to achieve great performance. Check out our blog post to learn about the cool sharding tricks we used. We've also open-sourced the code: https://github.com/felafax/felafax

We're a small startup building AI infra for fine-tuning and serving LLMs on non-NVIDIA hardware (TPUs, AMD, Trainium).

Problem: Many companies are trying to get PyTorch working on AMD GPUs, but we believe this is a treacherous path. PyTorch is deeply intertwined with the NVIDIA ecosystem in a lot of ways (e.g., `torch.cuda` or scaled_dot_product_attention is an NVIDIA CUDA kernel exposed as a PyTorch function). So, to get PyTorch code running on non-NVIDIA hardware, there's a lot of "de-NVIDIAfying" that needs to be done.

Solution: We believe JAX is a better fit for non-NVIDIA hardware. In JAX, ML model code compiles to hardware-independent HLO graphs, which are then optimized by the XLA compiler before hardware-specific optimization. This clean separation allowed us to run the same LLaMA3 JAX code both on Google TPUs and AMD GPUs with no changes.

Our strategy as a company is to invest upfront in porting models to JAX, then leverage its framework and XLA kernels to extract maximum performance from non-NVIDIA backends. This is why we first ported Llama 3.1 from PyTorch to JAX, and now the same JAX model works great on TPUs and runs perfectly on AMD GPUs.

We'd love to hear your thoughts on our vision and repo!

chillee
To be clear, this performance is quite bad (presumably because you didn't manage to get compilation working).

You're getting 35 tokens/s for a 405B model, which comes out to about 85 Teraflops. 8 MI300x GPUs comes out to 10.4 Petaflops, so you're getting about 0.8% MFU (which is about 40-50x worse than decent training performance of 30-40% MFU).

For AMD's sake, I hope that it's your software stack that's limiting perf.

3abiton
Firstly great work! I dabbled with AMD GPUs and ROCm support a year ago, and it was obvious AMD still a long way from catch ling up with Nvidia. While opting for JAX is in an interesting approach, what were the challenges for you deviating from pytorch (being the standard library for ML)?
latchkey
Nice work! I was just playing with the inference side of things with 405B myself this weekend [0].

I'm not convinced that 'torch.cuda' is really that bad since the AMD version of PyTorch just translates that for you. More like a naming problem, than anything. Fact is that it is just as easy to grab the rocm:pytorch container, as it is the rocm:jax container.

I don't see very many numbers posted. What MFU did you get?

[0] https://x.com/HotAisle/status/1837580046732874026

steeve
We (ZML) measured MI300X at 30% faster than H100. These are great chips!
brutus1213
Does any Cloud provider have a 8xAMD MI300 host that one can rent? I use AWS for a lot of my professional work, and was hoping to try out an AMD GPU.
yeahwhatever10
Where is the performance data?
Stem0037
If possible, it would be interesting to explore ways to overcome the memory constraints and run a JIT-compiled version. This could potentially lead to further performance improvements.
yieldcrv
Is AMD any closer to extracting value from this with large orders of their GPUs causing a shortage?

I’m getting the impression of “no”

system2
Why is obsidian (a note-taking app) doing this?
varispeed
How do you buy such a GPU or is it still only reserved to the rich so they can get ahead of the game once the pleb gets their unwashed hands on these cards?
manojlds
Thought this was a post from Obsidian at first. Why haven't they done the GitHub.com vs GitHub.io thing yet.
abalaji
@dang: could we get url to include the username since this isn't about Obsidian itself, but rather a user generated blog?
lerpgame
[dead]
oliver_jack
[flagged]