scico.scipy.special¶
BlockArray-compatible jax.scipy.special
functions.
This modules is a wrapper for jax.scipy.special where some
functions have been extended to automatically map over block array
blocks as described in NumPy and SciPy Functions.
Functions
|
Generate the first N Bernoulli numbers. |
|
The beta function |
|
The regularized incomplete beta function. |
|
Natural log of the absolute value of the beta function |
|
The digamma function |
|
The entropy function |
|
The error function |
|
The complement of the error function |
|
Scaled complementary error function. |
|
The inverse of the error function |
|
Exponential integral function. |
|
The logistic sigmoid (expit) function |
|
Factorial function |
|
The gamma function. |
|
The regularized lower incomplete gamma function. |
|
The regularized upper incomplete gamma function. |
|
Natural log of the absolute value of the gamma function. |
|
Modified bessel function of zeroth order. |
|
Exponentially scaled modified bessel function of zeroth order. |
|
Modified bessel function of first order. |
|
Exponentially scaled modified bessel function of first order. |
|
The Kullback-Leibler divergence. |
|
Log Normal distribution function. |
|
Log-Softmax function. |
|
Principal branch of the logarithm of the gamma function. |
|
The logit function |
|
The natural log of the multivariate gamma function. |
|
Normal distribution function. |
|
The inverse of the CDF of the Normal distribution function. |
|
Owen's T function. |
|
The polygamma function. |
|
The relative entropy function. |
|
Softmax function. |
|
Spence's function, also known as the dilogarithm for real values. |
|
Computes the spherical harmonics. |
|
Compute x*log(1 + y), returning 0 for x=0. |
|
Compute x*log(y), returning 0 for x=0. |
|
The Hurwitz zeta function. |
- scico.scipy.special.bernoulli(n)¶
Generate the first N Bernoulli numbers.
JAX implementation of
scipy.special.bernoulli.- Parameters:
n (
int) – integer, the number of Bernoulli terms to generate.- Return type:
- Returns:
Array containing the first
nBernoulli numbers.
Notes
bernoulligenerates numbers using the \(B_n^-\) convention, such that \(B_1=-1/2\).
- scico.scipy.special.beta(a, b)¶
The beta function
JAX implementation of
scipy.special.beta.\[\mathrm{beta}(a, b) = B(a, b) = \frac{\Gamma(a)\Gamma(b)}{\Gamma(a + b)}\]where \(\Gamma\) is the
gammafunction.- Parameters:
- Return type:
- Returns:
array containing the values of the beta function.
- scico.scipy.special.betainc(a, b, x)¶
The regularized incomplete beta function.
JAX implementation of
scipy.special.betainc.\[\mathrm{betainc}(a, b, x) = \frac{1}{B(a, b)}\int_0^x t^{a-1}(1-t)^{b-1}\mathrm{d}t\]where \(B(a, b)\) is the
betafunction.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – arraylike, real-valued. Parameter a of the beta distribution.b (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – arraylike, real-valued. Parameter b of the beta distribution.x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – arraylike, real-valued. Upper limit of the integration.
- Return type:
- Returns:
array containing values of the betainc function
- scico.scipy.special.betaln(a, b)¶
Natural log of the absolute value of the beta function
JAX implementation of
scipy.special.betaln.\[\mathrm{betaln}(a, b) = \log B(a, b)\]where \(B\) is the
betafunction.- Parameters:
- Return type:
- Returns:
array containing the values of the log-beta function
See also
- scico.scipy.special.digamma(x)¶
The digamma function
JAX implementation of
scipy.special.digamma.\[\mathrm{digamma}(z) = \psi(z) = \frac{\mathrm{d}}{\mathrm{d}z}\log \Gamma(z)\]where \(\Gamma(z)\) is the
gammafunction.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – arraylike, real-valued.- Return type:
- Returns:
array containing values of the digamma function.
Notes
The JAX version of digamma accepts real-valued inputs.
- scico.scipy.special.entr(x)¶
The entropy function
JAX implementation of
scipy.special.entr.\[\begin{split}\mathrm{entr}(x) = \begin{cases} -x\log(x) & x > 0 \\ 0 & x = 0\\ -\infty & \mathrm{otherwise} \end{cases}\end{split}\]
- scico.scipy.special.erf(x)¶
The error function
JAX implementation of
scipy.special.erf.\[\mathrm{erf}(x) = \frac{2}{\sqrt\pi} \int_{0}^x e^{-t^2} \mathrm{d}t\]- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – arraylike, real-valued.- Return type:
- Returns:
array containing values of the error function.
Notes
The JAX version only supports real-valued inputs.
- scico.scipy.special.erfc(x)¶
The complement of the error function
JAX implementation of
scipy.special.erfc.\[\mathrm{erfc}(x) = \frac{2}{\sqrt\pi} \int_{x}^\infty e^{-t^2} \mathrm{d}t\]This is the complement of the error function
erf,erfc(x) = 1 - erf(x).- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – arraylike, real-valued.- Return type:
- Returns:
array containing values of the complement of the error function.
Notes
The JAX version only supports real-valued inputs.
- scico.scipy.special.erfcx(x)¶
Scaled complementary error function.
JAX implementation of
scipy.special.erfcx.\[\mathrm{erfcx}(x) = e^{x^2} \mathrm{erfc}(x)\]This is numerically stable for large positive
x, unlike the naive formula which overflows.
- scico.scipy.special.erfinv(x)¶
The inverse of the error function
JAX implementation of
scipy.special.erfinv.Returns the inverse of
erf.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – arraylike, real-valued.- Return type:
- Returns:
array containing values of the inverse error function.
Notes
The JAX version only supports real-valued inputs.
- scico.scipy.special.exp1(x)¶
Exponential integral function.
JAX implementation of
scipy.special.exp1\[\mathrm{exp1}(x) = E_1(x) = x^{n-1}\int_x^\infty\frac{e^{-t}}{t}\mathrm{d}t\]- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – arraylike, real-valued- Return type:
- Returns:
array of exp1 values
See also
jax.scipy.special.expijax.scipy.special.expn
- scico.scipy.special.expit(x)¶
The logistic sigmoid (expit) function
JAX implementation of
scipy.special.expit.\[\mathrm{expit}(x) = \frac{1}{1 + e^{-x}}\]
- scico.scipy.special.factorial(n, exact=False)¶
Factorial function
JAX implementation of
scipy.special.factorial\[\mathrm{factorial}(n) = n! = \prod_{k=1}^n k\]- Parameters:
- Return type:
- Returns:
array containing values of the factorial.
Notes
This computes the float-valued factorial via the
gammafunction. JAX does not support exact factorials, because it is not particularly useful: aboven=20, the exact result cannot be represented by 64-bit integers, which are the largest integers available to JAX.See also
- scico.scipy.special.gamma(x)¶
The gamma function.
JAX implementation of
scipy.special.gamma.The gamma function is defined for \(\Re(z)>0\) as
\[\mathrm{gamma}(z) = \Gamma(z) = \int_0^\infty t^{z-1}e^{-t}\mathrm{d}t\]and is extended by analytic continuation to arbitrary complex values z. For positive integers n, the gamma function is related to the
factorialfunction via the following identity:\[\Gamma(n) = (n - 1)!\]For real inputs:
if \(x = -\infty\), NaN is returned.
if \(x = \pm 0\), \(\pm \infty\) is returned.
if \(x\) is a negative integer, NaN is returned. The sign of gamma at a negative integer depends on from which side the pole is approached.
if \(x = \infty\), \(\infty\) is returned.
if \(x\) is NaN, NaN is returned.
For complex inputs:
at non-positive integers (poles),
nan+nanjis returned, matching SciPy.if either real or imaginary component is NaN,
nan+nanjis returned.
- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – arraylike, real or complex valued. Complex inputs use a Lanczos approximation with reflection formula.- Return type:
- Returns:
array containing the values of the gamma function. For complex inputs, the output is complex-valued.
See also
jax.scipy.special.factorial: the factorial function.jax.scipy.special.gammaln: the natural log of the gamma functionjax.scipy.special.gammasgn: the sign of the gamma function
Notes
For complex inputs, the implementation uses the Lanczos approximation (g=7, N=9 coefficients) with the reflection formula for Re(z) < 0.5.
- scico.scipy.special.gammainc(a, x)¶
The regularized lower incomplete gamma function.
JAX implementation of
scipy.special.gammainc.\[\mathrm{gammainc}(x; a) = \frac{1}{\Gamma(a)}\int_0^x t^{a-1}e^{-t}\mathrm{d}t\]where \(\Gamma(a)\) is the
gammafunction.- Parameters:
- Return type:
- Returns:
array containing values of the gammainc function.
- scico.scipy.special.gammaincc(a, x)¶
The regularized upper incomplete gamma function.
JAX implementation of
scipy.special.gammaincc.\[\mathrm{gammaincc}(x; a) = \frac{1}{\Gamma(a)}\int_x^\infty t^{a-1}e^{-t}\mathrm{d}t\]where \(\Gamma(a)\) is the
gammafunction.- Parameters:
- Return type:
- Returns:
array containing values of the gammaincc function.
- scico.scipy.special.gammaln(x)¶
Natural log of the absolute value of the gamma function.
JAX implementation of
scipy.special.gammaln.\[\mathrm{gammaln}(x) = \log(|\Gamma(x)|)\]Where \(\Gamma\) is the
gammafunction.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – arraylike, real valued.- Return type:
- Returns:
array containing the values of the log-gamma function
See also
jax.scipy.special.gammaln: the natural log of the gamma functionjax.scipy.special.gammasgn: the sign of the gamma functionjax.scipy.special.loggamma: the principal branch of the log-gamma function
Notes
gammalndoes not support complex-valued inputs.
- scico.scipy.special.i0(x)¶
Modified bessel function of zeroth order.
JAX implementation of
scipy.special.i0.\[\mathrm{i0}(x) = I_0(x) = \sum_{k=0}^\infty \frac{(x^2/4)^k}{(k!)^2}\]
- scico.scipy.special.i0e(x)¶
Exponentially scaled modified bessel function of zeroth order.
JAX implementation of
scipy.special.i0e.\[\mathrm{i0e}(x) = e^{-|x|} I_0(x)\]where \(I_0(x)\) is the modified Bessel function
i0.
- scico.scipy.special.i1(x)¶
Modified bessel function of first order.
JAX implementation of
scipy.special.i1.\[\mathrm{i1}(x) = I_1(x) = \frac{1}{2}x\sum_{k=0}^\infty\frac{(x^2/4)^k}{k!(k+1)!}\]
- scico.scipy.special.i1e(x)¶
Exponentially scaled modified bessel function of first order.
JAX implementation of
scipy.special.i1e.\[\mathrm{i1e}(x) = e^{-|x|} I_1(x)\]where \(I_1(x)\) is the modified Bessel function
i1.
- scico.scipy.special.kl_div(p, q)¶
The Kullback-Leibler divergence.
JAX implementation of
scipy.special.kl_div.\[\begin{split} \mathrm{kl\_div}(p, q) = \begin{cases} p\log(p/q)-p+q & p>0,q>0\\ q & p=0,q\ge 0\\ \infty & \mathrm{otherwise} \end{cases}\end{split}\]
- scico.scipy.special.log_ndtr(x, series_order=3)¶
Log Normal distribution function.
JAX implementation of
scipy.special.log_ndtr.For details of the Normal distribution function see ndtr.
This function calculates \(\log(\mathrm{ndtr}(x))\) by either calling \(\log(\mathrm{ndtr}(x))\) or using an asymptotic series. Specifically:
For x > upper_segment, use the approximation -ndtr(-x) based on \(\log(1-x) \approx -x, x \ll 1\).
For lower_segment < x <= upper_segment, use the existing ndtr technique and take a log.
For x <= lower_segment, we use the series approximation of erf to compute the log CDF directly.
The lower_segment is set based on the precision of the input:
\[\begin{split}\begin{align} \mathit{lower\_segment} =& \ \begin{cases} -20 & x.\mathrm{dtype}=\mathit{float64} \\ -10 & x.\mathrm{dtype}=\mathit{float32} \\ \end{cases} \\ \mathit{upper\_segment} =& \ \begin{cases} 8& x.\mathrm{dtype}=\mathit{float64} \\ 5& x.\mathrm{dtype}=\mathit{float32} \\ \end{cases} \end{align}\end{split}\]When x < lower_segment, the ndtr asymptotic series approximation is:
\[\begin{split}\begin{align} \mathrm{ndtr}(x) =&\ \mathit{scale} * (1 + \mathit{sum}) + R_N \\ \mathit{scale} =&\ \frac{e^{-0.5 x^2}}{-x \sqrt{2 \pi}} \\ \mathit{sum} =&\ \sum_{n=1}^N {-1}^n (2n-1)!! / (x^2)^n \\ R_N =&\ O(e^{-0.5 x^2} (2N+1)!! / |x|^{2N+3}) \end{align}\end{split}\]where \((2n-1)!! = (2n-1) (2n-3) (2n-5) ... (3) (1)\) is a double-factorial operator.
- Parameters:
- Return type:
- Returns:
an array with dtype=x.dtype.
- Raises:
TypeError – if x.dtype is not handled.
TypeError – if series_order is a not Python integer.
ValueError – if series_order is not in [0, 30].
- scico.scipy.special.log_softmax(x, /, *, axis=None)¶
Log-Softmax function.
JAX implementation of
scipy.special.log_softmaxComputes the logarithm of the
softmaxfunction, which rescales elements to the range \([-\infty, 0)\).\[\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)\]- Parameters:
- Return type:
- Returns:
An array of the same shape as
x
Note
If any input values are
+inf, the result will be allNaN: this reflects the fact thatinf / infis not well-defined in the context of floating-point math.See also
- scico.scipy.special.loggamma(x)¶
Principal branch of the logarithm of the gamma function.
JAX implementation of
scipy.special.loggamma.Defined to be \(\log(\Gamma(x))\) for \(x > 0\) and extended to the complex plane by analytic continuation. The function has a single branch cut on the negative real axis.
- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – arraylike, real or complex valued.- Return type:
- Returns:
array containing the values of the loggamma function. For complex inputs, the output is complex-valued.
See also
jax.scipy.special.gamma: the gamma function.jax.scipy.special.gammaln: the natural log of the absolute value of the gamma function.
- scico.scipy.special.logit(x)¶
The logit function
JAX implementation of
scipy.special.logit.\[\mathrm{logit}(p) = \log\frac{p}{1 - p}\]
- scico.scipy.special.multigammaln(a, d)¶
The natural log of the multivariate gamma function.
JAX implementation of
scipy.special.multigammaln.\[\mathrm{multigammaln}(a, d) = \log\Gamma_d(a)\]where
\[\Gamma_d(a) = \pi^{d(d-1)/4}\prod_{i=1}^d\Gamma(a-(i-1)/2)\]and \(\Gamma(x)\) is the
gammafunction.- Parameters:
- Return type:
- Returns:
array containing values of the log-multigamma function.
See also
- scico.scipy.special.ndtr(x)¶
Normal distribution function.
JAX implementation of
scipy.special.ndtr.Returns the area under the Gaussian probability density function, integrated from minus infinity to x:
\[\begin{split}\begin{align} \mathrm{ndtr}(x) =& \ \frac{1}{\sqrt{2 \pi}}\int_{-\infty}^{x} e^{-\frac{1}{2}t^2} dt \\ =&\ \frac{1}{2} (1 + \mathrm{erf}(\frac{x}{\sqrt{2}})) \\ =&\ \frac{1}{2} \mathrm{erfc}(\frac{x}{\sqrt{2}}) \end{align}\end{split}\]
- scico.scipy.special.ndtri(p)¶
The inverse of the CDF of the Normal distribution function.
JAX implementation of
scipy.special.ndtri.Returns x such that the area under the PDF from \(-\infty\) to x is equal to p.
A piece-wise rational approximation is done for the function. This is based on the implementation in netlib.
- scico.scipy.special.owens_t(h, a)¶
Owen’s T function.
JAX implementation of
scipy.special.owens_t.Computes Owen’s T function:
\[T(h, a) = \frac{1}{2\pi} \int_0^a \frac{\exp\!\left(-\tfrac{1}{2}h^2(1+x^2)\right)}{1+x^2} \, dx\]Computed via 13-point Gauss-type quadrature on the canonical integral form (Patefield & Tandy 2000 method T5). The full 18-region dispatch from Patefield & Tandy is intentionally avoided because XLA evaluates every branch of a
where/selectunconditionally, which turns per-region dispatch into added cost rather than savings.- Parameters:
- Return type:
- Returns:
Array of Owen’s T values with dtype matching the promoted inputs.
See also
- scico.scipy.special.polygamma(n, x)¶
The polygamma function.
JAX implementation of
scipy.special.polygamma.\[\mathrm{polygamma}(n, x) = \psi^{(n)}(x) = \frac{\mathrm{d}^{n+1}}{\mathrm{d}x^{n+1}} \log \Gamma(x)\]where \(\psi\) is the
digammafunction and \(\Gamma\) is thegammafunction.- Parameters:
- Return type:
- Returns:
array
- scico.scipy.special.rel_entr(p, q)¶
The relative entropy function.
JAX implementation of
scipy.special.rel_entr.\[\begin{split} \mathrm{rel\_entr}(p, q) = \begin{cases} p\log(p/q) & p>0,q>0\\ 0 & p=0,q\ge 0\\ \infty & \mathrm{otherwise} \end{cases}\end{split}\]
- scico.scipy.special.softmax(x, /, *, axis=None)¶
Softmax function.
JAX implementation of
scipy.special.softmax.Computes the function which rescales elements to the range \([0, 1]\) such that the elements along
axissum to \(1\).\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]- Parameters:
- Return type:
- Returns:
An array of the same shape as
x.
Note
If any input values are
+inf, the result will be allNaN: this reflects the fact thatinf / infis not well-defined in the context of floating-point math.See also
- scico.scipy.special.spence(x)¶
Spence’s function, also known as the dilogarithm for real values.
JAX implementation of
scipy.special.spence.It is defined to be:
\[\mathrm{spence}(x) = \begin{equation} \int_1^x \frac{\log(t)}{1 - t}dt \end{equation}\]Unlike the SciPy implementation, this is only defined for positive real values of z. For negative values, NaN is returned.
- Parameters:
z – An array of type float32, float64.
- Return type:
- Returns:
An array with dtype=z.dtype. computed values of Spence’s function.
- Raises:
TypeError – if elements of array z are not in (float32, float64).
Notes: There is a different convention which defines Spence’s function by the integral:
\[\begin{equation} -\int_0^z \frac{\log(1 - t)}{t}dt \end{equation}\]This is our spence(1 - z).
- scico.scipy.special.sph_harm_y(n, m, theta, phi, diff_n=None, n_max=None)¶
Computes the spherical harmonics.
The JAX version has one extra argument n_max, the maximum value in n.
The spherical harmonic of degree n and order m can be written as \(Y_n^m(\theta, \phi) = N_n^m * P_n^m(\cos \theta) * \exp(i m \phi)\), where \(N_n^m = \sqrt{\frac{\left(2n+1\right) \left(n-m\right)!} {4 \pi \left(n+m\right)!}}\) is the normalization factor and \(\theta\) and \(\phi\) are the colatitude and longitude, respectively. \(N_n^m\) is chosen in the way that the spherical harmonics form a set of orthonormal basis functions of \(L^2(S^2)\).
- Parameters:
n (
Array) – The degree of the harmonic; must have n >= 0. The standard notation for degree in descriptions of spherical harmonics is l (lower case L). We use n here to be consistent with scipy.special.sph_harm_y. Return values for n < 0 are undefined.m (
Array) – The order of the harmonic; must have |m| <= n. Return values for |m| > n are undefined.theta (
Array) – The polar (colatitudinal) coordinate; must be in [0, pi].phi (
Array) – The azimuthal (longitudinal) coordinate; must be in [0, 2*pi].n_max (
int|None) – The maximum degree max(n). If the supplied n_max is not the true maximum value of n, the results are clipped to n_max. For example, sph_harm(m=jnp.array([2]), n=jnp.array([10]), theta, phi, n_max=6) actually returns sph_harm(m=jnp.array([2]), n=jnp.array([6]), theta, phi, n_max=6)
- Return type:
- Returns:
A 1D array containing the spherical harmonics at (m, n, theta, phi).
- scico.scipy.special.xlog1py(x, y)¶
Compute x*log(1 + y), returning 0 for x=0.
JAX implementation of
scipy.special.xlog1py.This is defined to return 0 when \((x, y) = (0, -1)\), with a custom derivative rule so that automatic differentiation is well-defined at this point.
- Parameters:
- Return type:
- Returns:
array containing xlog1py values.
See also
jax.scipy.special.xlogy
- scico.scipy.special.xlogy(x, y)¶
Compute x*log(y), returning 0 for x=0.
JAX implementation of
scipy.special.xlogy.This is defined to return zero when \((x, y) = (0, 0)\), with a custom derivative rule so that automatic differentiation is well-defined at this point.
- Parameters:
- Return type:
- Returns:
array containing xlogy values.
See also
jax.scipy.special.xlog1py
- scico.scipy.special.zeta(x, q=None)¶
The Hurwitz zeta function.
JAX implementation of
scipy.special.zeta. JAX does not implement the Riemann zeta function (i.e.q = None).\[\zeta(x, q) = \sum_{n=0}^\infty \frac{1}{(n + q)^x}\]