Friday, December 9, 2016

How to de-noise images in Python

So you tried to take a beautiful picture of a snow storm last night, but because it was so dark your picture turned out all noisy... What to do now? Don't worry—ask a Swiss! Image de-noising is the process of removing noise from an image, while at the same time preserving details and structures. In the following tutorial, we will implement a simple noise reduction algorithm in Python.


The Theory

Removing noise from images is important for many applications, from making your holiday photos look better to improving the quality of satellite images. One way to do this is by means of the Rudin-Osher-Fatemi (ROF) algorithm, which has the interesting property of finding a smoother version of an image while preserving edges and structures. The ROF algorithm typically produces a "posterized" version of an image with flat domains separated by sharp edges. The degree of posterization can be controlled by changing the trade-off between de-noising and faithfulness to the original image.

The ROF model is also known as the total variation filter. The idea behind the algorithm is that signals with excessive and possibly spurious detail have high total variation; that is, the integral of the absolute gradient of the signal is high.

The total variation (TV) of a (grayscale) image \(I\) is defined as the sum of the gradient norm. In a continuous representation this is:

\[ J(I) = \int |\nabla I| dx \]

In a discrete setting, the total variation becomes:

\[ J(I) = \sum_{\mathbf{x}} |\nabla I|, \]

where the sum is over all image coordinates \(\mathbf{x} = [x,y] \).

According to this principle, reducing the total variation of the signal subject to it being a close match to the original signal, removes unwanted detail whilst preserving important details such as edges. In the Chambolle version of ROF, the goal is to find a de-noised image \(U\) that minimizes the following cost function:

\[ \min_U ||I - U||^2 + 2 \lambda J(U), \]

where the norm \(||I - U||\) measures the difference between \(U\) and the original image \(I\). What this means in essence is that the model looks for images that are "flat" but allow "jumps" at edges between regions.

The result of this filter is an image that has a minimal total variation norm, while being as close to the initial image as possible. The total variation is the L1 norm of the gradient of the image.

However, in practice we are faced with the problem of implementing the above minimization constraint, which is not always easy, and might have multiple solutions. In fact, because we are dealing with noisy images, the above problem is not a well-posed problem in a mathematical sense. Luckily, smart mathematicians have come up with what is called a dual formulation, where the primal variable \(U\) is replaced by a dual variable \(P = (p_x, p_y)\), which is chosen in a smart way such that the new optimization problem becomes more amenable to numerical implementation. I'll spare you the math, but the idea is to assume that the added noise is Gaussian and then estimate the variance of that random Gaussian noise using a Lagrange multiplier. The full paper is here.


The Code

Many image processing libraries, such as OpenCV and scikit-image come pre-installed with a number of de-noising algorithms, such as the total variation filter and the bilateral filter. However, the goal of this post is to get our hands dirty and develop our own implementation.

We will use NumPy for computation, and matplotlib for plotting:

import numpy as np
import matplotlib.pyplot as plt

De-noising shall be performed by a function called denoise, which takes as input arguments a grayscale image (img) and a denoising weight strength (weight). As we iteratively reduce the error between the original image and a de-noised version, we also want to specify a stop criterion for the algorithm: we're done when either the error cannot be reduced any further (minimal reduction from one iteration to the next must be larger than eps) or we've reached a maximum number of iteration steps (num_iter_max). The function shall then return the de-noised image out:

def denoise(img, weight=0.1, eps=1e-3, num_iter_max=200):
    """Perform total-variation denoising on a grayscale image.
    
    Parameters
    ----------
    img : array
        2-D input data to be de-noised.
    weight : float, optional
        Denoising weight. The greater `weight`, the more
        de-noising (at the expense of fidelity to `img`).
    eps : float, optional
        Relative difference of the value of the cost
        function that determines the stop criterion.
        The algorithm stops when:
            (E_(n-1) - E_n) < eps * E_0
    num_iter_max : int, optional
        Maximal number of iterations used for the
        optimization.

    Returns
    -------
    out : array
        De-noised array of floats.
    
    Notes
    -----
    Rudin, Osher and Fatemi algorithm.
    """

The first step is to allocate memory for an initial guess of the clean image (u), which we will improve on iteratively, and the two components of the dual variable (px, py). We choose a fixed time step (tau), which should be smaller or equal 1/8 (Chambolle, 2005). We then iterate until we hit the maximum number of iterations (num_iter_max):

    u = np.zeros_like(img)
    px = np.zeros_like(img)
    py = np.zeros_like(img)
    
    nm = np.prod(img.shape[:2])
    tau = 0.125
    
    i = 0
    while i < num_iter_max:
        u_old = u

To compute the gradient of the primal variable, NumPy's roll comes in handy: As the name suggests, the function "rolls" the values of an array cyclically around an axis. This is very convenient for computing neighbor differences, in this case for derivatives:

        # x and y components of u's gradient
        ux = np.roll(u, -1, axis=1) - u
        uy = np.roll(u, -1, axis=0) - u

Per Eq. 11 of Chambolle (2005), the dual variable depends on the gradient of the primal variable, which we just computed. Note that weight corresponds to \(\lambda\) in the paper.

        
        # update the dual variable
        px_new = px + (tau / weight) * ux
        py_new = py + (tau / weight) * uy

We normalize the dual variable such that the vector \((p_x, p_y)\) has length 1:

        norm_new = np.maximum(1, np.sqrt(px_new **2 + py_new ** 2))
        px = px_new / norm_new
        py = py_new / norm_new

Then we calculate the divergence of \(P\), again by using NumPy's roll function, and update our current best guess of a cleaned-up image (u):

        # calculate divergence
        rx = np.roll(px, 1, axis=1)
        ry = np.roll(py, 1, axis=0)
        div_p = (px - rx) + (py - ry)
        
        # update image
        u = img + weight * div_p

We can measure the improvement in error for the current iteration step by comparing u to u_old from the previous step:

        
        # calculate error
        error = np.linalg.norm(u - u_old) / np.sqrt(nm)

Finally, we check the stop criterion to determine whether we are done, otherwise we keep iterating:

        if i == 0:
            err_init = error
            err_prev = error
        else:
            # break if error small enough
            if np.abs(err_prev - error) < eps * err_init:
                break
            else:
                e_prev = error
                
        # don't forget to update iterator
        i += 1

    return u

In order to test our function, let's create an image with three different grayscale levels, then spoil it with noise:

# create an image with 3 different grayscale levels
img = np.zeros((500,500))
img[100:400,100:400] = 128
img[200:300,200:300] = 255

# add noise to the grayscale values
img = img + 30*np.random.standard_normal((500,500))
img = np.clip(img, 0, 255)

Now we can run our de-noising algorithm on the image. By varying weight, we can control how much of the noise shall be removed:

# plot the noisy image
plt.subplot(141)
plt.imshow(img, cmap='viridis')
plt.title('noisy')

plt.subplot(142)
plt.imshow(denoise(img, weight=1), cmap='viridis')
plt.title('denoising with small weight')

plt.subplot(143)
plt.imshow(denoise(img, weight=10), cmap='viridis')
plt.title('denoising with medium weight')

plt.subplot(144)
plt.imshow(denoise(img, weight=100), cmap='viridis')
plt.title('denoising with strong weight')

The same works for real images, but be sure to convert to grayscale first. Here we will use scikit-image for our image processing needs:

from skimage.io import imread
from skimage.color import rgb2gray

img = imread('teddy.jpg')
img = rgb2gray(img2) * 255

We then add noise to the image, and input it into the de-noising algorithm:

noisy = img2 + 0.5 * img2.std() * np.random.random(img2.shape)
noisy = np.clip(noisy, 0, 255)
U, T = denoise(noisy, noisy)

Note: scikit-image library actually comes with two different de-noising filters straight out of the box: total variation filter and bilateral filter.

View this code in a Jupyter Notebook hosted on GitHub Gist.