google/jax: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

20278
STARS
289
WATCHERS
1861
FORKS
1103
ISSUES

jax Recent Issues

Issue Title State Comments Created Date Updated Date
Hard to find documentation of predefined jax checkpoint policies open 0 2022-09-19 2022-09-20
TPU VM stops working after stop / start open 0 2022-09-19 2022-09-20
Running pytest locally -- against AttributeError: num_generated_cases open 2 2022-09-19 2022-09-20
unexpected nan with jax.vjp open 4 2022-09-18 2022-09-20
`jax.numpy.nan_to_num` function appears twice in Documentation open 0 2022-09-18 2022-09-20
Feature Request bessel function in scipy ( J0 J1 J2 Y0 Y1 Y2) which have applied in autograd. open 1 2022-09-17 2022-09-20
single function which subsumes different vmapped functions closed 2 2022-09-16 2022-09-20
scipy.optimize not found closed 1 2022-09-16 2022-09-11
jax.numpy.average can't take tuple axis open 1 2022-09-16 2022-09-20
Problem mixing checkify.check and jax.lax.switch open 2 2022-09-16 2022-09-20
TraceAnnotation not showing inside jax.jit open 6 2022-09-15 2022-09-20
Remove device_buffer and device_buffer properties from Array open 0 2022-09-15 2022-09-20
Add the ability to pass an `is_leaf` argument to `tree_reduce`. open 2 2022-09-15 2022-09-20
implicit uint64 array creation doesn't work like numpy closed 2 2022-09-15 2022-09-20
JAX Sets Basic Logging Config closed 2 2022-09-15 2022-09-20
Test documentation fails open 2 2022-09-14 2022-09-11
jax.random.randint silently ignores dtype of int64 instead of generating warning open 1 2022-09-14 2022-09-20
lax.scatter inconsistent results between CPU and GPU closed 2 2022-09-14 2022-09-20
Allow pjit to close over values sharded on multiple devices open 0 2022-09-14 2022-09-20
Unexpected CPU memory allocation when running on GPU with torch open 3 2022-09-13 2022-09-20
float0 should support addition, subtraction, and scalar multiplication open 1 2022-09-13 2022-09-20
jax.config.update('jax_platforms', 'gpu') fails closed 2 2022-09-12 2022-09-20
`jnp.squeeze` will fail unexpectedly with `jit` compilation closed 2 2022-09-12 2022-09-20
Large discrepancy in gradients between jitted/non-jitted code closed 2 2022-09-12 2022-09-20
[TPU] TypeError with JAX 0.3.17 in Google Collab open 1 2022-09-12 2022-09-20
Possible bug in `jax.lax.index_take` closed 1 2022-09-10 2022-09-20
Custom PytreeNode constructors were given object() when used with vmap. closed 3 2022-09-10 2022-09-20
jax[cuda] installation replaces current jax version with old jax-0.2.22 version closed 6 2022-09-09 2022-09-20
checkify does not work in pytree constructors open 1 2022-09-09 2022-09-20
Misleading error message in broadcast closed 1 2022-09-09 2022-09-20
Different outputs every time fft2 is called open 11 2022-09-08 2022-09-11
Add C++ Array support to pmap open 0 2022-09-08 2022-09-11
Add support to handle device and backend argument to lower_sharding_computation open 0 2022-09-08 2022-09-27
issubclass(jnp.float32, jnp.floating) does not work closed 5 2022-09-08 2022-09-27
Move GPU computation dispatch into a separate thread. open 0 2022-09-08 2022-09-27
jet of dynamic slice does not match jet of gather open 1 2022-09-07 2022-09-06
`jax.pure_callback` crashes on TPU VM open 1 2022-09-07 2022-09-06
Suggestion: jax.distributed.initialize should return an error/warning if JAX backend was already initialized open 0 2022-09-07 2022-09-06
Pickling a JAX array does not preserve its original device closed 6 2022-09-07 2022-09-15
Multiplying Nan by False give 0. instead of NaN open 5 2022-09-06 2022-09-06
Implementation of `scipy.signal.savgol_filter` (or at least `scipy.signal.savgol_coeffs`). open 2 2022-09-04 2022-09-06
Testable jit cache performance open 1 2022-09-02 2022-09-05
slice-based indexing is slow for repeated indexing open 1 2022-09-01 2022-09-11
More concise naming for tree utilities open 1 2022-09-01 2022-09-02
pure_callback passes jax.DeviceArray to the callback on CPU when not jitted closed 2 2022-09-01 2022-09-02
Deleting class instance doesn't free memory used by jitted method open 0 2022-08-31 2022-09-02
Build failure for jax v0.3.17 on Windows 10 open 6 2022-08-31 2022-09-02
jax.jnp gives NaN values closed 5 2022-08-31 2022-09-02
cannot import name 'stax' from 'jax.experimental' (in Jupyter notebook) closed 1 2022-08-30 2022-08-31
Gradient leakage through masked convolutions open 3 2022-08-30 2022-08-31

google's Other Repos