Let’s say we have a complete input image for testing purposes. We will first mask it by putting in nan values randomly. We’ll use the masked image to reconstruct the original image using matrix factorization.
Since our image will be in RGB, we’ll perform all operations for the three channels seperately.
Code
# function to mask the imagedef mask_image(img, prop, device): img_copy = img.clone().to(device) mask = torch.rand(img.shape[1:]) < prop # uniform distribution img_copy[0][mask] =float('nan') img_copy[1][mask] =float('nan') img_copy[2][mask] =float('nan')return img_copy, mask# function that randomly removes 900 pixels from the imagedef random_mask_image(img, device): img_copy = img.clone().to(device) h, w = img.shape[1], img.shape[2] total_pixels = h * w# Randomly select 900 pixel indices random_indices = torch.randperm(total_pixels)[:900]# Convert flat indices back to 2D indices (for height and width) mask = torch.zeros(h, w, dtype=torch.bool, device=device) mask.view(-1)[random_indices] =True# Apply NaN mask to each channel img_copy[0][mask] =float('nan') img_copy[1][mask] =float('nan') img_copy[2][mask] =float('nan')return img_copy, mask
Implementing Matrix Factorization
To factorize any matrix \(A\), we essentially want to learn two matrices \(W\) and \(H\) such that \[A=W\cdot H\] We want that \[W, H = argmin_{W',H'} ||A-W'\cdot H'||_F^2\] This can achieved using two methods: gradient descent and alternating least squares.
Let us first implement the factorize(A, k, device) function using gradient descent. This function will take as input the matrix to be factorized \(A_{m \times n}\), the rank \(k\) of decomposition and the device for pytorch code.
We first randomly initialize \(W_{m\times k}\) and \(H_{k\times n}\). Then for each iteration of gradient descent we follow the following update rules: \[W = W - \alpha \frac{\partial ||A-W\cdot H||_F^2}{\partial W}\]\[H = H - \alpha \frac{\partial ||A-W\cdot H||_F^2}{\partial H}\]
We use the torch.optim.Adam optimizer with a learning rate of 0.01.
Since the input matrix can have certain pixels as nan we use a mask that ensures we only calculate loss by taking the norm of the non-nan values.
Code
def factorize(A, k, device):"""Factorize the matrix A into W and H""" A = A.to(device)# Randomly initialize W, H W = torch.randn(A.shape[0], k, requires_grad=True, device=device) H = torch.randn(k, A.shape[1], requires_grad=True, device=device)# Optimizer optimizer = torch.optim.Adam([W, H], lr=0.01) mask =~torch.isnan(A)# Train the modelfor i inrange(1000):# Compute the loss diff_matrix = torch.mm(W, H) - A diff_vector = diff_matrix[mask] # makes a 1D tensor loss = torch.norm(diff_vector)# Zero the gradients optimizer.zero_grad()# Backpropagate loss.backward()# Update the parameters optimizer.step()return W, H, loss
Code
def extract_not_nan_coordinates_pixels(image, device): channels, height, width = image.shape coords = [] pixel_values = [] for y inrange(height):for x inrange(width):if image[0][x][y].isnan():continue coords.append([x, y]) pixel_values.append(image[:, x, y].tolist()) coords = torch.tensor(coords, dtype=torch.float32) pixel_values = torch.tensor(pixel_values, dtype=torch.float32)return coords.to(device), pixel_values.to(device)def create_linear_model(input_dim, output_dim, device):return nn.Linear(input_dim, output_dim).to(device)def train(coords, pixels, model, learning_rate=0.01, epochs=1000, threshold=1e-6, verbose=True): criterion = nn.MSELoss() # define the loss function (mse) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # use the adam optimizer with the specified learning rate previous_loss =float('inf') # initialize w very large value (for early stopping)# training loopsfor epoch inrange(epochs): optimizer.zero_grad() # reset the gradient of the optimizer outputs = model(coords) # compute the output loss = criterion(outputs, pixels) # calculate the loss that we defined earlier loss.backward() # compute teh gradients of the loss with respect to the parameters optimizer.step() # update the parameters based on the gradients computed above# check for early stoppingifabs(previous_loss - loss.item()) < threshold:print(f"Stopping early at epoch {epoch} with loss: {loss.item():.6f}")break previous_loss = loss.item() # update the previous lossif verbose and epoch %100==0:print(f"Epoch {epoch} loss: {loss.item():.6f}")return loss.item()def create_rff_features(tensor, num_features, sigma, device): rff = RBFSampler(n_components=num_features, gamma=1/(2* sigma**2), random_state=42) tensor = torch.tensor(rff.fit_transform(tensor.cpu().numpy())).float().to(device)return tensordef create_coordinate_map(height, width, device, inv=False): # given the height and width of an image this function creates a coordinate map and returns it coords = []ifnot inv:for x inrange(height):for y inrange(width): coords.append([x, y])else:for y inrange(height):for x inrange(width): coords.append([x, y])return torch.tensor(coords, dtype=torch.float32).to(device)
Defining metrics for the problem
We use peak_signal_noise_ratio and mean_squared_error from the torchmetrics library as metrics to evaluate the performance of our model. Peak Signal to Noise ratio: Consider we have single channel of an image \(A_{m \times n}\) and its learned factorization \(A'=W_{m \times k}H_{k \times n}\), MSE is defined as \[MSE = \frac{1}{mn}\sum_{i=0}^{m-1} \sum_{j=0}^{n-1} \left(A(i,j)-A'(i,j)\right)^2\] The PSNR (in dB) is defined as \[PSNR = 10\log_{10}\left(\frac{MAX_A^2}{MSE}\right) = 20\log_{10}\left(MAX_A\right)-10\log_{10}\left(MSE\right)\] Here \(MAX_A\) is the maximum possible pixel value of the given channel \(A\) of an iamge. For our case \(MAX_A\) is simply 1.
Code
# function to compute the RMSE and PSNRdef metrics(img, reconstructed_img): rmse = mean_squared_error(target = img.reshape(-1), preds=reconstructed_img.reshape(-1), squared=False) psnr = peak_signal_noise_ratio(target=img.reshape(-1), preds=reconstructed_img.reshape(-1))return rmse, psnr
Testing
We now write a function plot_image_completion() that takes as input:
img: the original image
prop_list: a list of proportions of pixels to mask
factors_list: a list of the decomposition factors
device: PyTorch device
Now for each proportion of image to be masked prop in prop_list, we first mask the image, then for each r in factors_list we reconstruct all three channels and stack them to get a complete reconstructed image.
We also keep a track of RMSE and PSNR metrics for plotting later on.