Experimental Support for GPUs/TPUs
The current development branch dev/jax implements
experimental support for GPUs/TPUs.
Although OQuPy is built on top of the backend-agnostic
TensorNetwork library,
OQuPy uses vanilla NumPy and SciPy throughout its implementation.
The dev/jax branch adds support for GPUs/TPUs via the
JAX library. A new
oqupy.backends.numerical_backend.py module handles the
breaking changes in JAX
NumPy,
while the rest of the modules utilizes numpy and scipy.linalg
instances from there without explicitly importing JAX-based libraries.
Enabling Experimental Features
To enable experimental features, switch to the dev/jax branch and use
from oqupy.backends import enable_jax_features
enable_jax_features()
Alternatively, the OQUPY_BACKEND environmental variable may be set to jax to initialize the jax backend by default.
Contributing Guidelines
To contribute features compatible with the JAX backend, please adhere to the following set of guidelines:
avoid wildcard imports of NumPy and SciPy.
use
from oqupy.backends.numerical_backend import npinstead ofimport numpy as npand use the aliasdefault_npin cases vanilla NumPy is explicitly required.use
from oqupy.backends.numerical_backend import lainstead ofimport scipy.linalg as la, except that for non-symmetric eigen-decomposition,scipy.linalg.eigshould be used.use one of
np.dtype_complex(np.dtype_float) oroqupy.config.NumPyDtypeComplex(oqupy.config.NumPyDtypeFloat) instead ofnp.complex_(np.float_).convert lists or tuples to arrays when passing them as arguments inside functions.
use
array = np.update(array, indices, values)instead ofarray[indices] = values.use
np.get_random_floats(seed, shape)instead ofnp.random.default_rng(seed).random(shape).declare signatures for
np.vectorizeexplicitly.avoid directly changing the
shapeattribute of an array (use.reshapeinstead)