pwtools.rbf.core._np_distsq¶
- pwtools.rbf.core._np_distsq(aa, bb)[source]¶
(Slow) pure numpy squared distance matrix.
Need that in JAX_MODE b/c we cannot diff thru scipy.spatial.distance.cdist(). BUT: jax.jit() is crazy good (factor 40 faster than the numpy expression, factor almost 4 better than cdist (with 64 bit)!!
>>> f=lambda aa, bb: ((aa[:,None,:] - bb[None,...])**2.0).sum(-1) >>> jf=jit(f) >>> %timeit f(x,x) 39.3 ms ± 438 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) >>> %timeit scipy.spatial.distance.cdist(x,x) 3.43 ms ± 9.25 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) >>> %timeit jf(x,x) 1.03 ms ± 53.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)