I Rebuilt Karpathy's NanoChat in JAX. Here's What XLA Gets Right and What It Gets Dead Wrong.
The article details a port of Andrej Karpathy's NanoChat from PyTorch to JAX and Flax NNX, aiming to enable scalable AI experiments across GPU and TPU using XLA compilation. The JAX version achieved fast training on a single GPU and supports seamless execution on TPU, though it lacks certain optimizations like vLLM and Flash Attention 3. The author highlights both the advantages of XLA, such as performance after JIT compilation, and its drawbacks, including difficult debugging within compiled functions.
Opening excerpt (first ~120 words) tap to expand
try { if(localStorage) { let currentUser = localStorage.getItem('current_user'); if (currentUser) { currentUser = JSON.parse(currentUser); if (currentUser.id === 3703201) { document.getElementById('article-show-container').classList.add('current-user-is-article-author'); } } } } catch (e) { console.error(e); } Omotayo Aina for Google Developer Experts Posted on May 1 I Rebuilt Karpathy's NanoChat in JAX. Here's What XLA Gets Right and What It Gets Dead Wrong. #ai #machinelearning #deeplearning #python AI (3 Part Series) 1 Gemma-SRE: Self-Hosted vLLM Infrastructure Agent 2 Self-hosted Gemma 4 on TPU with vLLM, MCP, ADK, and Gemini CLI 3 I Rebuilt Karpathy's NanoChat in JAX. Here's What XLA Gets Right and What It Gets Dead Wrong.
…
Excerpt limited to ~120 words for fair-use compliance. The full article is at DEV.to (Top).