chainer.functions.tree_lstm¶
-
chainer.functions.
tree_lstm
(*inputs)[source]¶ TreeLSTM unit as an activation function.
This function implements TreeLSTM units both for N-ary TreeLSTM and Child-Sum TreeLSTM. Let the children cell states c1,c2,…,cN, and the incoming signal x.
First, the incoming signal x is split into (3 + N) arrays a,i,o,f1,f2,...,fN of the same shapes along the second axis. It means that x ‘s second axis must have (3 + N) times of the length of each cn.
The splitted input signals are corresponding to:
a : sources of cell input
i : sources of input gate
o : sources of output gate
fn : sources of forget gate for n-th ary
Second, it computes outputs as:
c=tanh(a)sigmoid(i)+c1sigmoid(f1),+c2sigmoid(f2),+...,+cNsigmoid(fN),h=tanh(c)sigmoid(o).These are returned as a tuple of (N + 1) variables.
- Parameters
inputs (list of
Variable
) – Variable arguments which include all cell vectors from child-nodes, and an input vector. Each of the cell vectors and the input vector isVariable
or N-dimensional array. The input vector must have the second dimension whose size is (N + 3) times of that of each cell, where N denotes the total number of cells.- Returns
Two
Variable
objectsc
andh
.c
is the updated cell state.h
indicates the outgoing signal.- Return type
See the papers for details: Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks and A Fast Unified Model for Parsing and Sentence Understanding.
Tai et al.’s N-Ary TreeLSTM is little extended in Bowman et al., and this link is based on the variant by Bowman et al. Specifically, eq. 10 in Tai et al. only has one W matrix to be applied to x, consistently for all children. On the other hand, Bowman et al.’s model has multiple matrices, each of which affects the forget gate for each child’s cell individually.
Example
Assuming
y
is the current input signal,c
is the previous cell state, andh
is the previous output signal from antree_lstm()
function. Each ofy
,c
andh
hasn_units
channels. Using 2-ary (binary) TreeLSTM, most typical preparation ofx
is:>>> model = chainer.Chain() >>> with model.init_scope(): ... model.w = L.Linear(10, 5 * 10) ... model.v1 = L.Linear(10, 5 * 10) ... model.v2 = L.Linear(10, 5 * 10) >>> y = np.random.uniform(-1, 1, (4, 10)).astype(np.float32) >>> h1 = np.random.uniform(-1, 1, (4, 10)).astype(np.float32) >>> h2 = np.random.uniform(-1, 1, (4, 10)).astype(np.float32) >>> c1 = np.random.uniform(-1, 1, (4, 10)).astype(np.float32) >>> c2 = np.random.uniform(-1, 1, (4, 10)).astype(np.float32) >>> x = model.w(y) + model.v1(h1) + model.v2(h2) >>> c, h = F.tree_lstm(c1, c2, x)
It corresponds to calculate the input sources a,i,o,f1,f2 from the current input
y
and the children’s outputsh1
andh2
. Different parameters are used for different kind of input sources.