torch.cuda¶
This package adds support for CUDA tensor types, that implement the same function as CPU tensors, but they utilize GPUs for computation.
It is lazily initialized, so you can always import it, and use
is_available()
to determine if your system supports CUDA.
CUDA semantics has more details about working with CUDA.
-
class
torch.cuda.
device
(device)[source]¶ Context-manager that changes the selected device.
Parameters: device (torch.device or int) – device index to select. It’s a no-op if this argument is a negative integer or None
.
-
torch.cuda.
device_ctx_manager
¶ alias of
torch.cuda.device
-
class
torch.cuda.
device_of
(obj)[source]¶ Context-manager that changes the current device to that of given object.
You can use both tensors and storages as arguments. If a given object is not allocated on a GPU, this is a no-op.
Parameters: obj (Tensor or Storage) – object allocated on the selected device.
-
torch.cuda.
empty_cache
()[source]¶ Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.
Note
empty_cache()
doesn’t increase the amount of GPU memory available for PyTorch. See Memory management for more details about GPU memory management.
-
torch.cuda.
get_device_capability
(device)[source]¶ Gets the cuda capability of a device.
Parameters: device (torch.device or int, optional) – device for which to return the device capability. This function is a no-op if this argument is a negative integer. Uses the current device, given by current_device()
, ifdevice
isNone
(default).Returns: the major and minor cuda capability of the device Return type: tuple(int, int)
-
torch.cuda.
get_device_name
(device)[source]¶ Gets the name of a device.
Parameters: device (torch.device or int, optional) – device for which to return the name. This function is a no-op if this argument is a negative integer. Uses the current device, given by current_device()
, ifdevice
isNone
(default).
-
torch.cuda.
init
()[source]¶ Initialize PyTorch’s CUDA state. You may need to call this explicitly if you are interacting with PyTorch via its C API, as Python bindings for CUDA functionality will not be until this initialization takes place. Ordinary users should not need this, as all of PyTorch’s CUDA methods automatically initialize CUDA state on-demand.
Does nothing if the CUDA state is already initialized.
-
torch.cuda.
max_memory_allocated
(device=None)[source]¶ Returns the maximum GPU memory usage by tensors in bytes for a given device.
Parameters: device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by current_device()
, ifdevice
isNone
(default).Note
See Memory management for more details about GPU memory management.
-
torch.cuda.
max_memory_cached
(device=None)[source]¶ Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
Parameters: device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by current_device()
, ifdevice
isNone
(default).Note
See Memory management for more details about GPU memory management.
-
torch.cuda.
memory_allocated
(device=None)[source]¶ Returns the current GPU memory usage by tensors in bytes for a given device.
Parameters: device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by current_device()
, ifdevice
isNone
(default).Note
This is likely less than the amount shown in nvidia-smi since some unused memory can be held by the caching allocator and some context needs to be created on GPU. See Memory management for more details about GPU memory management.
-
torch.cuda.
memory_cached
(device=None)[source]¶ Returns the current GPU memory managed by the caching allocator in bytes for a given device.
Parameters: device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by current_device()
, ifdevice
isNone
(default).Note
See Memory management for more details about GPU memory management.
-
torch.cuda.
set_device
(device)[source]¶ Sets the current device.
Usage of this function is discouraged in favor of
device
. In most cases it’s better to useCUDA_VISIBLE_DEVICES
environmental variable.Parameters: device (torch.device or int) – selected device. This function is a no-op if this argument is negative.
-
torch.cuda.
stream
(stream)[source]¶ Context-manager that selects a given stream.
All CUDA kernels queued within its context will be enqueued on a selected stream.
Parameters: stream (Stream) – selected stream. This manager is a no-op if it’s None
.Note
Streams are per-device, and this function changes the “current stream” only for the currently selected device. It is illegal to select a stream that belongs to a different device.
-
torch.cuda.
synchronize
()[source]¶ Waits for all kernels in all streams on current device to complete.
Random Number Generator¶
-
torch.cuda.
get_rng_state
(device=-1)[source]¶ Returns the random number generator state of the current GPU as a ByteTensor.
Parameters: device (int, optional) – The device to return the RNG state of. Default: -1 (i.e., use the current device). Warning
This function eagerly initializes CUDA.
-
torch.cuda.
set_rng_state
(new_state, device=-1)[source]¶ Sets the random number generator state of the current GPU.
Parameters: new_state (torch.ByteTensor) – The desired state
-
torch.cuda.
manual_seed
(seed)[source]¶ Sets the seed for generating random numbers for the current GPU. It’s safe to call this function if CUDA is not available; in that case, it is silently ignored.
Parameters: seed (int) – The desired seed. Warning
If you are working with a multi-GPU model, this function is insufficient to get determinism. To seed all GPUs, use
manual_seed_all()
.
-
torch.cuda.
manual_seed_all
(seed)[source]¶ Sets the seed for generating random numbers on all GPUs. It’s safe to call this function if CUDA is not available; in that case, it is silently ignored.
Parameters: seed (int) – The desired seed.
-
torch.cuda.
seed
()[source]¶ Sets the seed for generating random numbers to a random number for the current GPU. It’s safe to call this function if CUDA is not available; in that case, it is silently ignored.
Warning
If you are working with a multi-GPU model, this function will only initialize the seed on one GPU. To initialize all GPUs, use
seed_all()
.
Communication collectives¶
-
torch.cuda.comm.
broadcast
(tensor, devices)[source]¶ Broadcasts a tensor to a number of GPUs.
Parameters: - tensor (Tensor) – tensor to broadcast.
- devices (Iterable) – an iterable of devices among which to broadcast. Note that it should be like (src, dst1, dst2, …), the first element of which is the source device to broadcast from.
Returns: A tuple containing copies of the
tensor
, placed on devices corresponding to indices fromdevices
.
-
torch.cuda.comm.
broadcast_coalesced
(tensors, devices, buffer_size=10485760)[source]¶ Broadcasts a sequence tensors to the specified GPUs. Small tensors are first coalesced into a buffer to reduce the number of synchronizations.
Parameters: - tensors (sequence) – tensors to broadcast.
- devices (Iterable) – an iterable of devices among which to broadcast. Note that it should be like (src, dst1, dst2, …), the first element of which is the source device to broadcast from.
- buffer_size (int) – maximum size of the buffer used for coalescing
Returns: A tuple containing copies of the
tensor
, placed on devices corresponding to indices fromdevices
.
-
torch.cuda.comm.
reduce_add
(inputs, destination=None)[source]¶ Sums tensors from multiple GPUs.
All inputs should have matching shapes.
Parameters: Returns: A tensor containing an elementwise sum of all inputs, placed on the
destination
device.
-
torch.cuda.comm.
scatter
(tensor, devices, chunk_sizes=None, dim=0, streams=None)[source]¶ Scatters tensor across multiple GPUs.
Parameters: - tensor (Tensor) – tensor to scatter.
- devices (Iterable[int]) – iterable of ints, specifying among which devices the tensor should be scattered.
- chunk_sizes (Iterable[int], optional) – sizes of chunks to be placed on
each device. It should match
devices
in length and sum totensor.size(dim)
. If not specified, the tensor will be divided into equal chunks. - dim (int, optional) – A dimension along which to chunk the tensor.
Returns: A tuple containing chunks of the
tensor
, spread across givendevices
.
Streams and events¶
-
class
torch.cuda.
Stream
[source]¶ Wrapper around a CUDA stream.
A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See CUDA semantics for details.
Parameters: - device (torch.device or int, optional) – a device on which to allocate
the stream. If
device
isNone
(default) or a negative integer, this will use the current device. - priority (int, optional) – priority of the stream. Lower numbers represent higher priorities.
-
query
()[source]¶ Checks if all the work submitted has been completed.
Returns: A boolean indicating if all kernels in this stream are completed.
-
record_event
(event=None)[source]¶ Records an event.
Parameters: event (Event, optional) – event to record. If not given, a new one will be allocated. Returns: Recorded event.
-
synchronize
()[source]¶ Wait for all the kernels in this stream to complete.
Note
This is a wrapper around
cudaStreamSynchronize()
: see CUDA documentation for more info.
-
wait_event
(event)[source]¶ Makes all future work submitted to the stream wait for an event.
Parameters: event (Event) – an event to wait for. Note
This is a wrapper around
cudaStreamWaitEvent()
: see CUDA documentation for more info.This function returns without waiting for
event
: only future operations are affected.
-
wait_stream
(stream)[source]¶ Synchronizes with another stream.
All future work submitted to this stream will wait until all kernels submitted to a given stream at the time of call complete.
Parameters: stream (Stream) – a stream to synchronize. Note
This function returns without waiting for currently enqueued kernels in
stream
: only future operations are affected.
- device (torch.device or int, optional) – a device on which to allocate
the stream. If
Memory management¶
-
torch.cuda.
empty_cache
()[source] Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.
Note
empty_cache()
doesn’t increase the amount of GPU memory available for PyTorch. See Memory management for more details about GPU memory management.
-
torch.cuda.
memory_allocated
(device=None)[source] Returns the current GPU memory usage by tensors in bytes for a given device.
Parameters: device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by current_device()
, ifdevice
isNone
(default).Note
This is likely less than the amount shown in nvidia-smi since some unused memory can be held by the caching allocator and some context needs to be created on GPU. See Memory management for more details about GPU memory management.
-
torch.cuda.
max_memory_allocated
(device=None)[source] Returns the maximum GPU memory usage by tensors in bytes for a given device.
Parameters: device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by current_device()
, ifdevice
isNone
(default).Note
See Memory management for more details about GPU memory management.
-
torch.cuda.
memory_cached
(device=None)[source] Returns the current GPU memory managed by the caching allocator in bytes for a given device.
Parameters: device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by current_device()
, ifdevice
isNone
(default).Note
See Memory management for more details about GPU memory management.
-
torch.cuda.
max_memory_cached
(device=None)[source] Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
Parameters: device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by current_device()
, ifdevice
isNone
(default).Note
See Memory management for more details about GPU memory management.
NVIDIA Tools Extension (NVTX)¶
-
torch.cuda.nvtx.
mark
(msg)[source]¶ Describe an instantaneous event that occurred at some point.
Parameters: msg (string) – ASCII message to associate with the event.