Shortcuts

Source code for torch.nn.modules.rnn

import math
import torch
import warnings
import itertools
import numbers

from .module import Module
from ..parameter import Parameter
from ..utils.rnn import PackedSequence
from .. import init

_VF = torch._C._VariableFunctions
_rnn_impls = {
    'LSTM': _VF.lstm,
    'GRU': _VF.gru,
    'RNN_TANH': _VF.rnn_tanh,
    'RNN_RELU': _VF.rnn_relu,
}


class RNNBase(Module):

    def __init__(self, mode, input_size, hidden_size,
                 num_layers=1, bias=True, batch_first=False,
                 dropout=0, bidirectional=False):
        super(RNNBase, self).__init__()
        self.mode = mode
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bias = bias
        self.batch_first = batch_first
        self.dropout = dropout
        self.bidirectional = bidirectional
        num_directions = 2 if bidirectional else 1

        if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
                isinstance(dropout, bool):
            raise ValueError("dropout should be a number in range [0, 1] "
                             "representing the probability of an element being "
                             "zeroed")
        if dropout > 0 and num_layers == 1:
            warnings.warn("dropout option adds dropout after all but last "
                          "recurrent layer, so non-zero dropout expects "
                          "num_layers greater than 1, but got dropout={} and "
                          "num_layers={}".format(dropout, num_layers))

        if mode == 'LSTM':
            gate_size = 4 * hidden_size
        elif mode == 'GRU':
            gate_size = 3 * hidden_size
        elif mode == 'RNN_TANH':
            gate_size = hidden_size
        elif mode == 'RNN_RELU':
            gate_size = hidden_size
        else:
            raise ValueError("Unrecognized RNN mode: " + mode)

        self._all_weights = []
        for layer in range(num_layers):
            for direction in range(num_directions):
                layer_input_size = input_size if layer == 0 else hidden_size * num_directions

                w_ih = Parameter(torch.Tensor(gate_size, layer_input_size))
                w_hh = Parameter(torch.Tensor(gate_size, hidden_size))
                b_ih = Parameter(torch.Tensor(gate_size))
                b_hh = Parameter(torch.Tensor(gate_size))
                layer_params = (w_ih, w_hh, b_ih, b_hh)

                suffix = '_reverse' if direction == 1 else ''
                param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}']
                if bias:
                    param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}']
                param_names = [x.format(layer, suffix) for x in param_names]

                for name, param in zip(param_names, layer_params):
                    setattr(self, name, param)
                self._all_weights.append(param_names)

        self.flatten_parameters()
        self.reset_parameters()

    def flatten_parameters(self):
        """Resets parameter data pointer so that they can use faster code paths.

        Right now, this works only if the module is on the GPU and cuDNN is enabled.
        Otherwise, it's a no-op.
        """
        any_param = next(self.parameters()).data
        if not any_param.is_cuda or not torch.backends.cudnn.is_acceptable(any_param):
            return

        # If any parameters alias, we fall back to the slower, copying code path. This is
        # a sufficient check, because overlapping parameter buffers that don't completely
        # alias would break the assumptions of the uniqueness check in
        # Module.named_parameters().
        all_weights = self._flat_weights
        unique_data_ptrs = set(p.data_ptr() for p in all_weights)
        if len(unique_data_ptrs) != len(all_weights):
            return

        with torch.cuda.device_of(any_param):
            import torch.backends.cudnn.rnn as rnn

            # NB: This is a temporary hack while we still don't have Tensor
            # bindings for ATen functions
            with torch.no_grad():
                # NB: this is an INPLACE function on all_weights, that's why the
                # no_grad() is necessary.
                torch._cudnn_rnn_flatten_weight(
                    all_weights, (4 if self.bias else 2),
                    self.input_size, rnn.get_cudnn_mode(self.mode), self.hidden_size, self.num_layers,
                    self.batch_first, bool(self.bidirectional))

    def _apply(self, fn):
        ret = super(RNNBase, self)._apply(fn)
        self.flatten_parameters()
        return ret

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            init.uniform_(weight, -stdv, stdv)

    def check_forward_args(self, input, hidden, batch_sizes):
        is_input_packed = batch_sizes is not None
        expected_input_dim = 2 if is_input_packed else 3
        if input.dim() != expected_input_dim:
            raise RuntimeError(
                'input must have {} dimensions, got {}'.format(
                    expected_input_dim, input.dim()))
        if self.input_size != input.size(-1):
            raise RuntimeError(
                'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
                    self.input_size, input.size(-1)))

        if is_input_packed:
            mini_batch = int(batch_sizes[0])
        else:
            mini_batch = input.size(0) if self.batch_first else input.size(1)

        num_directions = 2 if self.bidirectional else 1
        expected_hidden_size = (self.num_layers * num_directions,
                                mini_batch, self.hidden_size)

        def check_hidden_size(hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
            if tuple(hx.size()) != expected_hidden_size:
                raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))

        if self.mode == 'LSTM':
            check_hidden_size(hidden[0], expected_hidden_size,
                              'Expected hidden[0] size {}, got {}')
            check_hidden_size(hidden[1], expected_hidden_size,
                              'Expected hidden[1] size {}, got {}')
        else:
            check_hidden_size(hidden, expected_hidden_size)

    def forward(self, input, hx=None):
        is_packed = isinstance(input, PackedSequence)
        if is_packed:
            input, batch_sizes = input
            max_batch_size = int(batch_sizes[0])
        else:
            batch_sizes = None
            max_batch_size = input.size(0) if self.batch_first else input.size(1)

        if hx is None:
            num_directions = 2 if self.bidirectional else 1
            hx = input.new_zeros(self.num_layers * num_directions,
                                 max_batch_size, self.hidden_size,
                                 requires_grad=False)
            if self.mode == 'LSTM':
                hx = (hx, hx)

        self.check_forward_args(input, hx, batch_sizes)
        _impl = _rnn_impls[self.mode]
        if batch_sizes is None:
            result = _impl(input, hx, self._flat_weights, self.bias, self.num_layers,
                           self.dropout, self.training, self.bidirectional, self.batch_first)
        else:
            result = _impl(input, batch_sizes, hx, self._flat_weights, self.bias,
                           self.num_layers, self.dropout, self.training, self.bidirectional)
        output = result[0]
        hidden = result[1:] if self.mode == 'LSTM' else result[1]

        if is_packed:
            output = PackedSequence(output, batch_sizes)
        return output, hidden

    def extra_repr(self):
        s = '{input_size}, {hidden_size}'
        if self.num_layers != 1:
            s += ', num_layers={num_layers}'
        if self.bias is not True:
            s += ', bias={bias}'
        if self.batch_first is not False:
            s += ', batch_first={batch_first}'
        if self.dropout != 0:
            s += ', dropout={dropout}'
        if self.bidirectional is not False:
            s += ', bidirectional={bidirectional}'
        return s.format(**self.__dict__)

    def __setstate__(self, d):
        super(RNNBase, self).__setstate__(d)
        if 'all_weights' in d:
            self._all_weights = d['all_weights']
        if isinstance(self._all_weights[0][0], str):
            return
        num_layers = self.num_layers
        num_directions = 2 if self.bidirectional else 1
        self._all_weights = []
        for layer in range(num_layers):
            for direction in range(num_directions):
                suffix = '_reverse' if direction == 1 else ''
                weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', 'bias_hh_l{}{}']
                weights = [x.format(layer, suffix) for x in weights]
                if self.bias:
                    self._all_weights += [weights]
                else:
                    self._all_weights += [weights[:2]]

    @property
    def _flat_weights(self):
        return list(self._parameters.values())

    @property
    def all_weights(self):
        return [[getattr(self, weight) for weight in weights] for weights in self._all_weights]


[docs]class RNN(RNNBase): r"""Applies a multi-layer Elman RNN with :math:`tanh` or :math:`ReLU` non-linearity to an input sequence. For each element in the input sequence, each layer computes the following function: .. math:: h_t = \text{tanh}(w_{ih} x_t + b_{ih} + w_{hh} h_{(t-1)} + b_{hh}) where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the previous layer at time `t-1` or the initial hidden state at time `0`. If :attr:`nonlinearity` is `'relu'`, then `ReLU` is used instead of `tanh`. Args: input_size: The number of expected features in the input `x` hidden_size: The number of features in the hidden state `h` num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` would mean stacking two RNNs together to form a `stacked RNN`, with the second RNN taking in outputs of the first RNN and computing the final results. Default: 1 nonlinearity: The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh' bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` batch_first: If ``True``, then the input and output tensors are provided as `(batch, seq, feature)`. Default: ``False`` dropout: If non-zero, introduces a `Dropout` layer on the outputs of each RNN layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False`` Inputs: input, h_0 - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features of the input sequence. The input can also be a packed variable length sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence` or :func:`torch.nn.utils.rnn.pack_sequence` for details. - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor containing the initial hidden state for each element in the batch. Defaults to zero if not provided. If the RNN is bidirectional, num_directions should be 2, else it should be 1. Outputs: output, h_n - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor containing the output features (`h_k`) from the last layer of the RNN, for each `k`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output will also be a packed sequence. For the unpacked case, the directions can be separated using ``output.view(seq_len, batch, num_directions, hidden_size)``, with forward and backward being direction `0` and `1` respectively. Similarly, the directions can be separated in the packed case. - **h_n** (num_layers * num_directions, batch, hidden_size): tensor containing the hidden state for `k = seq_len`. Like *output*, the layers can be separated using ``h_n.view(num_layers, num_directions, batch, hidden_size)``. Attributes: weight_ih_l[k]: the learnable input-hidden weights of the k-th layer, of shape `(hidden_size * input_size)` for `k = 0`. Otherwise, the shape is `(hidden_size * hidden_size)` weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer, of shape `(hidden_size * hidden_size)` bias_ih_l[k]: the learnable input-hidden bias of the k-th layer, of shape `(hidden_size)` bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer, of shape `(hidden_size)` .. note:: All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{hidden\_size}}` .. include:: cudnn_persistent_rnn.rst Examples:: >>> rnn = nn.RNN(10, 20, 2) >>> input = torch.randn(5, 3, 10) >>> h0 = torch.randn(2, 3, 20) >>> output, hn = rnn(input, h0) """ def __init__(self, *args, **kwargs): if 'nonlinearity' in kwargs: if kwargs['nonlinearity'] == 'tanh': mode = 'RNN_TANH' elif kwargs['nonlinearity'] == 'relu': mode = 'RNN_RELU' else: raise ValueError("Unknown nonlinearity '{}'".format( kwargs['nonlinearity'])) del kwargs['nonlinearity'] else: mode = 'RNN_TANH' super(RNN, self).__init__(mode, *args, **kwargs)
[docs]class LSTM(RNNBase): r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence. For each element in the input sequence, each layer computes the following function: .. math:: \begin{array}{ll} \\ i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\ f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\ g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{(t-1)} + b_{hg}) \\ o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\ c_t = f_t c_{(t-1)} + i_t g_t \\ h_t = o_t \tanh(c_t) \\ \end{array} where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer at time `t-1` or the initial hidden state at time `0`, and :math:`i_t`, :math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget, cell, and output gates, respectively. :math:`\sigma` is the sigmoid function. In a multilayer LSTM, the input :math:`i^{(l)}_t` of the :math:`l` -th layer (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)_t}` is a Bernoulli random variable which is :math:`0` with probability :attr:`dropout`. Args: input_size: The number of expected features in the input `x` hidden_size: The number of features in the hidden state `h` num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` would mean stacking two LSTMs together to form a `stacked LSTM`, with the second LSTM taking in outputs of the first LSTM and computing the final results. Default: 1 bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False`` dropout: If non-zero, introduces a `Dropout` layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False`` Inputs: input, (h_0, c_0) - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features of the input sequence. The input can also be a packed variable length sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence` or :func:`torch.nn.utils.rnn.pack_sequence` for details. - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor containing the initial hidden state for each element in the batch. If the RNN is bidirectional, num_directions should be 2, else it should be 1. - **c_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor containing the initial cell state for each element in the batch. If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero. Outputs: output, (h_n, c_n) - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor containing the output features `(h_t)` from the last layer of the LSTM, for each t. If a :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output will also be a packed sequence. For the unpacked case, the directions can be separated using ``output.view(seq_len, batch, num_directions, hidden_size)``, with forward and backward being direction `0` and `1` respectively. Similarly, the directions can be separated in the packed case. - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor containing the hidden state for `t = seq_len`. Like *output*, the layers can be separated using ``h_n.view(num_layers, num_directions, batch, hidden_size)`` and similarly for *c_n*. - **c_n** (num_layers * num_directions, batch, hidden_size): tensor containing the cell state for `t = seq_len` Attributes: weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer `(W_ii|W_if|W_ig|W_io)`, of shape `(4*hidden_size x input_size)` weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer `(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size x hidden_size)` bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer `(b_ii|b_if|b_ig|b_io)`, of shape `(4*hidden_size)` bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer `(b_hi|b_hf|b_hg|b_ho)`, of shape `(4*hidden_size)` .. note:: All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{hidden\_size}}` .. include:: cudnn_persistent_rnn.rst Examples:: >>> rnn = nn.LSTM(10, 20, 2) >>> input = torch.randn(5, 3, 10) >>> h0 = torch.randn(2, 3, 20) >>> c0 = torch.randn(2, 3, 20) >>> output, (hn, cn) = rnn(input, (h0, c0)) """ def __init__(self, *args, **kwargs): super(LSTM, self).__init__('LSTM', *args, **kwargs)
[docs]class GRU(RNNBase): r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence. For each element in the input sequence, each layer computes the following function: .. math:: \begin{array}{ll} r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ n_t = \tanh(W_{in} x_t + b_{in} + r_t (W_{hn} h_{(t-1)}+ b_{hn})) \\ h_t = (1 - z_t) n_t + z_t h_{(t-1)} \end{array} where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`, :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively. :math:`\sigma` is the sigmoid function. In a multilayer GRU, the input :math:`i^{(l)}_t` of the :math:`l` -th layer (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)_t}` is a Bernoulli random variable which is :math:`0` with probability :attr:`dropout`. Args: input_size: The number of expected features in the input `x` hidden_size: The number of features in the hidden state `h` num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` would mean stacking two GRUs together to form a `stacked GRU`, with the second GRU taking in outputs of the first GRU and computing the final results. Default: 1 bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False`` dropout: If non-zero, introduces a `Dropout` layer on the outputs of each GRU layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False`` Inputs: input, h_0 - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features of the input sequence. The input can also be a packed variable length sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence` for details. - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor containing the initial hidden state for each element in the batch. Defaults to zero if not provided. If the RNN is bidirectional, num_directions should be 2, else it should be 1. Outputs: output, h_n - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor containing the output features h_t from the last layer of the GRU, for each t. If a :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output will also be a packed sequence. For the unpacked case, the directions can be separated using ``output.view(seq_len, batch, num_directions, hidden_size)``, with forward and backward being direction `0` and `1` respectively. Similarly, the directions can be separated in the packed case. - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor containing the hidden state for `t = seq_len` Like *output*, the layers can be separated using ``h_n.view(num_layers, num_directions, batch, hidden_size)``. Attributes: weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer (W_ir|W_iz|W_in), of shape `(3*hidden_size x input_size)` weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer (W_hr|W_hz|W_hn), of shape `(3*hidden_size x hidden_size)` bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer (b_ir|b_iz|b_in), of shape `(3*hidden_size)` bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer (b_hr|b_hz|b_hn), of shape `(3*hidden_size)` .. note:: All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{hidden\_size}}` .. include:: cudnn_persistent_rnn.rst Examples:: >>> rnn = nn.GRU(10, 20, 2) >>> input = torch.randn(5, 3, 10) >>> h0 = torch.randn(2, 3, 20) >>> output, hn = rnn(input, h0) """ def __init__(self, *args, **kwargs): super(GRU, self).__init__('GRU', *args, **kwargs)
class RNNCellBase(Module): def __init__(self, input_size, hidden_size, bias, num_chunks): super(RNNCellBase, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.bias = bias self.weight_ih = Parameter(torch.Tensor(num_chunks * hidden_size, input_size)) self.weight_hh = Parameter(torch.Tensor(num_chunks * hidden_size, hidden_size)) if bias: self.bias_ih = Parameter(torch.Tensor(num_chunks * hidden_size)) self.bias_hh = Parameter(torch.Tensor(num_chunks * hidden_size)) else: self.register_parameter('bias_ih', None) self.register_parameter('bias_hh', None) self.reset_parameters() def extra_repr(self): s = '{input_size}, {hidden_size}' if 'bias' in self.__dict__ and self.bias is not True: s += ', bias={bias}' if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh": s += ', nonlinearity={nonlinearity}' return s.format(**self.__dict__) def check_forward_input(self, input): if input.size(1) != self.input_size: raise RuntimeError( "input has inconsistent input_size: got {}, expected {}".format( input.size(1), self.input_size)) def check_forward_hidden(self, input, hx, hidden_label=''): if input.size(0) != hx.size(0): raise RuntimeError( "Input batch size {} doesn't match hidden{} batch size {}".format( input.size(0), hidden_label, hx.size(0))) if hx.size(1) != self.hidden_size: raise RuntimeError( "hidden{} has inconsistent hidden_size: got {}, expected {}".format( hidden_label, hx.size(1), self.hidden_size)) def reset_parameters(self): stdv = 1.0 / math.sqrt(self.hidden_size) for weight in self.parameters(): init.uniform_(weight, -stdv, stdv)
[docs]class RNNCell(RNNCellBase): r"""An Elman RNN cell with tanh or ReLU non-linearity. .. math:: h' = \tanh(w_{ih} x + b_{ih} + w_{hh} h + b_{hh}) If :attr:`nonlinearity` is `'relu'`, then ReLU is used in place of tanh. Args: input_size: The number of expected features in the input `x` hidden_size: The number of features in the hidden state `h` bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` nonlinearity: The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh' Inputs: input, hidden - **input** of shape `(batch, input_size)`: tensor containing input features - **hidden** of shape `(batch, hidden_size)`: tensor containing the initial hidden state for each element in the batch. Defaults to zero if not provided. Outputs: h' - **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state for each element in the batch Attributes: weight_ih: the learnable input-hidden weights, of shape `(hidden_size x input_size)` weight_hh: the learnable hidden-hidden weights, of shape `(hidden_size x hidden_size)` bias_ih: the learnable input-hidden bias, of shape `(hidden_size)` bias_hh: the learnable hidden-hidden bias, of shape `(hidden_size)` .. note:: All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{hidden\_size}}` Examples:: >>> rnn = nn.RNNCell(10, 20) >>> input = torch.randn(6, 3, 10) >>> hx = torch.randn(3, 20) >>> output = [] >>> for i in range(6): hx = rnn(input[i], hx) output.append(hx) """ def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh"): super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1) self.nonlinearity = nonlinearity def forward(self, input, hx=None): self.check_forward_input(input) if hx is None: hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False) self.check_forward_hidden(input, hx) if self.nonlinearity == "tanh": func = _VF.rnn_tanh_cell elif self.nonlinearity == "relu": func = _VF.rnn_relu_cell else: raise RuntimeError( "Unknown nonlinearity: {}".format(self.nonlinearity)) return func( input, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh, )
[docs]class LSTMCell(RNNCellBase): r"""A long short-term memory (LSTM) cell. .. math:: \begin{array}{ll} i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o \tanh(c') \\ \end{array} where :math:`\sigma` is the sigmoid function. Args: input_size: The number of expected features in the input `x` hidden_size: The number of features in the hidden state `h` bias: If `False`, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` Inputs: input, (h_0, c_0) - **input** of shape `(batch, input_size)`: tensor containing input features - **h_0** of shape `(batch, hidden_size)`: tensor containing the initial hidden state for each element in the batch. - **c_0** of shape `(batch, hidden_size)`: tensor containing the initial cell state for each element in the batch. If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero. Outputs: h_1, c_1 - **h_1** of shape `(batch, hidden_size)`: tensor containing the next hidden state for each element in the batch - **c_1** of shape `(batch, hidden_size)`: tensor containing the next cell state for each element in the batch Attributes: weight_ih: the learnable input-hidden weights, of shape `(4*hidden_size x input_size)` weight_hh: the learnable hidden-hidden weights, of shape `(4*hidden_size x hidden_size)` bias_ih: the learnable input-hidden bias, of shape `(4*hidden_size)` bias_hh: the learnable hidden-hidden bias, of shape `(4*hidden_size)` .. note:: All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{hidden\_size}}` Examples:: >>> rnn = nn.LSTMCell(10, 20) >>> input = torch.randn(6, 3, 10) >>> hx = torch.randn(3, 20) >>> cx = torch.randn(3, 20) >>> output = [] >>> for i in range(6): hx, cx = rnn(input[i], (hx, cx)) output.append(hx) """ def __init__(self, input_size, hidden_size, bias=True): super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4) def forward(self, input, hx=None): self.check_forward_input(input) if hx is None: hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False) hx = (hx, hx) self.check_forward_hidden(input, hx[0], '[0]') self.check_forward_hidden(input, hx[1], '[1]') return _VF.lstm_cell( input, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh, )
[docs]class GRUCell(RNNCellBase): r"""A gated recurrent unit (GRU) cell .. math:: \begin{array}{ll} r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \end{array} where :math:`\sigma` is the sigmoid function. Args: input_size: The number of expected features in the input `x` hidden_size: The number of features in the hidden state `h` bias: If `False`, then the layer does not use bias weights `b_ih` and `b_hh`. Default: `True` Inputs: input, hidden - **input** of shape `(batch, input_size)`: tensor containing input features - **hidden** of shape `(batch, hidden_size)`: tensor containing the initial hidden state for each element in the batch. Defaults to zero if not provided. Outputs: h' - **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state for each element in the batch Attributes: weight_ih: the learnable input-hidden weights, of shape `(3*hidden_size x input_size)` weight_hh: the learnable hidden-hidden weights, of shape `(3*hidden_size x hidden_size)` bias_ih: the learnable input-hidden bias, of shape `(3*hidden_size)` bias_hh: the learnable hidden-hidden bias, of shape `(3*hidden_size)` .. note:: All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{hidden\_size}}` Examples:: >>> rnn = nn.GRUCell(10, 20) >>> input = torch.randn(6, 3, 10) >>> hx = torch.randn(3, 20) >>> output = [] >>> for i in range(6): hx = rnn(input[i], hx) output.append(hx) """ def __init__(self, input_size, hidden_size, bias=True): super(GRUCell, self).__init__(input_size, hidden_size, bias, num_chunks=3) def forward(self, input, hx=None): self.check_forward_input(input) if hx is None: hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False) self.check_forward_hidden(input, hx) return _VF.gru_cell( input, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh, )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources