Low-level control in Jax with shard_map and Pallas

2025-01-17T14:25

Notes from the OpenXLA talk: JAX: Low-level control with shard_map and Pallas. See references for the full video.

Shard-Map

"A souped-up parallel map" that parallelizes functions on multiple devices, with each device receiving a different shard from the input data.

"shmapped"

Custom Kernels

triton_call

Call triton code from Jax

pallas_call

Recommended. Expresses Kernes in Jax.

Custom kernels take in GPU array references (buffers), not arrays. Must be loaded before computations can be run. Refs support mutation, unlike the rest of Jax!

Pallas concepts like Grid and Block specs allow us to automatically pipeline memory access in TPU and run across multiple async threads in GPU.

Because matmul can be implemented recursively, we can break down a big matmul into a smaller one. Then, small kernels can be more effectively utilized on the hardware.

Grids: specify how many times we exec the kernel, and specify which instance of the kernel to execute. The grid, on GPU, is executed asynchronously in parallel, and on TPU is executed sequentially, pipelined.

Block shape: How do we break down the inputs and outputs into smaller components to be operated on by the kernel. Pallas will automatically care up arrays into the right block shape.

Index map: for a particular instance in the kernel in teh grid, which blocks should be inputs vs outputs.

Why Pallas?

Is there a program model difference btw Pallas and Triton?

Main difference: Pallas is higher level wrt memory access. You can say up-front what memory you're going to use and it allows for automatic scheduling.

Pallas does not have autotuning like Triton does.


References