How to use PYMC with MLX

How does one activate Pytensor/Pmyc to use the new MLX functionality?

pm.sample(compile_kwargs=dict(mode="MLX"))

Do note many operations are still missing and MLX seems comparatively slow for small dataset/ models, compared to any of the other backends.

1 Like

Let us know if ops you need are missing or if not, how did it go

Will do.

MLX is running on GPU. Amazing.Tried a large model with 242,380 observations for fun.

It is slower than nutpie but faster than native sampler: Nutpie = ~1hr10, PYMC native sampler ~3h55min, and MLX ~1h51.

1 Like

Nice! Thanks for testing it out. Hopefully we can speed that up even more, this is just the first rough cut. Did you have any heavy matmuls in the model, or was it mostly just basic element-wise stuff?

Heavy matmul, I think:

1 Like