shard_map
and Pallas2025-01-17T14:25
Notes from the OpenXLA talk: JAX: Low-level control with shard_map and Pallas. See references for the full video.
"A souped-up parallel map" that parallelizes functions on multiple devices, with each device receiving a different shard from the input data.
"shmapped"
triton_call
Call triton code from Jax
pallas_call
Recommended. Expresses Kernes in Jax.
Custom kernels take in GPU array references (buffers), not arrays. Must be loaded before computations can be run. Refs support mutation, unlike the rest of Jax!
Pallas concepts like Grid and Block specs allow us to automatically pipeline memory access in TPU and run across multiple async threads in GPU.
Because matmul can be implemented recursively, we can break down a big matmul into a smaller one. Then, small kernels can be more effectively utilized on the hardware.
Grids: specify how many times we exec the kernel, and specify which instance of the kernel to execute. The grid, on GPU, is executed asynchronously in parallel, and on TPU is executed sequentially, pipelined.
Block shape: How do we break down the inputs and outputs into smaller components to be operated on by the kernel. Pallas will automatically care up arrays into the right block shape.
Index map: for a particular instance in the kernel in teh grid, which blocks should be inputs vs outputs.
Main difference: Pallas is higher level wrt memory access. You can say up-front what memory you're going to use and it allows for automatic scheduling.
Pallas does not have autotuning like Triton does.