jax.devices#
- jax.devices(backend=None)[source]#
Returns a list of all devices for a given backend.
Each device is represented by a subclass of
Device(e.g.CpuDevice,GpuDevice). The length of the returned list is equal todevice_count(backend). Local devices can be identified by comparingDevice.process_indexto the value returned byjax.process_index().If
backendisNone, returns all the devices from the default backend. The default backend is generally'gpu'or'tpu'if available, otherwise'cpu'.- Parameters:
backend (str | xla_client.Client | None) – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend:
'cpu','gpu', or'tpu'.- Return type:
list[xla_client.Device]
- Returns:
List of Device subclasses.