pwtools.rbf.core._np_distsq

pwtools.rbf.core._np_distsq(aa, bb)[source]

(Slow) pure numpy squared distance matrix.

Need that in JAX_MODE b/c we cannot diff thru scipy.spatial.distance.cdist(). BUT: jax.jit() is crazy good (factor 40 faster than the numpy expression, factor almost 4 better than cdist (with 64 bit)!!

>>> f=lambda aa, bb: ((aa[:,None,:] - bb[None,...])**2.0).sum(-1)
>>> jf=jit(f)
>>> %timeit f(x,x)
39.3 ms ± 438 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit scipy.spatial.distance.cdist(x,x)
3.43 ms ± 9.25 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
>>> %timeit jf(x,x)
1.03 ms ± 53.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)