chainer.functions.sigmoid_cross_entropy¶
-
chainer.functions.
sigmoid_cross_entropy
(x, t, normalize=True, reduce='mean')[source]¶ Computes cross entropy loss for pre-sigmoid activations.
- Parameters
x (
Variable
or N-dimensional array) – A variable object holding a matrix whose (i, j)-th element indicates the unnormalized log probability of the j-th unit at the i-th example.t (
Variable
or N-dimensional array) – A variable object holding a matrix whose (i, j)-th element indicates a signed integer vector of ground truth labels 0 or 1. Ift[i, j] == -1
, correspondingx[i, j]
is ignored. Loss is zero if all ground truth labels are-1
.normalize (bool) – Variable holding a boolean value which determines the normalization constant. If true, this function normalizes the cross entropy loss across all instances. If else, it only normalizes along a batch size.
reduce (str) – Variable holding a
str
which determines whether to reduce the shape of the input. If it is'mean'
, it computes the sum of cross entropy and normalize it according tonormalize
option. If is is'no'
, this function computes cross entropy for each instance and does not normalize it (normalize
option is ignored). In this case, the loss value of the ignored instance, which has-1
as its target value, is set to0
.
- Returns
A variable object holding an array of the cross entropy. If
reduce
is'mean'
, it is a scalar array. Ifreduce
is'no'
, the shape is same as those ofx
andt
.- Return type
Note
This function is differentiable only by
x
.Example
>>> x = np.array([[-2.0, 3.0, 0.5], [5.0, 2.0, -0.5]]).astype(np.float32) >>> x array([[-2. , 3. , 0.5], [ 5. , 2. , -0.5]], dtype=float32) >>> t = np.array([[0, 1, 0], [1, 1, -1]]).astype(np.int32) >>> t array([[ 0, 1, 0], [ 1, 1, -1]], dtype=int32) >>> F.sigmoid_cross_entropy(x, t) variable(0.25664714) >>> F.sigmoid_cross_entropy(x, t, normalize=False) variable(0.64161783) >>> y = F.sigmoid_cross_entropy(x, t, reduce='no') >>> y.shape (2, 3) >>> y.array array([[ 0.126928 , 0.04858735, 0.974077 ], [ 0.00671535, 0.126928 , -0. ]], dtype=float32)