You're getting 35 tokens/s for a 405B model, which comes out to about 85 Teraflops. 8 MI300x GPUs comes out to 10.4 Petaflops, so you're getting about 0.8% MFU (which is about 40-50x worse than decent training performance of 30-40% MFU).
For AMD's sake, I hope that it's your software stack that's limiting perf.
I'm not convinced that 'torch.cuda' is really that bad since the AMD version of PyTorch just translates that for you. More like a naming problem, than anything. Fact is that it is just as easy to grab the rocm:pytorch container, as it is the rocm:jax container.
I don't see very many numbers posted. What MFU did you get?
I’m getting the impression of “no”
We're a small startup building AI infra for fine-tuning and serving LLMs on non-NVIDIA hardware (TPUs, AMD, Trainium).
Problem: Many companies are trying to get PyTorch working on AMD GPUs, but we believe this is a treacherous path. PyTorch is deeply intertwined with the NVIDIA ecosystem in a lot of ways (e.g., `torch.cuda` or scaled_dot_product_attention is an NVIDIA CUDA kernel exposed as a PyTorch function). So, to get PyTorch code running on non-NVIDIA hardware, there's a lot of "de-NVIDIAfying" that needs to be done.
Solution: We believe JAX is a better fit for non-NVIDIA hardware. In JAX, ML model code compiles to hardware-independent HLO graphs, which are then optimized by the XLA compiler before hardware-specific optimization. This clean separation allowed us to run the same LLaMA3 JAX code both on Google TPUs and AMD GPUs with no changes.
Our strategy as a company is to invest upfront in porting models to JAX, then leverage its framework and XLA kernels to extract maximum performance from non-NVIDIA backends. This is why we first ported Llama 3.1 from PyTorch to JAX, and now the same JAX model works great on TPUs and runs perfectly on AMD GPUs.
We'd love to hear your thoughts on our vision and repo!