nixpkgs/pkgs/development/python-modules/jax/test-cuda.nix
Samuel Ainsworth 2f4ac6f803 python3Packages.jax: add CUDA tests in passthru
Organize CUDA-enabled tests for jax and jaxlib into passthru scripts to facilitate testing.
2024-02-28 02:50:45 +00:00

18 lines
285 B
Nix

{ jax
, jaxlib
, pkgs
}:
pkgs.writers.writePython3Bin "jax-test-cuda" { libraries = [ jax jaxlib ]; } ''
import jax
from jax import random
assert jax.devices()[0].platform == "gpu"
rng = random.PRNGKey(0)
x = random.normal(rng, (100, 100))
x @ x
print("success!")
''