Skip to content

[Experiment] ROCm backend#2300

Open
NripeshN wants to merge 339 commits into
ml-explore:mainfrom
NripeshN:rocm-support
Open

[Experiment] ROCm backend#2300
NripeshN wants to merge 339 commits into
ml-explore:mainfrom
NripeshN:rocm-support

Conversation

@NripeshN

@NripeshN NripeshN commented Jun 16, 2025

Copy link
Copy Markdown
Contributor

Experiment with ROCm backend.

install MLX with ROCm backend using:

mkdir build && cd build
cmake -DMLX_BUILD_ROCM=ON \
      -DCMAKE_PREFIX_PATH=/opt/rocm \
      -DCMAKE_HIP_ARCHITECTURES="gfx90a;gfx1100" \
      ..
make -j$(nproc)

closes #2556

Inspired by @zcbenz

@NripeshN NripeshN changed the title [Experiment] ROCm backend initial push [Experiment] ROCm backend Jun 16, 2025
@lin72h

lin72h commented Jun 17, 2025

Copy link
Copy Markdown

What an unexpected and amazing surprise! I'm absolutely thrilled.

@NripeshN

Copy link
Copy Markdown
Contributor Author

@awni
What do you think of this PR? Does this have the potential to be merged into main? I can turn this PR from experimental to WIP if so.

@angeloskath

Copy link
Copy Markdown
Member

I think this is good to stay as an experiment branch for some time while we work on core and CUDA. I don't think we have the bandwidth to merge this for a few months at least. Sorry if this is disappointing @NripeshN I don't mean to discourage you working on it.

@akshat2602

Copy link
Copy Markdown

I would love to see the ROCm backend get more traction. The new AI series of processors by AMD have a similar advantage to Apple Silicon with unified memory and getting MLX to run on those processors would be neat.

@countradooku

Copy link
Copy Markdown

Stole my idea :(

@goniz

goniz commented Jan 22, 2026

Copy link
Copy Markdown

How is this even possible for such an awesome PR to be left like this?

Copilot AI review requested due to automatic review settings January 24, 2026 17:08

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds experimental ROCm backend support to MLX, enabling execution on AMD GPUs. The implementation mirrors the CUDA backend structure, providing HIP-based implementations of core operations, memory management, and device handling.

Changes:

  • Added ROCm backend infrastructure with device management, memory allocation, and stream handling
  • Implemented HIP kernels for unary, binary, ternary operations, reductions, normalization (softmax, layer_norm, rms_norm), RoPE, and sorting
  • Updated build system (CMake) to support ROCm compilation with configurable GPU architectures

Reviewed changes

Copilot reviewed 59 out of 59 changed files in this pull request and generated 13 comments.

Show a summary per file
File Description
CMakeLists.txt Added MLX_BUILD_ROCM option and ROCm library detection
mlx/CMakeLists.txt Integrated ROCm backend build configuration
mlx/device.cpp Added ROCm device availability checks
mlx/backend/rocm/*.hip HIP kernel implementations for various operations
mlx/backend/rocm/device.* ROCm device and stream management
mlx/backend/rocm/allocator.* ROCm-specific memory allocator using HIP unified memory
mlx/backend/rocm/worker.* Async task execution worker for stream synchronization
mlx/backend/rocm/utils.* HIP utility functions and error handling
mlx/backend/rocm/jit_module.* JIT compilation support using HIPRTC
mlx/backend/rocm/device/*.hpp Device-side utility functions and type definitions
mlx/backend/rocm/CMakeLists.txt ROCm backend build configuration

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread mlx/backend/rocm/softmax.hip Outdated
Comment thread mlx/backend/rocm/device.cpp Outdated
Comment thread mlx/backend/rocm/layer_norm.hip Outdated
Comment thread mlx/backend/rocm/rope.hip Outdated
Comment thread mlx/backend/rocm/softmax.hip Outdated
Comment thread mlx/backend/rocm/allocator.cpp Outdated
Comment thread CMakeLists.txt Outdated
Comment thread mlx/backend/rocm/binary.hip Outdated
Comment thread mlx/backend/rocm/rms_norm.hip Outdated
Comment thread mlx/backend/rocm/layer_norm.hip Outdated
@goniz

goniz commented Jan 24, 2026

Copy link
Copy Markdown

👑👑👑

@NripeshN

Copy link
Copy Markdown
Contributor Author

Can anyone run

CMAKE_ARGS="-DMLX_BUILD_ROCM=ON" pip install -e .
CMAKE_ARGS="-DMLX_BUILD_ROCM=ON -DMLX_ROCM_ARCHITECTURES={based on your GPU}" pip install -e .

Replace {based on your GPU} with your GPU architecture

You can run

rocm-smi

to get your GPU information

@goniz

goniz commented Jan 24, 2026

Copy link
Copy Markdown

I'm getting this CMake error:

CMAKE_ARGS="-DMLX_BUILD_ROCM=ON -DMLX_ROCM_ARCHITECTURES=gfx1151" pip install -e .

      -- Configuring done (4.8s)
      CMake Error: The following variables are used in this project, but they are set to NOTFOUND.
      Please set them or make sure they are set and tested correctly in the CMake files:
      /home/goniz/Work/mlx/LAPACK_INCLUDE_DIRS
         used as include directory in directory /home/goniz/Work/mlx
      
      CMake Error in CMakeLists.txt:
        HIP_ARCHITECTURES is empty for target "mlx".
      
      
      CMake Error in CMakeLists.txt:
        HIP_ARCHITECTURES is empty for target "mlx".
      
      
      -- Generating done (0.0s)
      CMake Generate step failed.  Build files cannot be regene
rated correctly.

Running on Strix Halo (gfx1151)

@NripeshN

Copy link
Copy Markdown
Contributor Author

I'm getting this CMake error:

CMAKE_ARGS="-DMLX_BUILD_ROCM=ON -DMLX_ROCM_ARCHITECTURES=gfx1151" pip install -e .
     -- Configuring done (4.8s)
     CMake Error: The following variables are used in this project, but they are set to NOTFOUND.
     Please set them or make sure they are set and tested correctly in the CMake files:
     /home/goniz/Work/mlx/LAPACK_INCLUDE_DIRS
        used as include directory in directory /home/goniz/Work/mlx
     
     CMake Error in CMakeLists.txt:
       HIP_ARCHITECTURES is empty for target "mlx".
     
     
     CMake Error in CMakeLists.txt:
       HIP_ARCHITECTURES is empty for target "mlx".
     
     
     -- Generating done (0.0s)
     CMake Generate step failed.  Build files cannot be regene
rated correctly.

Running on Strix Halo (gfx1151)

Could you retry with the latest push please (p.s. keep your fingers crossed while it compiles, worked for me 138th time)😅

@goniz

goniz commented Jan 25, 2026

Copy link
Copy Markdown
  Created wheel for mlx: filename=mlx-0.30.4.dev20260125+cadf18c1-0.editable-cp314-cp314-linux_x86_64.whl size=4722 sha256=72c664adbfc4fb9ec317522a8d83b84f85d599d08bd691d7fec3abfdb6f3a5e9
  Stored in directory: /tmp/pip-ephem-wheel-cache-nt7w6bq0/wheels/8a/63/d1/d7d629a5ff73457822bb71aa527c083674bb19ca314735cd05
Successfully built mlx
Installing collected packages: mlx
Successfully installed mlx-0.30.4.dev20260125+cadf18c1

Now what can I test? 😍

@goniz

goniz commented Jan 25, 2026

Copy link
Copy Markdown

I'm getting this:

ImportError: /home/goniz/Work/mlx/python/mlx/lib/libmlx.so: undefined symbol: _ZN3mlx4core11Convolution8eval_gpuERKSt6vectorINS0_5arrayESaIS3_EERS3_

@NripeshN

Copy link
Copy Markdown
Contributor Author

I'm getting this:

ImportError: /home/goniz/Work/mlx/python/mlx/lib/libmlx.so: undefined symbol: _ZN3mlx4core11Convolution8eval_gpuERKSt6vectorINS0_5arrayESaIS3_EERS3_

I forgot to test the Python build my bad, can you try it now?

Unfortunately I might not be able to help after it compiles, I don't have an AMD GPU to run tests😔 I've tried replicating most things from cuda, so hopefully it works

@goniz

goniz commented Jan 26, 2026

Copy link
Copy Markdown

Now fails on load with this:

>>> import mlx.core
Traceback (most recent call last):
  File "<python-input-0>", line 1, in <module>
    import mlx.core
ImportError: /home/goniz/Work/mlx/python/mlx/lib/libmlx.so: undefined symbol: hiprtcCompileProgram

@goniz

goniz commented Jan 26, 2026

Copy link
Copy Markdown

Unfortunately I might not be able to help after it compiles, I don't have an AMD GPU to run tests😔 I've tried replicating most things from cuda, so hopefully it works

Omg I don't believe you did it without AMD card 😱😱

@NripeshN

NripeshN commented Jan 26, 2026

Copy link
Copy Markdown
Contributor Author

Now fails on load with this:

The latest push hopefully fixes the undefined symbol error Found the issue, working on the fix😩

Omg I don't believe you did it without AMD card 😱😱

Haha docker literally saves me and humbles me at the same time

@goniz

goniz commented Jan 26, 2026

Copy link
Copy Markdown
image

@goniz

goniz commented Jan 26, 2026

Copy link
Copy Markdown

I might got over excited:
image

@NripeshN

Copy link
Copy Markdown
Contributor Author

Wait it works?😅

Ah unfortunately unless a magic fairy sends me a PC with AMD GPU I cannot help after this😭 With the ram prices I doubt the magic fairy has the funds either🥲

@goniz

goniz commented Jan 26, 2026

Copy link
Copy Markdown

Latest commit broke something:
image

@NripeshN

Copy link
Copy Markdown
Contributor Author

Lemme try adding a fix for both the issues above actually. I had just made a stub implementation earlier.

@NripeshN

Copy link
Copy Markdown
Contributor Author

@goniz give the last push a try maybe. It might not work but you will definitely not have the same error atleast☺️

@goniz

goniz commented Jan 26, 2026

Copy link
Copy Markdown

mlx rocm-support ? ❯︎ python3 qwen3.py 
Fetching 9 files: 100%|██████| 9/9 [00:00<00:00, 201864.90it/s]
Download complete: : 0.00B [00:00, ?B/s]              ?, ?it/s]
==========
Traceback (most recent call last):
  File "/home/goniz/Work/mlx/qwen3.py", line 15, in <module>
    text = generate(model, tokenizer, prompt=prompt, verbose=True)
  File "/home/goniz/Work/mlx/venv/lib/python3.14/site-packages/mlx_lm/generate.py", line 762, in generate
    for response in stream_generate(model, tokenizer, prompt, **kwargs):
                    ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/goniz/Work/mlx/venv/lib/python3.14/site-packages/mlx_lm/generate.py", line 699, in stream_generate
    for n, (token, logprobs, from_draft) in enumerate(token_generator):
                                            ~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/goniz/Work/mlx/venv/lib/python3.14/site-packages/mlx_lm/generate.py", line 689, in <genexpr>
    (token, logprobs, False) for token, logprobs in token_generator
                                                    ^^^^^^^^^^^^^^^
  File "/home/goniz/Work/mlx/venv/lib/python3.14/site-packages/mlx_lm/generate.py", line 432, in generate_step
    mx.eval([c.state for c in prompt_cache])
    ~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Unsupported dtype for affine_dequantize

@NripeshN

Copy link
Copy Markdown
Contributor Author

Might fix it(????)

@goniz

goniz commented Jan 26, 2026

Copy link
Copy Markdown

mlx rocm-support ? ❯︎ python3 qwen3.py 
Fetching 9 files: 100%|███████| 9/9 [00:00<00:00, 28575.88it/s]
Download complete: : 0.00B [00:00, ?B/s]              ?, ?it/s]
==========
Traceback (most recent call last):
  File "/home/goniz/Work/mlx/qwen3.py", line 15, in <module>
    text = generate(model, tokenizer, prompt=prompt, verbose=True)
  File "/home/goniz/Work/mlx/venv/lib/python3.14/site-packages/mlx_lm/generate.py", line 762, in generate
    for response in stream_generate(model, tokenizer, prompt, **kwargs):
                    ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/goniz/Work/mlx/venv/lib/python3.14/site-packages/mlx_lm/generate.py", line 699, in stream_generate
    for n, (token, logprobs, from_draft) in enumerate(token_generator):
                                            ~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/goniz/Work/mlx/venv/lib/python3.14/site-packages/mlx_lm/generate.py", line 689, in <genexpr>
    (token, logprobs, False) for token, logprobs in token_generator
                                                    ^^^^^^^^^^^^^^^
  File "/home/goniz/Work/mlx/venv/lib/python3.14/site-packages/mlx_lm/generate.py", line 432, in generate_step
    mx.eval([c.state for c in prompt_cache])
    ~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: QuantizedMatmul has no ROCm implementation.

Geramy added 30 commits June 23, 2026 20:58
qmm.hip: v_dot4_i32_iu8 integer-dot quantized matmul for 4-bit affine decode,
on both the dense QuantizedMatmul path and the MoE GatherQMM path. Activation
is quantized to int8 (per-group scale + exact sum) in a prepass; weight nibbles
feed __builtin_amdgcn_sudot4 with int32 accumulate, then affine dequant
(scale*dx*<dot> + bias*Sx). Gated by MLX_QMV_IDOT (dense) and
MLX_QMV_IDOT_GATHER (gather, experimental); both off by default. Coherent
(~0.27% L2 vs scalar). +7% on dense 4B; neutral on MoE (gather is latency-bound).

copy.hip: env-gated (MLX_COUNT_COPIES) histogram of general copies by shape,
to attribute copy_gg/copy_g kernels to model ops. No overhead when disabled.
gather_qmv_wide_kernel: flat per-(batch,col) warp with 128-bit (uint4 = 32
nibbles) weight loads instead of the per-group 32-bit loop, for 4-bit affine
MoE-expert decode. Per-group affine applied inline; activation read from global
(L2-cached). Requires K%32==0, group_size in {32,64,128}. Gated MLX_QMV_WIDE,
off by default. Microbench: scattered-expert DRAM read 130->295 GB/s on R9700
(2.3x), 98->150 on gfx1151. Coherent. In-engine TPS-neutral on the 35B (experts
are the smaller matmul share; dense qmv_tiled already uses wide loads + dot4),
kept as a measured, available path.
…% dense

qmv_wide_kernel: one warp per output column, flat full-K loop with 128-bit
(uint4 = 32 nibbles) weight loads, per-group affine applied inline, activation
from global. Replaces qmv_tiled on the 4-bit affine decode path under
MLX_QMV_WIDE. Root cause: qmv_tiled achieves only ~108 GB/s on the real 8MB
per-layer matmul while a simple wide-load streaming kernel sustains ~175 (a
kernel inefficiency, not matmul size or interleaving — verified by rotating-DRAM
microbench). Requires K%32==0, group_size in {32,64,128}.

Measured (gfx1151, greedy decode), MLX_QMV_WIDE on vs off:
  dense Qwen3.5-4B:   54.2 -> 61.9 tok/s (+14%)
  dense Qwen3.6-27B:  11.3 -> 13.5 tok/s (+20%)
  35B-A3B MoE:        38.4 -> 40.3 (+5%, dense proj is smaller share)
All coherent. Win scales with dense matmul share. MLX_QMV_WIDE now covers both
this dense path and the gather path.
The dense qmv_wide path is a measured +14-20% win on dense 4-bit decode and
never regresses, so enable it by default for the qualifying decode case (4-bit
affine, K%32==0, group_size in {32,64,128}, M<=8). Both the dense qmv_wide and
the gather_qmv_wide paths now default on; set MLX_QMV_NO_WIDE=1 to revert to
qmv_tiled / the warp-shared gather.
Behind MLX_GRAPH_REPLAY=1 (default off): once the decode topology is locked
(two-consecutive-token key match), buffer a token's kernel params instead of
rebuilding the graph node-by-node, then re-point a drained cached exec
(hipGraphExecKernelNodeSetParams) and relaunch — eliminating ~1000
hipGraphAddKernelNode calls/token. Replay execs are kept in a dedicated pool
(replay_pool_) whose src_nodes are the instantiation graph's nodes, so re-point
is valid (unlike ExecUpdate-managed exec_pool_ slots).

The re-point HIT path works; the first-build/grow self-instantiate path still
hangs in hipGraphInstantiate (WIP). Default (graph + per-token rebuild) path
unchanged.
…pt-in, WIP)

Replaces the earlier per-node ExecKernelNodeSetParams approach. Behind
MLX_GRAPH_REPLAY=1 (default off): once the decode topology locks (two-consecutive
key match), the normal path force-grows the exec pool to N slots
(MLX_GRAPH_REPLAY_SLOTS, default 4), each keeping its OWN persistent source graph.
Replay then re-points a DRAINED slot's source nodes (hipGraphKernelNodeSetParams,
idle so by-pointer safe) + hipGraphExecUpdate + relaunch — no AddKernelNode
rebuild. ExecUpdate is the standard, by-pointer replay path (source graph + kernarg
Packs kept alive per slot).

HIT path works (replays multiple tokens coherently). Remaining WIP: after ~2 HITs
the engine spin-waits on the replayed token's AtomicEvent completion atomic (never
set past 2 tokens) — a completion-signaling interaction with the one-behind
pipeline, not a kernel fault or OOM. Default (per-token rebuild) path unchanged.
… + force-grow

The pool is force-grown to N instantiated slots (MLX_GRAPH_REPLAY_SLOTS, default 4)
via the normal path; these are never ExecUpdate'd, so their src_nodes are the exec's
own instantiation nodes. Replay re-points a drained slot's exec directly with
hipGraphExecKernelNodeSetParams + relaunch (no rebuild). This replaces the
KernelNodeSetParams-on-persistent-source + ExecUpdate path, which returned success
but produced a faulting exec.

Replays are now bit-identical to eager for the first couple tokens, then the GPU
faults (use-after-free / graph-exec-relaunch lifetime interaction — see
GRAPH_DECODE_PLAN.md, completion-gated ownership). Still opt-in (MLX_GRAPH_REPLAY=1,
default off); default per-token-rebuild path unchanged and coherent.
Re-introduce gpu_kv_pos_set / gpu_kv_pos_increment — in-place mutation of a
fixed-address [1] int32 device scalar (the decode position). Keeping the address
fixed and advancing the value in place (between replays, loop-owned) lets a
capture-once HIP graph bake the address and have RoPE/causal-mask/KV-write read
the position device-side, so they advance every pure replay (no SetParams, no
rebuild). Foundation for documented capture-once + pure-relaunch decode; unused
until the capture path is wired.
The wide-load dense and gather QMV kernels are default-on (opt out with
MLX_QMV_NO_WIDE); no MLX_QMV_WIDE env is ever read. Correct the two kernel
header comments that still described the old opt-in prototype gating.
Adds an opt-in bump allocator that hands out identical device addresses for an
identical allocation sequence across token resets. During steady-state decode
the op tape is identical every token, so rewinding a fixed backing region to
its base before each token makes allocation #N land at the same address every
token — the precondition for relaunching a graph built once (its nodes' baked
input/output pointers stay valid and outputs land where the next token's eval
expects them).

API: decode_arena_begin/reset/end (+ active/high_water/overflowed), bridged to
mlx::core for the engine loop. malloc/malloc_async serve from the arena when
armed; free no-ops arena buffers (identified by pointer range, so RocmBuffer
stays POD); on overflow the alloc falls back to the normal pool. Fully inert
and byte-identical when disarmed (default).
Adds a clean build-once decode path that pairs with the deterministic decode
arena. record mode lets each per-token commit instantiate as normal but keeps
the exec + its source graph + kernarg packs in an ordered chain. replay mode
relaunches the next recorded exec verbatim at each commit and discards the
freshly-built nodes — no hipGraphExecKernelNodeSetParams, no hipGraphExecUpdate
(both corrupt by-value-struct kernels on ROCm). Because the arena makes every
per-token buffer address identical, the recorded exec's baked pointers stay
valid every token, so a bit-identical relaunch is correct.

maybe_commit() chunks deterministically (same node sequence → same boundaries),
so the chain aligns one-to-one between record and replay; a chunk-count mismatch
disables pure mode and falls back to the normal build path. Bridged to mlx::core
as decode_pure_record/replay/off/chain_len for the engine loop. Inert by default
(mode 0); the production path is unchanged and verified coherent.
… input

Feeds the freshly-sampled token (on a transient arena address) into the fixed-
address graph-decode input buffer so the build-once decode graph's embedding
gather reads the new token without the buffer being reallocated between
relaunches. Runs on-stream, ordered after the producing forward.
decode_pure_chain_ becomes a 2-slot array (record/replay take a slot) and adds
gpu_buffer_copy (immediate device buffer copy) so the decode loop can move GDN
scratch state into the read buffer between relaunches. With the engine's
decoupled state write-back this makes one-graph build-once decode produce output
bit-matching eager on gfx1151.
…t loss

On the integrated APU eager decode (~65 TPS) beats the rebuild-every-token graph
path (~49 TPS): graph build/instantiate overhead exceeds the launch-batching win
when launches are cheap. Default off; opt in with MLX_USE_HIP_GRAPHS=1 (launch-
bound discrete GPU, or build-once replay). Also adds dormant
decode_pure_relaunch_all for the upcoming relaunch-only replay path.
Pack each expert-run's tokens into the WMMA M dimension (BM=64) so MoE prefill
runs matrix-matrix instead of M=1 gather-QMV, reading each expert's q4 weights
once and reusing across its tokens. Register-tiled (MT4/NT2), 128-bit uint4
weight+activation loads, bit-level nibble unpack.

Indices copied host-side via hipMemcpy (device pointers; reading .data() on the
host corrupted the heap on discrete gfx1201). Gated behind MLX_ROCM_GROUPED_PREFILL
(default off) while kernel throughput + GPU-side grouping are tuned.

1021-tok prefill, Qwen3.6-35B-A3B q4: gfx1151 358->415 tok/s (+16%),
gfx1201 305->400 (+31%). Output coherent.
Replace the per-call host sort + device sync with on-device histogram/scan/scatter
(keys = expert in [0,E)); runs map 1:1 to experts so the GEMM reads run_id as its
expert. Only a 4-byte max_run_len readback remains (to bound grid.x to the largest
run instead of B -> avoids ~98% empty token-tiles). Drops the 2xB host index copy
and the full-pipeline sync. Parity at 417/401 tok/s with cleaner async path; sets
up the occupancy rewrite. Output coherent.
use_hip_graphs() now branches on the engine's decode-region flag: prefill
(multi-token) consults MLX_HIP_GRAPH_PREFILL, decode (single-token) consults
MLX_HIP_GRAPH_DECODE. Legacy MLX_USE_HIP_GRAPHS=1 enables both. Default off.

Lets prefill be graph-captured while decode stays eager (decode graph was a net
loss on the APU). Grouped MoE prefill is made capture-safe: under graphs it uses
the data-independent grid (ceil(B/64)) instead of the host max_run_len readback
(a host sync is illegal inside capture). gfx1151 256-tok prefill 247->265 with
grouped+prefill-graph; output coherent, decode unaffected.
… gather)

Block goes 1->4 warps; the 4 warps share one gathered sX tile (each owns a 32-col
slab of a 128-col block tile), so the activation gather runs once per 4 N-tiles
instead of once per N-tile (~4x less X traffic) and 4 warps are resident per WGP
for latency hiding. The MoE GEMM was running at ~5% of peak (occupancy-starved).
Per-column bounds added for N%128!=0. gfx1151 verified coherent:
2371-tok 468->587 (+25%), 8529-tok 472->611 (+29%). Closes most of the gap to
llama.cpp Vulkan (703 @ 2367 on gfx1151).
Long-context prefill on the R9700 (gfx1201) produced intermittent garbage (~50%
of runs, identical input) — independent of precision (e4m3/bf16 both fail) and
GEMM library (hipBLASLt/rocBLAS both fail). Root-caused to the stream-ordered
async pool: hipMallocAsync re-hands a block whose prior owner's work hasn't
drained. Confirmed it's the pool's reuse, not our free discipline:
synchronizing the stream before every hipFreeAsync does NOT help; only avoiding
hipMallocAsync (MLX_ROCM_NO_ASYNC_POOL) does. The existing code comment already
warns hipMallocAsync faults inside the driver under graph workloads — same root.

Gate the pool off for gfx1201 (falls back to the unified/slab path; the pool
gives ~0% on the APU and only the racy +5% on the R9700). gfx1151 keeps it.
Override with MLX_ROCM_FORCE_ASYNC_POOL=1. dev1 2371-tok prefill now 0/10 garbage
(was ~5/10), coherent at 8529 tok.
Optional override: gfx1201 uses fp8 e4m3 for half-precision activation GEMMs by
default (preferred_gemm_precision). e4m3 trades mantissa for range; this env
forces bf16 for accuracy-sensitive runs. Default off (keeps e4m3 on RDNA4).
Two fixes so HIP graph mode no longer aborts at build:
- Skip true no-op launches (any zero grid/block dim): hipLaunchKernel tolerates
  them but hipGraphAddKernelNode rejects them as invalid-argument.
- When hipGraphAddKernelNode still fails (some single-block reduction/new
  kernels are rejected), graph-split like launch_kernel() instead of aborting:
  flush the accumulated graph, run the kernel eagerly (ordered, same stream),
  continue building fresh. Graph decode now builds + runs coherent on gfx1151
  and gfx1201 (was: hipGraphAddKernelNode invalid argument).

Note: the new decode kernels are not yet graph-addable, so they split the graph
each layer -> decode-graph is currently slower than eager until they're migrated
to add_kernel_node-compatible launches.
The kernel now grid-strides over each expert-run's token-tiles, so grid.x is a
bounded occupancy knob instead of a correctness requirement. This removes the
max_run_len readback, which needed a full pipeline-draining enc.synchronize()
per gather_qmm call (~2x/layer x N layers) and serialized prefill — the main
drag on small-input PP/s. grid.x = ~4x the average run's tiles (env
MLX_GRIDX_MULT): enough parallelism for routing imbalance, far fewer empty
blocks than ceil(B/64), no sync. Coherent on both archs through 8529 tok.

Qwen3.6-35B-A3B q4 prefill (grouped), vs prior readback path:
  gfx1151: 298t 300->341, 2371t 601->668, 8529t 610->701 (~= llama.cpp 706)
  gfx1201: 8529t 918 (coherent; was garbage pre pool-fix)
QuantizedMatmul prefill (the dense q4/q6/q8 affine projections) can now run as a
single fused WMMA GEMM via add_kernel_node instead of dequant+rocBLAS. Unpacks
packed weights in-kernel (128-bit loads, bit-level; q6 cross-32b-boundary) — no
bf16 weight materialization — so it (a) is a single graph node (no rocBLAS graph
split, the prerequisite for whole-graph prefill replay) and (b) avoids the dequant
memory traffic. 4 warps/block share the gathered x tile; grid-strides over M.

Gated: MLX_ROCM_WMMA_QMM=1, or auto-on under HIP graphs; eager default keeps
rocBLAS (MLX_ROCM_WMMA_QMM=0 to force off). q4 verified coherent + on par with
rocBLAS eager (341 vs 340 @298) and clean in graph mode; q6/q8 follow the
validated standalone unpack (no q6/q8 model on hand to run).
Generalizes the decode build-once/replay machinery to prefill chunks: prefill_key_
+ pending_prefill_key_ lock the stable full-chunk topology, use_execupdate +
replay activation extended to prefill (MLX_GRAPH_PREFILL_REPLAY=1), plus an
anti-hang drain in the replay slot-search (prefill's coarse pipeline can miss a
drained slot; sync instead of instantiate-during-replay which hangs on ROCm).

INCOMPLETE: prefill chunks still split into many fragments (conv via gemm_conv,
etc.), and the single-key replay assumes ~1 fragment per chunk — so enabling it
currently spins over the per-fragment instantiate. Needs conv as a graph node
(WMMA QMM already removed the matmul splits) to reach ~1 fragment. Gated off;
default path unaffected.
…ill, gated)

The manual hipGraphAddKernelNode rejects some of our kernels with invalid-argument
(WMMA QMM, conv unfold/naive_gemm) — they were graph-splitting and shattering the
prefill chunk into ~200 fragments, which made single-key replay spin. Mirror
CUDA's capture model: when the node API rejects a kernel (add_kernel_node_kp) or
for library launches (launch_kernel), wrap the launch in hipStreamBeginCapture/
EndCapture and fold the captured child graph in via add_child_graph_node. Gated to
prefill replay mode (MLX_GRAPH_PREFILL_REPLAY=1); eager + plain graph mode keep
the split path.

Result: prefill chunk goes from ~208 fragments to ~1 — coherent, 620 tok/s @2371
(capture, no replay). This is the prerequisite for whole-chunk replay; the replay
machinery itself (single-key build-skip) still needs work and remains gated off.
…ExecUpdate

Build-skip replay (replay_active_, buffer params/no launch) is incompatible with
the prefill capture path (capture needs the launch), so it's now decode-only.
Prefill replay instead re-captures each chunk + ExecUpdate's a cached exec
(cheaper than re-instantiate). Runs correctly + coherent, no spin (was: spin from
the build-skip/capture conflict). Graph prefill is whole (~1 fragment) but still
~5-8% under eager (per-chunk capture/build cost); a true build-once cross-prompt
replay needs captured nodes to support build-skip — the remaining architectural
piece. All gated behind MLX_GRAPH_PREFILL_REPLAY (default off).
Capture the entire single-token decode forward into one HIP graph (graphs
off so every kernel launches on the stream), instantiate once, and relaunch
that exec verbatim each token. Output is bit-identical to eager.

- CommandEncoder::decode_capture_begin/end_record/replay/destroy: whole-step
  stream capture + instantiate + relaunch.
- worker: skip hipLaunchHostFunc while capturing (not capturable, and a
  captured host node would re-fire on every replay); signal inline instead.
- decode arena: back with fine-grained device memory via rocm_unified_malloc
  instead of hipMallocManaged (managed added GPU page-migration cost to every
  relaunch); add a replay floor so per-token sampling allocates above the
  recorded exec's baked buffers.
- decode_inline_launch_count() bridge + captured-graph node-type debug.

Gated by the engine's MLX_DECODE_GRAPH_PURE; default decode path unchanged.
decode_capture_end_record/replay take a slot (0/1) and the encoder keeps one
captured exec per parity, so the engine can record two graphs (read [0] write
[2]; read [2] write [0]) and relaunch by token parity with no per-token state
copy. Single-slot behavior is unchanged (slot 0).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add ROCm Support for AMD GPUs