pwtools.rbf.core.Rbf.deriv_jax

Rbf.deriv_jax(points)[source]

Partial derivs from jax.

Same API as deriv(): grad for 1d input or vmap(grad) for 2d input.

>>> x.shape
(n,)
>>> grad(f)(x).shape
(n,)
>>> X.shape
(m,n)
>>> vmap(grad(self))(X).shape
(m,n)
Parameters:

points (2d array (L,N) or (N,)) –

Returns:

grads

Return type:

2d array (L,N) or (N,)

See also

deriv()