"""
Radial Basis Function regression. See :ref:`rbf` for details.
"""
from pprint import pformat
##from functools import partial
import math
from pwtools import config
JAX_MODE = config.use_jax
if JAX_MODE:
import jax.numpy as np
from jax.config import config as jax_config
# Need double prec, else, the analytic derivs in Rbf.deriv() as well as
# the autodiff version Rbf.deriv_jax() are rubbish.
jax_config.update("jax_enable_x64", True)
from jax import grad, vmap, jit
import jax.scipy.linalg as jax_linalg
else:
import numpy as np
from scipy.spatial.distance import cdist
# no-op decorator in case we want to sprinkle @jit around instead of using
# the pattern
#
# if JAX_MODE:
# jax_func = jit(func)
#
[docs]
def jit(func, *args, **kwds):
return func
import scipy.linalg as linalg
[docs]
def rbf_gauss(rsq, p):
r"""Gaussian RBF :math:`\exp\left(-\frac{r^2}{2\,p^2}\right)`
Parameters
----------
rsq : float
squared distance :math:`r^2`
p : float
width
"""
return np.exp(-0.5 * rsq / p**2.0)
[docs]
def rbf_multi(rsq, p):
r"""Multiquadric RBF :math:`\sqrt{r^2 + p^2}`
Parameters
----------
rsq : float
squared distance :math:`r^2`
p : float
width
"""
return np.sqrt(rsq + p**2.0)
[docs]
def rbf_inv_multi(rsq, p):
r"""Inverse Multiquadric RBF :math:`\frac{1}{\sqrt{r^2 + p^2}}`
Parameters
----------
rsq : float
squared distance :math:`r^2`
p : float
width
"""
return 1 / rbf_multi(rsq, p)
rbf_dct = {
"gauss": rbf_gauss,
"multi": rbf_multi,
"inv_multi": rbf_inv_multi,
}
# Consider using jax-md space.py here
[docs]
def _np_distsq(aa, bb):
"""(Slow) pure numpy squared distance matrix.
Need that in JAX_MODE b/c we cannot diff thru
scipy.spatial.distance.cdist(). BUT: jax.jit() is crazy good (factor 40
faster than the numpy expression, factor almost 4 better than cdist (with
64 bit)!!
>>> f=lambda aa, bb: ((aa[:,None,:] - bb[None,...])**2.0).sum(-1)
>>> jf=jit(f)
>>> %timeit f(x,x)
39.3 ms ± 438 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit scipy.spatial.distance.cdist(x,x)
3.43 ms ± 9.25 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
>>> %timeit jf(x,x)
1.03 ms ± 53.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
"""
return ((aa[:, None, :] - bb[None, ...]) ** 2.0).sum(-1)
if JAX_MODE:
_jax_np_distsq = jit(_np_distsq)
_jax_np_dist = jit(lambda aa, bb: np.sqrt(_np_distsq(aa, bb)))
[docs]
def squared_dists(aa, bb):
if JAX_MODE:
return _jax_np_distsq(aa, bb)
else:
return cdist(aa, bb, metric="sqeuclidean")
[docs]
def euclidean_dists(aa, bb):
if JAX_MODE:
return _jax_np_dist(aa, bb)
else:
return cdist(aa, bb, metric="euclidean")
# I think this doesn't need to use jax.numpy, but due to circular deps core
# <-> hyperopt, we need that here. Once we switch to jax-only and drop
# JAX_MODE, we can use np (=numpy) here and jnp (=jax.numpy) elsewhere.
[docs]
def estimate_p(points, method="mean"):
r"""Estimate :math:`p`.
Parameters
----------
method : str
| 'mean' : :math:`1/M^2\,\sum_{ij} R_{ij}; M=\texttt{points.shape[0]}`
| 'scipy' : mean nearest neighbor distance
"""
if method == "mean":
return euclidean_dists(points, points).mean()
elif method == "scipy":
xi = points.T
ximax = np.amax(xi, axis=1)
ximin = np.amin(xi, axis=1)
edges = ximax - ximin
edges = edges[np.nonzero(edges)]
return np.power(np.prod(edges) / xi.shape[-1], 1.0 / edges.size)
else:
raise Exception(f"illegal method: {method}")
[docs]
class Rbf:
r"""Radial basis function network interpolation and regression.
Notes
-----
Array shape API is as in :class:`~pwtools.num.PolyFit`.
:math:`\texttt{f}: \mathbb R^{m\times n} \rightarrow \mathbb R^m` if
``points.ndim=2``, which is the setting when training and vectorized
evaluation.
>>> X.shape
(m,n)
>>> y.shape
(m,)
>>> f=Rbf(X,y)
>>> f(X).shape
(m,)
:math:`\texttt{f}: \mathbb R^n \rightarrow \mathbb R` if
``points.ndim=1``.
>>> x.shape
(n,)
>>> f(x).shape
()
Examples
--------
>>> from pwtools import mpl, rbf
>>> import numpy as np
>>> # 1D example w/ derivatives. For 1d, we need to use points[:,None]
>>> # input array containing training (dd.DY) and interpolation
>>> # (ddi.XY) points must be 2d.
>>> fig,ax = mpl.fig_ax()
>>> x=linspace(0,10,20) # shape (M,), M=20 points
>>> z=sin(x) # shape (M,)
>>> rbfi=rbf.Rbf(x[:,None], z, r=1e-10)
>>> xi=linspace(0,10,100) # shape (M,), M=100 points
>>> ax.plot(x,z,'o', label='data')
>>> ax.plot(xi, sin(xi), label='sin(x)')
>>> ax.plot(xi, rbfi(xi[:,None]), label='rbf')
>>> ax.plot(xi, cos(xi), label='cos(x)')
>>> ax.plot(xi, rbfi(xi[:,None],der=1)[:,0], label='d(rbf)/dx')
>>> ax.legend()
>>> # 2D example
>>> x = np.linspace(-3,3,10)
>>> dd = mpl.Data2D(x=x, y=x)
>>> dd.update(Z=np.sin(dd.X)+np.cos(dd.Y))
>>> rbfi=rbf.Rbf(dd.XY, dd.zz, r=1e-10)
>>> xi=linspace(-3,3,50)
>>> ddi = mpl.Data2D(x=xi, y=xi)
>>> fig1,ax1 = mpl.fig_ax3d()
>>> ax1.scatter(dd.xx, dd.yy, dd.zz, label='data', color='r')
>>> ax1.plot_wireframe(ddi.X, ddi.Y, rbfi(ddi.XY).reshape(50,50),
... label='rbf')
>>> ax1.set_xlabel('x'); ax1.set_ylabel('y');
>>> ax1.legend()
>>> fig2,ax2 = mpl.fig_ax3d()
>>> offset=2
>>> ax2.plot_wireframe(ddi.X, ddi.Y, rbfi(ddi.XY).reshape(50,50),
... label='rbf', color='b')
>>> ax2.plot_wireframe(ddi.X, ddi.Y,
... rbfi(ddi.XY, der=1)[:,0].reshape(50,50)+offset,
... color='g', label='d(rbf)/dx')
>>> ax2.plot_wireframe(ddi.X, ddi.Y,
... rbfi(ddi.XY, der=1)[:,1].reshape(50,50)+2*offset,
... color='r', label='d(rbf)/dy')
>>> ax2.set_xlabel('x'); ax2.set_ylabel('y');
>>> ax2.legend()
"""
[docs]
def __init__(self, points, values, rbf="gauss", r=0, p="mean", fit=True):
r"""
Parameters
----------
points : 2d array, (M,N)
data points : M points in N-dim space, training set points
values : 1d array, (M,)
function values at training points
rbf : str (see rbf_dct.keys()) or callable rbf(r**2, p)
RBF definition
r : float or None
regularization parameter, if None then we use a least squares
solver
p : 'mean' or 'scipy' (see :func:`estimate_p`) or float
the RBF's free parameter
fit : bool
call :meth:`fit` in :meth:`__init__`
"""
assert points.ndim == 2, "points must be 2d array"
assert values.ndim == 1, "values must be 1d array"
self.npoints = points.shape[0]
self.ndim = points.shape[1]
assert (
len(values) == self.npoints
), f"{len(values)=} != {self.npoints=}"
self.points = points
self.values = values
self.rbf = rbf_dct[rbf] if isinstance(rbf, str) else rbf
self.distsq = None
if isinstance(p, str):
if p == "mean":
# re-implement the 'mean' case here again since we can re-use
# distsq later (training data distance matrix)
self.distsq = self.get_distsq()
self.p = np.sqrt(self.distsq).mean()
elif p == "scipy":
self.p = estimate_p(points, "scipy")
else:
raise ValueError("p is not 'mean' or 'scipy'")
else:
self.p = p
self.r = r
if fit:
self.fit()
def __repr__(self):
attrs = ["p", "r", "rbf", "ndim"]
return "Rbf\n" + pformat(
dict([(kk, getattr(self, kk)) for kk in attrs])
)
def _rectify_points_shape(self, points):
ret = np.atleast_2d(points)
# fmt: off
assert (p_ndim := ret.shape[1]) == (f_ndim := self.ndim), (
f"points ndim doesn't match: got {p_ndim}, expect {f_ndim}")
# fmt: on
return ret
[docs]
def get_distsq(self, points=None):
r"""Matrix of distance values :math:`R_{ij} = |\mathbf x_i - \mathbf
c_j|`.
| :math:`\mathbf x_i` : ``points[i,:]`` (points)
| :math:`\mathbf c_j` : ``self.points[j,:]`` (centers)
Parameters
----------
points : array (K,N) with N-dim points, optional
If None then ``self.points`` is used (training points).
Returns
-------
distsq : (M,K), where K = M for training
"""
# training:
# If points == centers, we could also use
# scipy.spatial.distance.pdist(points), which would give us a 1d
# array of all distances. But we need the redundant square matrix
# form for G=rbf(distsq) anyway, so there is no real point in
# special-casing that. These two are the same:
# spatial.squareform(spatial.distances.pdist(points,
# metric="sqeuclidean"))
# spatial.distances.cdist(points, points, metric="sqeuclidean")
# speed: see examples/benchmarks/distmat_speed.py
#
if points is None:
if self.distsq is None:
return squared_dists(self.points, self.points)
else:
return self.distsq
else:
return squared_dists(points, self.points)
[docs]
def get_params(self):
"""Return ``(p,r)``."""
return self.p, self.r
[docs]
def fit(self):
r"""Solve linear system for the weights.
The weights `self.w` (:math:`\mathbf w`) are found from: :math:`\mathbf
G\,\mathbf w = \mathbf z` or if :math:`r` is given :math:`(\mathbf G +
r\,\mathbf I)\,\mathbf w = \mathbf z`.
with centers == points (center vectors are all data points). Then G is
quadratic. Updates ``self.w``.
Notes
-----
``self.r != None`` : linear system solver
For :math:`r=0`, this always yields perfect interpolation at the
data points. May be numerically unstable in that case. Use
:math:`r>0` to increase stability (try small values such as
``1e-10`` first) or create smooth fitting (generate more stiff
functions with higher `r`). Behaves similar to ``lstsq`` but
appears to be numerically more stable (no small noise in solution)
.. but `r` it is another parameter that needs to be tuned.
``self.r = None`` : least squares solver
Use :func:`scipy.linalg.lstsq`. Numerically more stable than
direct solver w/o regularization. Will mostly be the same as
the interpolation result, but will not go thru all points for
very noisy data. May create small noise in solution (plot fit
with high point density). Much (up to 10x) slower that normal
linear solver when ``self.r != None``.
"""
G = self.rbf(self.get_distsq(), self.p)
assert G.shape == (self.npoints,) * 2
# The lstsq solver is ~10x slower than jax' solver and ~4x slower than
# the symmetric scipy solver
if self.r is None:
self.w = self._solve_lstsq(G)
else:
if self.r == 0:
self.w = self._solve(G)
else:
self.w = self._solve(G + np.eye(G.shape[0]) * self.r)
def _solve_lstsq(self, G):
x, res, rnk, svs = linalg.lstsq(G, self.values)
return x
def _solve(self, Gr):
if JAX_MODE:
la = jax_linalg
# jax.scipy.linalg.solve() doesn't have the `assume_a` kwd, only
# `sym_pos` which is deprecated in scipy. According to scipy's
# docs, sym_pos=True is equal to assume_a="pos". However that is
# 2x slower than assume_a="sym", so we use the latter wheh using
# scipy. With that, scipy and jax are equally fast (CPU, single
# core).
#
# >>> import pwtools.rbf.core
# >>> X=rand(1000,3); y=rand(1000)
#
# >>> pwtools.config.use_jax=True; c=reload(pwtools.rbf.core)
# >>> c.Rbf(X, y, r=1e-10)
# Rbf
# {'ndim': 3,
# 'p': DeviceArray(0.67460831, dtype=float64),
# 'r': 1e-10,
# 'rbf': <function rbf_gauss at 0x7fe77c5ba820>}
# >>> %timeit c.Rbf(X, y, r=1e-10)
# 60.7 ms ± 684 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
#
# >>> pwtools.config.use_jax=False; c=reload(pwtools.rbf.core)
# >>> c.Rbf(X, y, r=1e-10)
# Rbf
# {'ndim': 3,
# 'p': 0.6746083082988632,
# 'r': 1e-10,
# 'rbf': <function rbf_gauss at 0x7fe7918f9550>}
# >>> %timeit c.Rbf(X, y, r=1e-10)
# 59.7 ms ± 170 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
kwds = dict(sym_pos=True)
else:
la = linalg
kwds = dict(assume_a="sym")
return la.solve(Gr, self.values, **kwds)
# XXX the jit here causes a 3x speed up over numpy but a huge performance
# regression in test_rbf.py::test_opt_api, not sure why. Basically, the
# test never finishes. W/o the jit, the jax.numpy predict is 4x slower than
# numpy.
##@partial(jit, static_argnums=0)
[docs]
def predict(self, points):
"""Evaluate model at `points`.
Parameters
----------
points : 2d array (L,N) or (N,)
Returns
-------
vals : 1d array (L,) or scalar
"""
_got_single_point = points.ndim == 1
points = self._rectify_points_shape(points)
assert points.shape[1] == self.ndim, "wrong ndim"
G = self.rbf(self.get_distsq(points=points), self.p)
assert G.shape[0] == points.shape[0]
assert G.shape[1] == len(self.w), (
"shape mismatch between g_ij: %s and w_j: %s, 2nd dim of "
"g_ij must match length of w_j" % (str(G.shape), str(self.w.shape))
)
# normalize w
maxw = np.abs(self.w).max() * 1.0
values = np.dot(G, self.w / maxw) * maxw
return values[0] if _got_single_point else values
[docs]
def deriv_jax(self, points):
"""Partial derivs from jax.
Same API as :meth:`deriv`: ``grad`` for 1d input or ``vmap(grad)`` for
2d input.
>>> x.shape
(n,)
>>> grad(f)(x).shape
(n,)
>>> X.shape
(m,n)
>>> vmap(grad(self))(X).shape
(m,n)
Parameters
----------
points : 2d array (L,N) or (N,)
Returns
-------
grads : 2d array (L,N) or (N,)
See Also
--------
:func:`deriv`
"""
if JAX_MODE:
if points.ndim == 1:
assert len(points) == self.ndim, "wrong ndim"
return jit(grad(self))(points)
elif points.ndim == 2:
assert points.shape[1] == self.ndim, "wrong ndim"
return jit(vmap(grad(self)))(points)
else:
raise Exception("points has wrong shape")
else:
raise NotImplementedError
[docs]
def deriv(self, points):
r"""Analytic first partial derivatives.
Analytic reference implementation of ``jax`` ``grad`` for 1d input or
``vmap(grad)`` for 2d input.
>>> x.shape
(n,)
>>> grad(f)(x).shape
(n,)
>>> X.shape
(m,n)
>>> vmap(grad(self))(X).shape
(m,n)
Parameters
----------
points : 2d array (L,N) or (N,)
Returns
-------
2d array (L,N) or (N,)
Each row holds the gradient vector :math:`\partial f/\partial\mathbf x_i`
where :math:`\mathbf x_i = \texttt{points[i,:]
= [xi_0, ..., xi_N-1]}`. For all points points (L,N) we get the
matrix::
[[df/dx0_0, df/dx0_1, ..., df/dx0_N-1],
[...],
[df/dxL-1_0, df/dxL-1_1, ..., df/dxL-1_N-1]]
See Also
--------
:func:`deriv_jax`
"""
# For the implemented RBF types, the derivatives w.r.t. to the point
# coords simplify to nice dot products, which can be evaluated
# reasonably fast w/ numpy. We don't need to change the RBF's
# implementations to provide a deriv() method. For that, they would
# need to take points and centers explicitly as args instead of squared
# distances, which are calculated fast by cdist().
# Also, this implementation of analytic derivs is nice, but we want to
# play with the cool kids and also use jax. This method here serves as
# reference mostly.
# Speed:
# We have one python loop over the L points (points.shape=(L,N)) left, so
# this gets slow for many points.
# Loop versions (for RBFMultiquadric):
#
# # 3 loops:
# D = np.zeros((L,N), dtype=float)
# for ll in range(L):
# for kk in range(N):
# for jj in range(len(self.w)):
# D[ll,kk] += (points[ll,kk] - centers[jj,kk]) / G[ll,jj] * \
# self.w[jj]
#
# # 2 loops:
# D = np.zeros((L,N), dtype=float)
# for ll in range(L):
# for kk in range(N):
# vec = -1.0 * (centers[:,kk] - points[ll,kk]) / G[ll,:]
# D[ll,kk] = np.dot(vec, self.w)
_got_single_point = points.ndim == 1
points = self._rectify_points_shape(points)
assert points.shape[1] == self.ndim, "wrong ndim"
L, N = points.shape
centers = self.points
G = self.rbf(self.get_distsq(points=points), self.p)
maxw = np.abs(self.w).max() * 1.0
fname = self.rbf.__name__
# fmt: off
# analytic deriv funcs for the inner loop
D_zz = dict(
rbf_multi=lambda zz: -np.dot(
((centers - points[zz, :]) / G[zz, :][:, None]).T,
self.w / maxw,
) * maxw,
rbf_inv_multi=lambda zz: np.dot(
((centers - points[zz, :]) * (G[zz, :] ** 3.0)[:, None]).T,
self.w / maxw,
) * maxw,
rbf_gauss=lambda zz: 1.0 / self.p ** 2.0 * np.dot(
((centers - points[zz, :]) * G[zz, :][:, None]).T,
self.w / maxw,
) * maxw,
)
# fmt: on
assert fname in D_zz.keys(), f"{fname} not in {D_zz.keys()}"
func = D_zz[fname]
if JAX_MODE:
# Of course don't call this method when doing jax autodiff. Still
# when in JAX_MODE, np = jax.numpy and thus its limitations apply.
# In this case, we use the code here only as reference
# implementation, but it must anyway work under jax.numpy .
#
# Because of that, we must use slow list comp b/c jax' functional
# workaround for in-place ops
# D.at[zz,:].set(func(zz))
# still doesn't do inplace (despite its name, but in sync w/ docs)
# unless we jax.jit stuff, then docs say. Instead it returns a
# copy, in this case the full D matrix, with just one line changed.
# Then it is cheaper to just list comp.
#
# We tried to jit this method but D.at[zz,:].set(func(zz)) still
# doesn't update D and returns D all zero. The in-place state of
# mind has no place in jax land.
#
##D = np.empty((L,N), dtype=float)
##for zz in range(L):
## D.at[zz,:].set(func(zz))
D = np.array([func(zz) for zz in range(L)])
else:
D = np.empty((L, N), dtype=float)
for zz in range(L):
D[zz, :] = func(zz)
return D[0] if _got_single_point else D
[docs]
def fit_error(self, points, values):
"""Sum of squared fit errors with penalty on negative `p`."""
res = values - self(points)
err = np.dot(res, res) / len(res)
return math.exp(abs(err)) if self.p < 0 else err
[docs]
def __call__(self, *args, **kwargs):
"""
Call :meth:`predict` or :meth:`deriv`.
Parameters
----------
points : 2d array (L,N) or (N,)
L N-dim points to evaluate the model on.
der : int
If 1 return (matrix of) partial derivatives (see :meth:`deriv` and
:meth:`deriv_jax`), else model prediction values (default).
Returns
-------
vals : 1d array (L,) or scalar
Interpolated values.
or
derivs : 2d array (L,N) or (N,)
1st partial derivatives.
"""
# We hard-code only the 1st deriv using jax, mainly as an example, and
# for easy comparison to our analytic derivs.
#
# Higher order derivs can be implemented by the user outside, e.g.
#
# >>> X=rand(100,2); y=rand(100)
# >>> f=Rbf(X,y)
# >>> jax.hessian(f)(rand(2)).shape
# (4,4)
if "der" in list(kwargs.keys()):
if kwargs["der"] != 1:
raise Exception("only der=1 supported")
kwargs.pop("der")
if JAX_MODE:
return self.deriv_jax(*args, **kwargs)
else:
return self.deriv(*args, **kwargs)
else:
return self.predict(*args, **kwargs)