Explore Diffeomorphisms of Space#

Setup & Imports#

import setup


%load_ext autoreload
%autoreload 2
import pyLDDMM
from pyLDDMM.utils.visualization import loadimg, saveimg, save_animation, plot_warpgrid
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML
from matplotlib.animation import FuncAnimation
import torch
import os

os.environ["GEOMSTATS_BACKEND"] = "pytorch"

import geomstats.backend as gs

from geomstats.geometry.pullback_metric import PullbackMetric
Examples: Circle to Square#

i0 = loadimg("./example_images/circle.png")
i1 = loadimg("./example_images/square.png")

fig = plt.figure(figsize=(10, 5))
ax1 = fig.add_subplot(121)
ax1.imshow(i0, cmap="gray")
ax1.set_title("Template Image")
ax2 = fig.add_subplot(122)
ax2.imshow(i1, cmap="gray")
ax2.set_title("Target Image");

Algorithm Inputs:#

\(I_0\): image, ndarray of dimension \(H \times W \times n\)

\(I_1\): image, ndarray of dimension \(H \times W \times n\)

\(T\): int, simulated discrete time steps

\(K\): int, maximum iterations

\(\sigma\): float, sigma for \(L_2\) loss. lower values strengthen the \(L_2\) loss

\(\alpha\): float, smoothness regularization. Higher values regularize stronger

\(\gamma\): float, norm penalty. Positive value to ensure injectivity of the regularizer

\(\epsilon\): float, learning rate

Algorithm Outputs:#

\(\phi_0\): forward flow

\(\phi_1\): backward flow

\(J_0\): time-series images generated by forward-pushing \(I_0\) using forward flow

\(J_1\): time-series images generated by pulling-back \(I_1\) using backward flow

\(\text{length}\): length of path on the manifold

\(v\): final velocity field


\(\text{im}\): \(\text{im} = J_0[-1]\)

lddmm = pyLDDMM.LDDMM2D()
im, v, energies, length, Phi0, Phi1, J0, J1 = lddmm.register(
    i0, i1, sigma=0.05, alpha=1, epsilon=0.0001, K=50
plt.imshow(im, cmap="gray");

Visualize velocity vector field over time#

time_points = v.shape[0]
num_rows = v.shape[1]
num_columns = v.shape[2]

# Create a grid of points
x = np.linspace(0, num_columns - 1, num_columns)
y = np.linspace(0, num_rows - 1, num_rows)
X, Y = np.meshgrid(x, y)

fig, ax = plt.subplots(figsize=(8, 8))

# Create a quiver plot for the initial vector field
Q = ax.quiver(X, Y, v[0, :, :, 0], v[0, :, :, 1], color="red")

ax.set_title("Vector field visualization")

# Update function for the animation
def update(num):
    U = v[num, :, :, 0]
    V = v[num, :, :, 1]

    # Update the data for the quiver plot
    Q.set_UVC(U, V)

    return (Q,)

# Create the animation
ani = FuncAnimation(fig, update, frames=range(time_points), blit=True)

Visualize grid distortion#

fig = plt.figure(figsize=(10, 5))
ax0 = fig.add_subplot(121)
ax0 = plot_warpgrid(Phi0[-1], interval=1, show_axis=True)
ax0.set_title("Forward flow warpgrid after registration")

ax1 = fig.add_subplot(122)
ax1 = plot_warpgrid(Phi1[0], interval=1, show_axis=True)
ax1.set_title("Backward flow warpgrid after registration");

Examples: Distorted Gaussians#

from scipy.stats import multivariate_normal

def create_gaussian_image(width, height, resolution, x_variance, y_variance, distance):
    x = np.linspace(0, width - 1, int(width / resolution))
    y = np.linspace(0, height - 1, int(height / resolution))
    x, y = np.meshgrid(x, y)

    pos = np.dstack((x, y))

    mu1 = np.array([width / 2 - distance / 2, height / 2])
    cov1 = np.array([[x_variance, 0], [0, y_variance]])  # covariance matrix

    rv1 = multivariate_normal(mu1, cov1)

    mu2 = np.array([width / 2 + distance / 2, height / 2])
    cov2 = np.array([[x_variance, 0], [0, y_variance]])  # covariance matrix

    rv2 = multivariate_normal(mu2, cov2)

    intensity_values1 = rv1.pdf(pos)
    intensity_values2 = rv2.pdf(pos)
    intensity_values = intensity_values1 + intensity_values2
    intensity_values = intensity_values / np.sum(intensity_values)

    return intensity_values

def create_uniform_image(width, height, resolution):
    x = np.linspace(-1, 1, int(width / resolution))
    y = np.linspace(-1, 1, int(height / resolution))
    x, y = np.meshgrid(x, y)

    # Calculate intensity values based on distance from center
    intensity_values = np.ones_like(x)

    intensity_values = intensity_values / np.sum(intensity_values)

    return intensity_values
image1 = create_gaussian_image(
    width=10, height=10, resolution=0.2, x_variance=0.2, y_variance=0.4, distance=4

image2 = create_gaussian_image(
    width=10, height=10, resolution=0.2, x_variance=1, y_variance=0.4, distance=4
# image2 = create_uniform_image(width=5,height=5,resolution=0.1)
# Create a figure with two subplots side by side
fig, axs = plt.subplots(1, 2, figsize=(16, 6))

# Plot the first image on the first subplot
axs[0].imshow(image1, cmap="RdPu")
axs[0].set_title("Image 1")

# Plot the second image on the second subplot
axs[1].imshow(image2, cmap="RdPu")
axs[1].set_title("Image 2")

# Display the plots

Learn deformation with FNN#

# hyperparameters
train_ratio = 0.8
batch_size = 10
learning_rate = 0.001

# hardware
device = "cuda" if torch.cuda.is_available() else "mps"

Load data#

x_in = Phi0[0]
x_out = Phi0[-1]

# reshape
input_data = x_in.reshape((-1, 2))
output_data = x_out.reshape((-1, 2))

# shuffle data
idx = np.arange(input_data.shape[0])
input_data = input_data[idx]
output_data = output_data[idx]

input_data = torch.from_numpy(input_data).float()
output_data = torch.from_numpy(output_data).float()

n_train = int(input_data.shape[0] * train_ratio)

train_dataset = torch.utils.data.TensorDataset(
    input_data[:n_train], output_data[:n_train]
val_dataset = torch.utils.data.TensorDataset(
    input_data[n_train:], output_data[n_train:]

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=True
# FCNET with residual connection
class FCNET(torch.nn.Module):
    def __init__(self, input_dim, output_dim, n_layers, n_neurons, activation):
        super(FCNET, self).__init__()
        self.n_layers = n_layers
        self.n_neurons = n_neurons
        self.layers = torch.nn.ModuleList()
        self.activation = activation
        for i in range(n_layers):
                torch.nn.Linear(input_dim if i == 0 else n_neurons, n_neurons)
        self.out = torch.nn.Linear(n_neurons, output_dim)

    def forward(self, x):
        x_in = x
        for i in range(self.n_layers):
            x = self.layers[i](x)
            x = self.activation(x)
        x = self.out(x)
        # return x
        return x + x_in
# train the neural network on the training dataset and validate on the validation dataset
def train_and_validate(
    train_loss = []
    val_loss = []
    for epoch in range(n_epochs):
        for x, y in train_dataloader:
            x = x.to(device)
            y = y.to(device)
            y_pred = net(x)
            loss = criterion(y_pred, y)
        if epoch % checkpoint_num == 0:
            print("Epoch %d, train loss %.4f" % (epoch, loss.item()))
        with torch.no_grad():
            for x, y in val_dataloader:
                x = x.to(device)
                y = y.to(device)
                y_pred = net(x)
                loss = criterion(y_pred, y)
            if epoch % checkpoint_num == 0:
                print("Epoch %d, val loss %.4f" % (epoch, loss.item()))
    return train_loss, val_loss
# Define the model
net = FCNET(
    input_dim=2, output_dim=2, n_layers=3, n_neurons=100, activation=torch.nn.Tanh()
net = net.float()

# Define the optimizer and the loss function
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, amsgrad=True)
criterion = torch.nn.MSELoss()
n_epochs = 200
checkpoint_num = 4
train_loss, test_loss = train_and_validate(
plt.plot(train_loss, label="train loss")
plt.plot(test_loss, label="test loss")
x_in_reshape = x_in.reshape((-1, 2))
x_in_reshape = torch.from_numpy(x_in_reshape).float()
x_in_reshape = x_in_reshape.to(device)
x_out_pred = net(x_in_reshape)
x_out_pred = x_out_pred.cpu().detach().numpy()
x_out_pred = x_out_pred.reshape(x_in.shape)
fig = plt.figure(figsize=(10, 10))
ax_in = fig.add_subplot(121)

ax_in = plot_warpgrid(x_out, interval=1)
ax_out = fig.add_subplot(122)
ax_out = plot_warpgrid(x_out_pred, interval=1)
torch.save(net.state_dict(), "model.pt")
res_net = FCNET(
    input_dim=2, output_dim=2, n_layers=3, n_neurons=100, activation=torch.nn.Tanh()
res_net = res_net.float()
res_state_dict = torch.load("model.pt")

for key in res_state_dict.keys():
    res_state_dict[key] = res_state_dict[key].float()

import torchlens as tl

# a little helper function to tell us where our model is
def get_model_device(model):
    return next(model.parameters()).device

x = next(iter(dataloaders["train"]))[0][:3]  # a sample of our batched training inputs
x = x.to(get_model_device(model))
model_history = tl.log_forward_pass(model, x, vis_opt="unrolled")

Compute pullback metric#

def get_learned_diffeo(model):
    def diffeo(x):
        # x = x.to(device)
        x = x.float().to(device)
        y = model(x)
        return y

    return diffeo
 In [39]:
diffeo = get_learned_diffeo(res_net)

pullback_metric = PullbackMetric(dim=2, embedding_dim=2, immersion=diffeo)
x = gs.linspace(0, 64, 64)
y = gs.linspace(0, 64, 64)
x_grid, y_grid = gs.meshgrid(x, y)

# Combine and reshape the x and y coordinates into a list of 2D points
points = gs.vstack((x_grid.ravel(), y_grid.ravel())).T

# Define your function
def volume_element(x, y):
    point = gs.array([x, y])
    g = pullback_metric.metric_matrix(point)
    vol = gs.sqrt(gs.abs(gs.linalg.det(g)))
    return vol

# Apply the function to each point in the list
values = gs.array([volume_element(x, y) for x, y in points])

# Reshape the values back into a 2D grid
z_values = values.reshape(x_grid.shape)

# Create the heatmap
    z_values, origin="lower", extent=[x.min(), x.max(), y.min(), y.max()], cmap="RdPu"
plt.colorbar(label="Function Value")
plt.title("Heatmap of volume element at (x, y)")
 In [ ]: