chainer.gradient_check.check_backward¶
-
chainer.gradient_check.check_backward(func, x_data, y_grad, params=(), eps=0.001, atol=1e-05, rtol=0.0001, no_grads=None, dtype=None, detect_nondifferentiable=False)[source]¶ Test backward procedure of a given function.
This function automatically checks the backward-process of a given function to ensure that the computed gradients are approximately correct. For example, assuming you’ve defined a
FunctionNodeclassMyFunc, that takes two arguments and returns one value, you can wrap it in a ordinary function and check its gradient computations as follows:def func(xs): y, = MyFunc().apply(xs) return y x1_data = xp.array(...) x2_data = xp.array(...) gy_data = xp.array(...) check_backward(func, (x1_data, x2_data), gy_data)
This function creates
Variableobjects withx_dataand callsfuncwith theVariables to get its result asVariable. Then, it setsy_gradarray togradattribute of the result and callsbackwardmethod to get gradients of the inputs. To check correctness of the gradients, the function callsnumerical_grad()to calculate numerically the gradients and compares the types of gradients withchainer.testing.assert_allclose().To reduce computational time, it uses directional derivative along a random vector. A function \(g: \mathbb{R} \rightarrow \mathbb{R}^n\) is defined as \(g(\delta) = f(x + \delta r)\), where \(\delta \in \mathbb{R}\), \(r \in \mathbb{R}^n\) is a random vector and \(f\) is a function which you want to test. Its gradient is
\[g'(\delta) = f'(x + \delta r) \cdot r.\]Therefore, \(g'(0) = f'(x) \cdot r\). So we can check the correctness of back propagation of \(f\) indirectly by comparing this equation with the gradient of \(g\) numerically calculated and that of \(f\) computed by backprop. If \(r\) is chosen from uniform distribution, we can conclude with high probability that the gradient of \(f\) itself is correct.
If the function is non-differentiable with respect to some input objects, we can check its backprop to such objects by
no_gradsargument.gradient_checkcomputes numerical backward to inputs that correspond toFalseinno_grads. It also asserts that the backprop leaves gradientsNonefor inputs that correspond toTrueinno_grads. The default ofno_gradsargument is the tuple of truth values whether input objects (x1_dataor/andx2_datain this example) represent integer variables.You can simplify a test when
MyFuncgets only one argument:check_backward(func, x1_data, gy_data)
If
MyFuncis a loss function which returns a zero-dimensional array, passNonetogy_data. In this case, it sets1togradattribute of the result:check_backward(my_loss_func, (x1_data, x2_data), None)
If
MyFuncreturns multiple outputs, pass all gradients for outputs as a tuple:gy1_data = xp.array(...) gy2_data = xp.array(...) check_backward(func, x1_data, (gy1_data, gy2_data))
You can also test a
Link. To check gradients of parameters of the link, set a tuple of the parameters toparamsarguments:check_backward(my_link, (x1_data, x2_data), gy_data, (my_link.W, my_link.b))
Note that
paramsare notndarrays, butVariabless.Function objects are acceptable as
funcargument:check_backward(lambda x1, x2: f(x1, x2), (x1_data, x2_data), gy_data)
Note
funcis called many times to get numerical gradients for all inputs. This function doesn’t work correctly whenfuncbehaves randomly as it gets different gradients.- Parameters
func (callable) – A function which gets
Variables and returnsVariables.funcmust returns a tuple ofVariables or oneVariable. You can use aFunction,FunctionNodeor aLinkobject or any other function satisfying the condition.x_data (ndarray or tuple of ndarrays) – A set of
ndarrays to be passed tofunc. Ifx_datais onendarrayobject, it is treated as(x_data,).y_grad (ndarray or tuple of ndarrays or None) – A set of
ndarrays representing gradients of return-values offunc. Ify_gradis onendarrayobject, it is treated as(y_grad,). Iffuncis a loss-function,y_gradshould be set toNone.params (Variable or tuple of ~chainder.Variable) – A set of
Variables whose gradients are checked. Whenfuncis aLinkobject, set its parameters asparams. Ifparamsis oneVariableobject, it is treated as(params,).eps (float) – Epsilon value to be passed to
numerical_grad().atol (float) – Absolute tolerance to be passed to
chainer.testing.assert_allclose().rtol (float) – Relative tolerance to be passed to
chainer.testing.assert_allclose().no_grads (list of bool) – Flag to skip variable for gradient assertion. It should be same length as
x_data.dtype (dtype) –
x_data,y_gradandparamsare casted to this dtype when calculating numerical gradients. Only float types andNoneare allowed.detect_nondifferentiable (bool) – If
True, check for non-differentiable inputs is enabled. Iffuncis non-differentiable atx_data,check_backwardraisesNondifferentiableError.
See also