pwtools.rbf.core.Rbf

class pwtools.rbf.core.Rbf(points, values, rbf='gauss', r=0, p='mean', fit=True)[source]

Bases: object

Radial basis function network interpolation and regression.

Notes

Array shape API is as in PolyFit.

\(\texttt{f}: \mathbb R^{m\times n} \rightarrow \mathbb R^m\) if points.ndim=2, which is the setting when training and vectorized evaluation.

>>> X.shape
(m,n)
>>> y.shape
(m,)
>>> f=Rbf(X,y)
>>> f(X).shape
(m,)

\(\texttt{f}: \mathbb R^n \rightarrow \mathbb R\) if points.ndim=1.

>>> x.shape
(n,)
>>> f(x).shape
()

Examples

>>> from pwtools import mpl, rbf
>>> import numpy as np
>>> # 1D example w/ derivatives. For 1d, we need to use points[:,None]
>>> # input array containing training (dd.DY) and interpolation
>>> # (ddi.XY) points must be 2d.
>>> fig,ax = mpl.fig_ax()
>>> x=linspace(0,10,20)     # shape (M,), M=20 points
>>> z=sin(x)                # shape (M,)
>>> rbfi=rbf.Rbf(x[:,None], z, r=1e-10)
>>> xi=linspace(0,10,100)   # shape (M,), M=100 points
>>> ax.plot(x,z,'o', label='data')
>>> ax.plot(xi, sin(xi), label='sin(x)')
>>> ax.plot(xi, rbfi(xi[:,None]), label='rbf')
>>> ax.plot(xi, cos(xi), label='cos(x)')
>>> ax.plot(xi, rbfi(xi[:,None],der=1)[:,0], label='d(rbf)/dx')
>>> ax.legend()
>>> # 2D example
>>> x = np.linspace(-3,3,10)
>>> dd = mpl.Data2D(x=x, y=x)
>>> dd.update(Z=np.sin(dd.X)+np.cos(dd.Y))
>>> rbfi=rbf.Rbf(dd.XY, dd.zz, r=1e-10)
>>> xi=linspace(-3,3,50)
>>> ddi = mpl.Data2D(x=xi, y=xi)
>>> fig1,ax1 = mpl.fig_ax3d()
>>> ax1.scatter(dd.xx, dd.yy, dd.zz, label='data', color='r')
>>> ax1.plot_wireframe(ddi.X, ddi.Y, rbfi(ddi.XY).reshape(50,50),
...                    label='rbf')
>>> ax1.set_xlabel('x'); ax1.set_ylabel('y');
>>> ax1.legend()
>>> fig2,ax2 = mpl.fig_ax3d()
>>> offset=2
>>> ax2.plot_wireframe(ddi.X, ddi.Y, rbfi(ddi.XY).reshape(50,50),
...                    label='rbf', color='b')
>>> ax2.plot_wireframe(ddi.X, ddi.Y,
...                    rbfi(ddi.XY, der=1)[:,0].reshape(50,50)+offset,
...                    color='g', label='d(rbf)/dx')
>>> ax2.plot_wireframe(ddi.X, ddi.Y,
...                    rbfi(ddi.XY, der=1)[:,1].reshape(50,50)+2*offset,
...                    color='r', label='d(rbf)/dy')
>>> ax2.set_xlabel('x'); ax2.set_ylabel('y');
>>> ax2.legend()
__init__(points, values, rbf='gauss', r=0, p='mean', fit=True)[source]
Parameters:
  • points (2d array, (M,N)) – data points : M points in N-dim space, training set points

  • values (1d array, (M,)) – function values at training points

  • rbf (str (see rbf_dct.keys()) or callable rbf(r**2, p)) – RBF definition

  • r (float or None) – regularization parameter, if None then we use a least squares solver

  • p (‘mean’ or ‘scipy’ (see estimate_p()) or float) – the RBF’s free parameter

  • fit (bool) – call fit() in __init__()

__call__(*args, **kwargs)[source]

Call predict() or deriv().

Parameters:
  • points (2d array (L,N) or (N,)) – L N-dim points to evaluate the model on.

  • der (int) – If 1 return (matrix of) partial derivatives (see deriv() and deriv_jax()), else model prediction values (default).

Returns:

  • vals (1d array (L,) or scalar) – Interpolated values.

  • or

  • derivs (2d array (L,N) or (N,)) – 1st partial derivatives.

Methods

deriv(points)

Analytic first partial derivatives.

deriv_jax(points)

Partial derivs from jax.

fit()

Solve linear system for the weights.

fit_error(points, values)

Sum of squared fit errors with penalty on negative p.

get_distsq([points])

Matrix of distance values \(R_{ij} = |\mathbf x_i - \mathbf c_j|\).

get_params()

Return (p,r).

predict(points)

Evaluate model at points.