Convolutional Sparse Coding (ADMM)¶
This example demonstrates the solution of a simple convolutional sparse coding problem
where the \(\mathbf{h}\)_k is a set of filters comprising the dictionary, the \(\mathbf{x}\)_k is a corrresponding set of coefficient maps, and \(\mathbf{y}\) is the signal to be represented. The problem is solved via an ADMM algorithm using the frequency-domain approach proposed in [59].
[1]:
import numpy as np
import komplot as kplt
import scico.numpy as snp
from scico.examples import create_conv_sparse_phantom
from scico.functional import L1MinusL2Norm
from scico.linop import CircularConvolve, Identity, Sum
from scico.loss import SquaredL2Loss
from scico.optimize.admm import ADMM, FBlockCircularConvolveSolver
from scico.util import device_info
kplt.config_notebook_plotting()
Set problem size and create random convolutional dictionary (a set of filters) and a corresponding sparse random set of coefficient maps.
[2]:
N = 128 # image size
Nnz = 128 # number of non-zeros in coefficient maps
h, x0 = create_conv_sparse_phantom(N, Nnz)
Normalize dictionary filters and scale coefficient maps accordingly.
[3]:
hnorm = np.sqrt(np.sum(h**2, axis=(1, 2), keepdims=True))
h /= hnorm
x0 *= hnorm
Convert numpy arrays to jax arrays.
[4]:
h = snp.array(h)
x0 = snp.array(x0)
Set up sum-of-convolutions forward operator.
[5]:
C = CircularConvolve(h, input_shape=x0.shape, ndims=2)
S = Sum(input_shape=C.output_shape, axis=0)
A = S @ C
Construct test image from dictionary \(\mathbf{h}\) and coefficient maps \(\mathbf{x}_0\).
[6]:
y = A(x0)
Set functional and solver parameters.
[7]:
λ = 1e0 # ℓ1-ℓ2 norm regularization parameter
ρ = 2e0 # ADMM penalty parameter
maxiter = 200 # number of ADMM iterations
Define loss function and regularization. Note the use of the \(\ell_1 - \ell_2\) norm, which has been found to provide slightly better performance than the \(\ell_1\) norm in this type of problem [60].
[8]:
f = SquaredL2Loss(y=y, A=A)
g0 = λ * L1MinusL2Norm()
C0 = Identity(input_shape=x0.shape)
Initialize ADMM solver.
[9]:
solver = ADMM(
f=f,
g_list=[g0],
C_list=[C0],
rho_list=[ρ],
alpha=1.8,
maxiter=maxiter,
subproblem_solver=FBlockCircularConvolveSolver(check_solve=True),
itstat_options={"display": True, "period": 10},
)
Run the solver.
[10]:
print(f"Solving on {device_info()}\n")
x1 = solver.solve()
hist = solver.itstat_object.history(transpose=True)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)
Iter Time Objective Prml Rsdl Dual Rsdl Slv Res
----------------------------------------------------------
0 2.45e+00 2.107e+03 3.851e+01 5.291e+01 6.771e-06
10 3.74e+00 2.862e+03 4.623e+00 9.837e+00 7.471e-06
20 3.85e+00 2.702e+03 2.097e+00 6.156e+00 8.159e-06
30 3.97e+00 2.626e+03 1.528e+00 4.603e+00 1.060e-05
40 4.08e+00 2.578e+03 1.222e+00 3.719e+00 1.312e-05
50 4.20e+00 2.542e+03 1.124e+00 3.478e+00 1.011e-05
60 4.33e+00 2.511e+03 1.053e+00 3.260e+00 9.227e-06
70 4.46e+00 2.486e+03 9.754e-01 3.012e+00 1.460e-05
80 4.59e+00 2.464e+03 8.965e-01 2.773e+00 9.706e-06
90 4.73e+00 2.444e+03 8.411e-01 2.613e+00 1.212e-05
100 4.87e+00 2.426e+03 8.010e-01 2.495e+00 1.139e-05
110 4.99e+00 2.411e+03 7.589e-01 2.367e+00 6.826e-06
120 5.10e+00 2.398e+03 7.139e-01 2.224e+00 7.160e-06
130 5.21e+00 2.386e+03 6.679e-01 2.076e+00 4.309e-06
140 5.33e+00 2.377e+03 6.111e-01 1.881e+00 1.699e-05
150 5.43e+00 2.368e+03 5.547e-01 1.717e+00 1.301e-05
160 5.54e+00 2.361e+03 5.137e-01 1.595e+00 1.007e-05
170 5.64e+00 2.356e+03 4.663e-01 1.439e+00 1.254e-05
180 5.74e+00 2.352e+03 4.245e-01 1.301e+00 6.956e-06
190 5.86e+00 2.348e+03 3.735e-01 1.135e+00 1.350e-05
199 5.96e+00 2.347e+03 3.066e-01 8.973e-01 7.575e-06
Show the recovered coefficient maps.
[11]:
fig, ax = kplt.subplots(nrows=2, ncols=3, sharex=True, sharey=True, figsize=(12, 8.6))
kplt.imview(x0[0], title="Coef. map 0", cmap=kplt.cm.Blues, ax=ax[0, 0])
ax[0, 0].set_ylabel("Ground truth")
kplt.imview(x0[1], title="Coef. map 1", cmap=kplt.cm.Blues, ax=ax[0, 1])
kplt.imview(x0[2], title="Coef. map 2", cmap=kplt.cm.Blues, ax=ax[0, 2])
kplt.imview(x1[0], cmap=kplt.cm.Blues, ax=ax[1, 0])
ax[1, 0].set_ylabel("Recovered")
kplt.imview(x1[1], cmap=kplt.cm.Blues, ax=ax[1, 1])
kplt.imview(x1[2], cmap=kplt.cm.Blues, ax=ax[1, 2])
fig.tight_layout()
fig.show()
Show test image and reconstruction from recovered coefficient maps.
[12]:
fig, ax = kplt.subplots(nrows=1, ncols=2, sharex=True, sharey=True, figsize=(12, 6))
kplt.imview(y, title="Test image", cmap=kplt.cm.gist_heat_r, ax=ax[0])
kplt.imview(A(x1), title="Reconstructed image", cmap=kplt.cm.gist_heat_r, ax=ax[1])
fig.show()
Plot convergence statistics.
[13]:
fig, ax = kplt.subplots(nrows=1, ncols=2, figsize=(12, 5))
kplt.plot(
hist.Objective,
title="Objective function",
xlabel="Iteration",
ylabel="Functional value",
ax=ax[0],
)
kplt.plot(
snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,
ylog=True,
title="Residuals",
xlabel="Iteration",
legend=("Primal", "Dual"),
ax=ax[1],
)
fig.show()