pwtools.rbf.core.Rbf¶
- class pwtools.rbf.core.Rbf(points, values, rbf='gauss', r=0, p='mean', fit=True)[source]¶
Bases:
object
Radial basis function network interpolation and regression.
Notes
Array shape API is as in
PolyFit
.\(\texttt{f}: \mathbb R^{m\times n} \rightarrow \mathbb R^m\) if
points.ndim=2
, which is the setting when training and vectorized evaluation.>>> X.shape (m,n) >>> y.shape (m,) >>> f=Rbf(X,y) >>> f(X).shape (m,)
\(\texttt{f}: \mathbb R^n \rightarrow \mathbb R\) if
points.ndim=1
.>>> x.shape (n,) >>> f(x).shape ()
Examples
>>> from pwtools import mpl, rbf >>> import numpy as np >>> # 1D example w/ derivatives. For 1d, we need to use points[:,None] >>> # input array containing training (dd.DY) and interpolation >>> # (ddi.XY) points must be 2d. >>> fig,ax = mpl.fig_ax() >>> x=linspace(0,10,20) # shape (M,), M=20 points >>> z=sin(x) # shape (M,) >>> rbfi=rbf.Rbf(x[:,None], z, r=1e-10) >>> xi=linspace(0,10,100) # shape (M,), M=100 points >>> ax.plot(x,z,'o', label='data') >>> ax.plot(xi, sin(xi), label='sin(x)') >>> ax.plot(xi, rbfi(xi[:,None]), label='rbf') >>> ax.plot(xi, cos(xi), label='cos(x)') >>> ax.plot(xi, rbfi(xi[:,None],der=1)[:,0], label='d(rbf)/dx') >>> ax.legend() >>> # 2D example >>> x = np.linspace(-3,3,10) >>> dd = mpl.Data2D(x=x, y=x) >>> dd.update(Z=np.sin(dd.X)+np.cos(dd.Y)) >>> rbfi=rbf.Rbf(dd.XY, dd.zz, r=1e-10) >>> xi=linspace(-3,3,50) >>> ddi = mpl.Data2D(x=xi, y=xi) >>> fig1,ax1 = mpl.fig_ax3d() >>> ax1.scatter(dd.xx, dd.yy, dd.zz, label='data', color='r') >>> ax1.plot_wireframe(ddi.X, ddi.Y, rbfi(ddi.XY).reshape(50,50), ... label='rbf') >>> ax1.set_xlabel('x'); ax1.set_ylabel('y'); >>> ax1.legend() >>> fig2,ax2 = mpl.fig_ax3d() >>> offset=2 >>> ax2.plot_wireframe(ddi.X, ddi.Y, rbfi(ddi.XY).reshape(50,50), ... label='rbf', color='b') >>> ax2.plot_wireframe(ddi.X, ddi.Y, ... rbfi(ddi.XY, der=1)[:,0].reshape(50,50)+offset, ... color='g', label='d(rbf)/dx') >>> ax2.plot_wireframe(ddi.X, ddi.Y, ... rbfi(ddi.XY, der=1)[:,1].reshape(50,50)+2*offset, ... color='r', label='d(rbf)/dy') >>> ax2.set_xlabel('x'); ax2.set_ylabel('y'); >>> ax2.legend()
- __init__(points, values, rbf='gauss', r=0, p='mean', fit=True)[source]¶
- Parameters:
points (2d array, (M,N)) – data points : M points in N-dim space, training set points
values (1d array, (M,)) – function values at training points
rbf (str (see rbf_dct.keys()) or callable rbf(r**2, p)) – RBF definition
r (float or None) – regularization parameter, if None then we use a least squares solver
p (‘mean’ or ‘scipy’ (see
estimate_p()
) or float) – the RBF’s free parameterfit (bool) – call
fit()
in__init__()
- __call__(*args, **kwargs)[source]¶
-
- Parameters:
points (2d array (L,N) or (N,)) – L N-dim points to evaluate the model on.
der (int) – If 1 return (matrix of) partial derivatives (see
deriv()
andderiv_jax()
), else model prediction values (default).
- Returns:
vals (1d array (L,) or scalar) – Interpolated values.
or
derivs (2d array (L,N) or (N,)) – 1st partial derivatives.
Methods
deriv
(points)Analytic first partial derivatives.
deriv_jax
(points)Partial derivs from jax.
fit
()Solve linear system for the weights.
fit_error
(points, values)Sum of squared fit errors with penalty on negative p.
get_distsq
([points])Matrix of distance values \(R_{ij} = |\mathbf x_i - \mathbf c_j|\).
Return
(p,r)
.predict
(points)Evaluate model at points.