scico.scipy.special

BlockArray-compatible jax.scipy.special functions.

This modules is a wrapper for jax.scipy.special where some functions have been extended to automatically map over block array blocks as described in NumPy and SciPy Functions.

Functions

bernoulli(n)

Generate the first N Bernoulli numbers.

beta(a, b)

The beta function

betainc(a, b, x)

The regularized incomplete beta function.

betaln(a, b)

Natural log of the absolute value of the beta function

digamma(x)

The digamma function

entr(x)

The entropy function

erf(x)

The error function

erfc(x)

The complement of the error function

erfcx(x)

Scaled complementary error function.

erfinv(x)

The inverse of the error function

exp1(x)

Exponential integral function.

expit(x)

The logistic sigmoid (expit) function

factorial(n[, exact])

Factorial function

gamma(x)

The gamma function.

gammainc(a, x)

The regularized lower incomplete gamma function.

gammaincc(a, x)

The regularized upper incomplete gamma function.

gammaln(x)

Natural log of the absolute value of the gamma function.

i0(x)

Modified bessel function of zeroth order.

i0e(x)

Exponentially scaled modified bessel function of zeroth order.

i1(x)

Modified bessel function of first order.

i1e(x)

Exponentially scaled modified bessel function of first order.

kl_div(p, q)

The Kullback-Leibler divergence.

log_ndtr(x[, series_order])

Log Normal distribution function.

log_softmax(x, /, *[, axis])

Log-Softmax function.

loggamma(x)

Principal branch of the logarithm of the gamma function.

logit(x)

The logit function

multigammaln(a, d)

The natural log of the multivariate gamma function.

ndtr(x)

Normal distribution function.

ndtri(p)

The inverse of the CDF of the Normal distribution function.

owens_t(h, a)

Owen's T function.

polygamma(n, x)

The polygamma function.

rel_entr(p, q)

The relative entropy function.

softmax(x, /, *[, axis])

Softmax function.

spence(x)

Spence's function, also known as the dilogarithm for real values.

sph_harm_y(n, m, theta, phi[, diff_n, n_max])

Computes the spherical harmonics.

xlog1py(x, y)

Compute x*log(1 + y), returning 0 for x=0.

xlogy(x, y)

Compute x*log(y), returning 0 for x=0.

zeta(x[, q])

The Hurwitz zeta function.

scico.scipy.special.bernoulli(n)

Generate the first N Bernoulli numbers.

JAX implementation of scipy.special.bernoulli.

Parameters:

n (int) – integer, the number of Bernoulli terms to generate.

Return type:

Array

Returns:

Array containing the first n Bernoulli numbers.

Notes

bernoulli generates numbers using the \(B_n^-\) convention, such that \(B_1=-1/2\).

scico.scipy.special.beta(a, b)

The beta function

JAX implementation of scipy.special.beta.

\[\mathrm{beta}(a, b) = B(a, b) = \frac{\Gamma(a)\Gamma(b)}{\Gamma(a + b)}\]

where \(\Gamma\) is the gamma function.

Parameters:
Return type:

Array

Returns:

array containing the values of the beta function.

scico.scipy.special.betainc(a, b, x)

The regularized incomplete beta function.

JAX implementation of scipy.special.betainc.

\[\mathrm{betainc}(a, b, x) = \frac{1}{B(a, b)}\int_0^x t^{a-1}(1-t)^{b-1}\mathrm{d}t\]

where \(B(a, b)\) is the beta function.

Parameters:
Return type:

Array

Returns:

array containing values of the betainc function

scico.scipy.special.betaln(a, b)

Natural log of the absolute value of the beta function

JAX implementation of scipy.special.betaln.

\[\mathrm{betaln}(a, b) = \log B(a, b)\]

where \(B\) is the beta function.

Parameters:
Return type:

Array

Returns:

array containing the values of the log-beta function

scico.scipy.special.digamma(x)

The digamma function

JAX implementation of scipy.special.digamma.

\[\mathrm{digamma}(z) = \psi(z) = \frac{\mathrm{d}}{\mathrm{d}z}\log \Gamma(z)\]

where \(\Gamma(z)\) is the gamma function.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – arraylike, real-valued.

Return type:

Array

Returns:

array containing values of the digamma function.

Notes

The JAX version of digamma accepts real-valued inputs.

scico.scipy.special.entr(x)

The entropy function

JAX implementation of scipy.special.entr.

\[\begin{split}\mathrm{entr}(x) = \begin{cases} -x\log(x) & x > 0 \\ 0 & x = 0\\ -\infty & \mathrm{otherwise} \end{cases}\end{split}\]
Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – arraylike, real-valued.

Return type:

Array

Returns:

array containing entropy values.

scico.scipy.special.erf(x)

The error function

JAX implementation of scipy.special.erf.

\[\mathrm{erf}(x) = \frac{2}{\sqrt\pi} \int_{0}^x e^{-t^2} \mathrm{d}t\]
Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – arraylike, real-valued.

Return type:

Array

Returns:

array containing values of the error function.

Notes

The JAX version only supports real-valued inputs.

scico.scipy.special.erfc(x)

The complement of the error function

JAX implementation of scipy.special.erfc.

\[\mathrm{erfc}(x) = \frac{2}{\sqrt\pi} \int_{x}^\infty e^{-t^2} \mathrm{d}t\]

This is the complement of the error function erf, erfc(x) = 1 - erf(x).

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – arraylike, real-valued.

Return type:

Array

Returns:

array containing values of the complement of the error function.

Notes

The JAX version only supports real-valued inputs.

scico.scipy.special.erfcx(x)

Scaled complementary error function.

JAX implementation of scipy.special.erfcx.

\[\mathrm{erfcx}(x) = e^{x^2} \mathrm{erfc}(x)\]

This is numerically stable for large positive x, unlike the naive formula which overflows.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – arraylike, real-valued.

Return type:

Array

Returns:

array containing values of the scaled complementary error function.

scico.scipy.special.erfinv(x)

The inverse of the error function

JAX implementation of scipy.special.erfinv.

Returns the inverse of erf.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – arraylike, real-valued.

Return type:

Array

Returns:

array containing values of the inverse error function.

Notes

The JAX version only supports real-valued inputs.

scico.scipy.special.exp1(x)

Exponential integral function.

JAX implementation of scipy.special.exp1

\[\mathrm{exp1}(x) = E_1(x) = x^{n-1}\int_x^\infty\frac{e^{-t}}{t}\mathrm{d}t\]
Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – arraylike, real-valued

Return type:

Array

Returns:

array of exp1 values

See also

  • jax.scipy.special.expi

  • jax.scipy.special.expn

scico.scipy.special.expit(x)

The logistic sigmoid (expit) function

JAX implementation of scipy.special.expit.

\[\mathrm{expit}(x) = \frac{1}{1 + e^{-x}}\]
Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – arraylike, real-valued.

Return type:

Array

Returns:

array containing values of the expit function.

scico.scipy.special.factorial(n, exact=False)

Factorial function

JAX implementation of scipy.special.factorial

\[\mathrm{factorial}(n) = n! = \prod_{k=1}^n k\]
Parameters:
Return type:

Array

Returns:

array containing values of the factorial.

Notes

This computes the float-valued factorial via the gamma function. JAX does not support exact factorials, because it is not particularly useful: above n=20, the exact result cannot be represented by 64-bit integers, which are the largest integers available to JAX.

scico.scipy.special.gamma(x)

The gamma function.

JAX implementation of scipy.special.gamma.

The gamma function is defined for \(\Re(z)>0\) as

\[\mathrm{gamma}(z) = \Gamma(z) = \int_0^\infty t^{z-1}e^{-t}\mathrm{d}t\]

and is extended by analytic continuation to arbitrary complex values z. For positive integers n, the gamma function is related to the factorial function via the following identity:

\[\Gamma(n) = (n - 1)!\]

For real inputs:

  • if \(x = -\infty\), NaN is returned.

  • if \(x = \pm 0\), \(\pm \infty\) is returned.

  • if \(x\) is a negative integer, NaN is returned. The sign of gamma at a negative integer depends on from which side the pole is approached.

  • if \(x = \infty\), \(\infty\) is returned.

  • if \(x\) is NaN, NaN is returned.

For complex inputs:

  • at non-positive integers (poles), nan+nanj is returned, matching SciPy.

  • if either real or imaginary component is NaN, nan+nanj is returned.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – arraylike, real or complex valued. Complex inputs use a Lanczos approximation with reflection formula.

Return type:

Array

Returns:

array containing the values of the gamma function. For complex inputs, the output is complex-valued.

See also

Notes

For complex inputs, the implementation uses the Lanczos approximation (g=7, N=9 coefficients) with the reflection formula for Re(z) < 0.5.

scico.scipy.special.gammainc(a, x)

The regularized lower incomplete gamma function.

JAX implementation of scipy.special.gammainc.

\[\mathrm{gammainc}(x; a) = \frac{1}{\Gamma(a)}\int_0^x t^{a-1}e^{-t}\mathrm{d}t\]

where \(\Gamma(a)\) is the gamma function.

Parameters:
Return type:

Array

Returns:

array containing values of the gammainc function.

scico.scipy.special.gammaincc(a, x)

The regularized upper incomplete gamma function.

JAX implementation of scipy.special.gammaincc.

\[\mathrm{gammaincc}(x; a) = \frac{1}{\Gamma(a)}\int_x^\infty t^{a-1}e^{-t}\mathrm{d}t\]

where \(\Gamma(a)\) is the gamma function.

Parameters:
Return type:

Array

Returns:

array containing values of the gammaincc function.

scico.scipy.special.gammaln(x)

Natural log of the absolute value of the gamma function.

JAX implementation of scipy.special.gammaln.

\[\mathrm{gammaln}(x) = \log(|\Gamma(x)|)\]

Where \(\Gamma\) is the gamma function.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – arraylike, real valued.

Return type:

Array

Returns:

array containing the values of the log-gamma function

See also

Notes

gammaln does not support complex-valued inputs.

scico.scipy.special.i0(x)

Modified bessel function of zeroth order.

JAX implementation of scipy.special.i0.

\[\mathrm{i0}(x) = I_0(x) = \sum_{k=0}^\infty \frac{(x^2/4)^k}{(k!)^2}\]
Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array, real-valued

Return type:

Array

Returns:

array of bessel function values.

scico.scipy.special.i0e(x)

Exponentially scaled modified bessel function of zeroth order.

JAX implementation of scipy.special.i0e.

\[\mathrm{i0e}(x) = e^{-|x|} I_0(x)\]

where \(I_0(x)\) is the modified Bessel function i0.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array, real-valued

Return type:

Array

Returns:

array of bessel function values.

scico.scipy.special.i1(x)

Modified bessel function of first order.

JAX implementation of scipy.special.i1.

\[\mathrm{i1}(x) = I_1(x) = \frac{1}{2}x\sum_{k=0}^\infty\frac{(x^2/4)^k}{k!(k+1)!}\]
Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array, real-valued

Return type:

Array

Returns:

array of bessel function values

scico.scipy.special.i1e(x)

Exponentially scaled modified bessel function of first order.

JAX implementation of scipy.special.i1e.

\[\mathrm{i1e}(x) = e^{-|x|} I_1(x)\]

where \(I_1(x)\) is the modified Bessel function i1.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array, real-valued

Return type:

Array

Returns:

array of bessel function values

scico.scipy.special.kl_div(p, q)

The Kullback-Leibler divergence.

JAX implementation of scipy.special.kl_div.

\[\begin{split} \mathrm{kl\_div}(p, q) = \begin{cases} p\log(p/q)-p+q & p>0,q>0\\ q & p=0,q\ge 0\\ \infty & \mathrm{otherwise} \end{cases}\end{split}\]
Parameters:
Return type:

Array

Returns:

array of KL-divergence values

scico.scipy.special.log_ndtr(x, series_order=3)

Log Normal distribution function.

JAX implementation of scipy.special.log_ndtr.

For details of the Normal distribution function see ndtr.

This function calculates \(\log(\mathrm{ndtr}(x))\) by either calling \(\log(\mathrm{ndtr}(x))\) or using an asymptotic series. Specifically:

  • For x > upper_segment, use the approximation -ndtr(-x) based on \(\log(1-x) \approx -x, x \ll 1\).

  • For lower_segment < x <= upper_segment, use the existing ndtr technique and take a log.

  • For x <= lower_segment, we use the series approximation of erf to compute the log CDF directly.

The lower_segment is set based on the precision of the input:

\[\begin{split}\begin{align} \mathit{lower\_segment} =& \ \begin{cases} -20 & x.\mathrm{dtype}=\mathit{float64} \\ -10 & x.\mathrm{dtype}=\mathit{float32} \\ \end{cases} \\ \mathit{upper\_segment} =& \ \begin{cases} 8& x.\mathrm{dtype}=\mathit{float64} \\ 5& x.\mathrm{dtype}=\mathit{float32} \\ \end{cases} \end{align}\end{split}\]

When x < lower_segment, the ndtr asymptotic series approximation is:

\[\begin{split}\begin{align} \mathrm{ndtr}(x) =&\ \mathit{scale} * (1 + \mathit{sum}) + R_N \\ \mathit{scale} =&\ \frac{e^{-0.5 x^2}}{-x \sqrt{2 \pi}} \\ \mathit{sum} =&\ \sum_{n=1}^N {-1}^n (2n-1)!! / (x^2)^n \\ R_N =&\ O(e^{-0.5 x^2} (2N+1)!! / |x|^{2N+3}) \end{align}\end{split}\]

where \((2n-1)!! = (2n-1) (2n-3) (2n-5) ... (3) (1)\) is a double-factorial operator.

Parameters:
  • x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – an array of type float32, float64.

  • series_order (int) – Positive Python integer. Maximum depth to evaluate the asymptotic expansion. This is the N above.

Return type:

Array

Returns:

an array with dtype=x.dtype.

Raises:
  • TypeError – if x.dtype is not handled.

  • TypeError – if series_order is a not Python integer.

  • ValueError – if series_order is not in [0, 30].

scico.scipy.special.log_softmax(x, /, *, axis=None)

Log-Softmax function.

JAX implementation of scipy.special.log_softmax

Computes the logarithm of the softmax function, which rescales elements to the range \([-\infty, 0)\).

\[\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)\]
Parameters:
Return type:

Array

Returns:

An array of the same shape as x

Note

If any input values are +inf, the result will be all NaN: this reflects the fact that inf / inf is not well-defined in the context of floating-point math.

See also

softmax

scico.scipy.special.loggamma(x)

Principal branch of the logarithm of the gamma function.

JAX implementation of scipy.special.loggamma.

Defined to be \(\log(\Gamma(x))\) for \(x > 0\) and extended to the complex plane by analytic continuation. The function has a single branch cut on the negative real axis.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – arraylike, real or complex valued.

Return type:

Array

Returns:

array containing the values of the loggamma function. For complex inputs, the output is complex-valued.

See also

scico.scipy.special.logit(x)

The logit function

JAX implementation of scipy.special.logit.

\[\mathrm{logit}(p) = \log\frac{p}{1 - p}\]
Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – arraylike, real-valued.

Return type:

Array

Returns:

array containing values of the logit function.

scico.scipy.special.multigammaln(a, d)

The natural log of the multivariate gamma function.

JAX implementation of scipy.special.multigammaln.

\[\mathrm{multigammaln}(a, d) = \log\Gamma_d(a)\]

where

\[\Gamma_d(a) = \pi^{d(d-1)/4}\prod_{i=1}^d\Gamma(a-(i-1)/2)\]

and \(\Gamma(x)\) is the gamma function.

Parameters:
Return type:

Array

Returns:

array containing values of the log-multigamma function.

scico.scipy.special.ndtr(x)

Normal distribution function.

JAX implementation of scipy.special.ndtr.

Returns the area under the Gaussian probability density function, integrated from minus infinity to x:

\[\begin{split}\begin{align} \mathrm{ndtr}(x) =& \ \frac{1}{\sqrt{2 \pi}}\int_{-\infty}^{x} e^{-\frac{1}{2}t^2} dt \\ =&\ \frac{1}{2} (1 + \mathrm{erf}(\frac{x}{\sqrt{2}})) \\ =&\ \frac{1}{2} \mathrm{erfc}(\frac{x}{\sqrt{2}}) \end{align}\end{split}\]
Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – An array of type float32, float64.

Return type:

Array

Returns:

An array with dtype=x.dtype.

Raises:

TypeError – if x is not floating-type.

scico.scipy.special.ndtri(p)

The inverse of the CDF of the Normal distribution function.

JAX implementation of scipy.special.ndtri.

Returns x such that the area under the PDF from \(-\infty\) to x is equal to p.

A piece-wise rational approximation is done for the function. This is based on the implementation in netlib.

Parameters:

p (Union[Array, ndarray, bool, number, bool, int, float, complex]) – an array of type float32, float64.

Return type:

Array

Returns:

an array with dtype=p.dtype.

Raises:

TypeError – if p is not floating-type.

scico.scipy.special.owens_t(h, a)

Owen’s T function.

JAX implementation of scipy.special.owens_t.

Computes Owen’s T function:

\[T(h, a) = \frac{1}{2\pi} \int_0^a \frac{\exp\!\left(-\tfrac{1}{2}h^2(1+x^2)\right)}{1+x^2} \, dx\]

Computed via 13-point Gauss-type quadrature on the canonical integral form (Patefield & Tandy 2000 method T5). The full 18-region dispatch from Patefield & Tandy is intentionally avoided because XLA evaluates every branch of a where / select unconditionally, which turns per-region dispatch into added cost rather than savings.

Parameters:
Return type:

Array

Returns:

Array of Owen’s T values with dtype matching the promoted inputs.

scico.scipy.special.polygamma(n, x)

The polygamma function.

JAX implementation of scipy.special.polygamma.

\[\mathrm{polygamma}(n, x) = \psi^{(n)}(x) = \frac{\mathrm{d}^{n+1}}{\mathrm{d}x^{n+1}} \log \Gamma(x)\]

where \(\psi\) is the digamma function and \(\Gamma\) is the gamma function.

Parameters:
Return type:

Array

Returns:

array

scico.scipy.special.rel_entr(p, q)

The relative entropy function.

JAX implementation of scipy.special.rel_entr.

\[\begin{split} \mathrm{rel\_entr}(p, q) = \begin{cases} p\log(p/q) & p>0,q>0\\ 0 & p=0,q\ge 0\\ \infty & \mathrm{otherwise} \end{cases}\end{split}\]
Parameters:
Return type:

Array

Returns:

array of relative entropy values.

scico.scipy.special.softmax(x, /, *, axis=None)

Softmax function.

JAX implementation of scipy.special.softmax.

Computes the function which rescales elements to the range \([0, 1]\) such that the elements along axis sum to \(1\).

\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]
Parameters:
Return type:

Array

Returns:

An array of the same shape as x.

Note

If any input values are +inf, the result will be all NaN: this reflects the fact that inf / inf is not well-defined in the context of floating-point math.

See also

log_softmax

scico.scipy.special.spence(x)

Spence’s function, also known as the dilogarithm for real values.

JAX implementation of scipy.special.spence.

It is defined to be:

\[\mathrm{spence}(x) = \begin{equation} \int_1^x \frac{\log(t)}{1 - t}dt \end{equation}\]

Unlike the SciPy implementation, this is only defined for positive real values of z. For negative values, NaN is returned.

Parameters:

z – An array of type float32, float64.

Return type:

Array

Returns:

An array with dtype=z.dtype. computed values of Spence’s function.

Raises:

TypeError – if elements of array z are not in (float32, float64).

Notes: There is a different convention which defines Spence’s function by the integral:

\[\begin{equation} -\int_0^z \frac{\log(1 - t)}{t}dt \end{equation}\]

This is our spence(1 - z).

scico.scipy.special.sph_harm_y(n, m, theta, phi, diff_n=None, n_max=None)

Computes the spherical harmonics.

The JAX version has one extra argument n_max, the maximum value in n.

The spherical harmonic of degree n and order m can be written as \(Y_n^m(\theta, \phi) = N_n^m * P_n^m(\cos \theta) * \exp(i m \phi)\), where \(N_n^m = \sqrt{\frac{\left(2n+1\right) \left(n-m\right)!} {4 \pi \left(n+m\right)!}}\) is the normalization factor and \(\theta\) and \(\phi\) are the colatitude and longitude, respectively. \(N_n^m\) is chosen in the way that the spherical harmonics form a set of orthonormal basis functions of \(L^2(S^2)\).

Parameters:
  • n (Array) – The degree of the harmonic; must have n >= 0. The standard notation for degree in descriptions of spherical harmonics is l (lower case L). We use n here to be consistent with scipy.special.sph_harm_y. Return values for n < 0 are undefined.

  • m (Array) – The order of the harmonic; must have |m| <= n. Return values for |m| > n are undefined.

  • theta (Array) – The polar (colatitudinal) coordinate; must be in [0, pi].

  • phi (Array) – The azimuthal (longitudinal) coordinate; must be in [0, 2*pi].

  • diff_n (int | None) – Unsupported by JAX.

  • n_max (int | None) – The maximum degree max(n). If the supplied n_max is not the true maximum value of n, the results are clipped to n_max. For example, sph_harm(m=jnp.array([2]), n=jnp.array([10]), theta, phi, n_max=6) actually returns sph_harm(m=jnp.array([2]), n=jnp.array([6]), theta, phi, n_max=6)

Return type:

Array

Returns:

A 1D array containing the spherical harmonics at (m, n, theta, phi).

scico.scipy.special.xlog1py(x, y)

Compute x*log(1 + y), returning 0 for x=0.

JAX implementation of scipy.special.xlog1py.

This is defined to return 0 when \((x, y) = (0, -1)\), with a custom derivative rule so that automatic differentiation is well-defined at this point.

Parameters:
Return type:

Array

Returns:

array containing xlog1py values.

See also

jax.scipy.special.xlogy

scico.scipy.special.xlogy(x, y)

Compute x*log(y), returning 0 for x=0.

JAX implementation of scipy.special.xlogy.

This is defined to return zero when \((x, y) = (0, 0)\), with a custom derivative rule so that automatic differentiation is well-defined at this point.

Parameters:
Return type:

Array

Returns:

array containing xlogy values.

See also

jax.scipy.special.xlog1py

scico.scipy.special.zeta(x, q=None)

The Hurwitz zeta function.

JAX implementation of scipy.special.zeta. JAX does not implement the Riemann zeta function (i.e. q = None).

\[\zeta(x, q) = \sum_{n=0}^\infty \frac{1}{(n + q)^x}\]
Parameters:
Return type:

Array

Returns:

array of zeta function values