Explore Diffeomorphisms of Space#

Setup & Imports#

 In [2]:
import setup

setup.main()

%load_ext autoreload
%autoreload 2
Working directory:  /Users/facosta/Desktop/code/neurometry/neurometry/neuralwarp
Directory added to path:  /Users/facosta/Desktop/code/neurometry/neurometry
Directory added to path:  /Users/facosta/Desktop/code/neurometry/neurometry/neuralwarp
['/Users/facosta/Desktop/code/neurometry/neurometry/neuralwarp', '/Users/facosta/miniconda3/envs/neurometry/lib/python38.zip', '/Users/facosta/miniconda3/envs/neurometry/lib/python3.8', '/Users/facosta/miniconda3/envs/neurometry/lib/python3.8/lib-dynload', '', '/Users/facosta/miniconda3/envs/neurometry/lib/python3.8/site-packages', '/Users/facosta/Desktop/code/neurometry', '/Users/facosta/Desktop/code/neurometry/neurometry', '/Users/facosta/Desktop/code/neurometry/neurometry/neuralwarp']

Imports#

 In [3]:
import pyLDDMM
from pyLDDMM.utils.visualization import loadimg, saveimg, save_animation, plot_warpgrid
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML
from matplotlib.animation import FuncAnimation
import torch
import os

os.environ["GEOMSTATS_BACKEND"] = "pytorch"

import geomstats.backend as gs

from geomstats.geometry.pullback_metric import PullbackMetric
INFO: Using pytorch backend

Examples: Circle to Square#

 In [4]:
i0 = loadimg("./example_images/circle.png")
i1 = loadimg("./example_images/square.png")

fig = plt.figure(figsize=(10, 5))
ax1 = fig.add_subplot(121)
ax1.imshow(i0, cmap="gray")
ax1.set_title("Template Image")
ax2 = fig.add_subplot(122)
ax2.imshow(i1, cmap="gray")
ax2.set_title("Target Image");
../_images/notebooks_05_explore_diffeomorphisms_of_space_7_0.png

Algorithm Inputs:#

\(I_0\): image, ndarray of dimension \(H \times W \times n\)

\(I_1\): image, ndarray of dimension \(H \times W \times n\)

\(T\): int, simulated discrete time steps

\(K\): int, maximum iterations

\(\sigma\): float, sigma for \(L_2\) loss. lower values strengthen the \(L_2\) loss

\(\alpha\): float, smoothness regularization. Higher values regularize stronger

\(\gamma\): float, norm penalty. Positive value to ensure injectivity of the regularizer

\(\epsilon\): float, learning rate

Algorithm Outputs:#

\(\phi_0\): forward flow

\(\phi_1\): backward flow

\(J_0\): time-series images generated by forward-pushing \(I_0\) using forward flow

\(J_1\): time-series images generated by pulling-back \(I_1\) using backward flow

\(\text{length}\): length of path on the manifold

\(v\): final velocity field

\(\text{energies}\):

\(\text{im}\): \(\text{im} = J_0[-1]\)

 In [5]:
lddmm = pyLDDMM.LDDMM2D()
im, v, energies, length, Phi0, Phi1, J0, J1 = lddmm.register(
    i0, i1, sigma=0.05, alpha=1, epsilon=0.0001, K=50
)
iteration   0, energy 110400.00, thereof 0.00 regularization and 110400.00 intensity difference
iteration   1, energy 31376.69, thereof 56.71 regularization and 31319.99 intensity difference
iteration   2, energy 12797.85, thereof 68.09 regularization and 12729.77 intensity difference
iteration   3, energy 8087.11, thereof 73.68 regularization and 8013.43 intensity difference
iteration   4, energy 5748.45, thereof 77.48 regularization and 5670.97 intensity difference
iteration   5, energy 4167.99, thereof 79.65 regularization and 4088.34 intensity difference
iteration   6, energy 3341.06, thereof 81.71 regularization and 3259.35 intensity difference
iteration   7, energy 3302.86, thereof 82.79 regularization and 3220.07 intensity difference
iteration   8, energy 3206.30, thereof 84.61 regularization and 3121.69 intensity difference
iteration   9, energy 3263.73, thereof 85.24 regularization and 3178.50 intensity difference
iteration  10, energy 3093.66, thereof 86.96 regularization and 3006.69 intensity difference
iteration  11, energy 3088.58, thereof 87.38 regularization and 3001.20 intensity difference
iteration  12, energy 2909.72, thereof 89.07 regularization and 2820.65 intensity difference
iteration  13, energy 2882.24, thereof 89.41 regularization and 2792.83 intensity difference
iteration  14, energy 2341.25, thereof 91.07 regularization and 2250.18 intensity difference
iteration  15, energy 2642.59, thereof 91.37 regularization and 2551.22 intensity difference
iteration  16, energy 2000.54, thereof 92.95 regularization and 1907.58 intensity difference
iteration  17, energy 2275.22, thereof 93.24 regularization and 2181.99 intensity difference
iteration  18, energy 1803.73, thereof 94.73 regularization and 1709.00 intensity difference
iteration  19, energy 2022.94, thereof 95.03 regularization and 1927.91 intensity difference
iteration  20, energy 1643.14, thereof 96.41 regularization and 1546.72 intensity difference
iteration  21, energy 1851.62, thereof 96.66 regularization and 1754.96 intensity difference
iteration  22, energy 1573.71, thereof 97.93 regularization and 1475.77 intensity difference
iteration  23, energy 1737.14, thereof 98.32 regularization and 1638.82 intensity difference
iteration  24, energy 1508.85, thereof 99.55 regularization and 1409.30 intensity difference
iteration  25, energy 1652.01, thereof 99.97 regularization and 1552.04 intensity difference
iteration  26, energy 1453.80, thereof 101.19 regularization and 1352.60 intensity difference
iteration  27, energy 1581.09, thereof 101.62 regularization and 1479.48 intensity difference
iteration  28, energy 1407.52, thereof 102.81 regularization and 1304.71 intensity difference
iteration  29, energy 1574.88, thereof 103.23 regularization and 1471.65 intensity difference
iteration  30, energy 1261.04, thereof 104.66 regularization and 1156.38 intensity difference
iteration  31, energy 1478.42, thereof 104.77 regularization and 1373.65 intensity difference
iteration  32, energy 1274.33, thereof 106.02 regularization and 1168.31 intensity difference
iteration  33, energy 1411.64, thereof 106.34 regularization and 1305.30 intensity difference
iteration  34, energy 1271.60, thereof 107.45 regularization and 1164.15 intensity difference
iteration  35, energy 1367.43, thereof 107.86 regularization and 1259.57 intensity difference
iteration  36, energy 1259.58, thereof 108.90 regularization and 1150.68 intensity difference
iteration  37, energy 1335.03, thereof 109.33 regularization and 1225.70 intensity difference
iteration  38, energy 1243.86, thereof 110.32 regularization and 1133.54 intensity difference
iteration  39, energy 1415.82, thereof 110.74 regularization and 1305.07 intensity difference
iteration  40, energy 1103.52, thereof 112.05 regularization and 991.47 intensity difference
iteration  41, energy 1306.41, thereof 112.13 regularization and 1194.28 intensity difference
iteration  42, energy 1133.39, thereof 113.22 regularization and 1020.16 intensity difference
iteration  43, energy 1244.86, thereof 113.54 regularization and 1131.32 intensity difference
iteration  44, energy 1147.82, thereof 114.45 regularization and 1033.38 intensity difference
iteration  45, energy 1210.97, thereof 114.87 regularization and 1096.10 intensity difference
iteration  46, energy 1151.12, thereof 115.69 regularization and 1035.43 intensity difference
iteration  47, energy 1191.13, thereof 116.16 regularization and 1074.97 intensity difference
iteration  48, energy 1149.42, thereof 116.92 regularization and 1032.51 intensity difference
iteration  49, energy 1300.43, thereof 117.39 regularization and 1183.04 intensity difference
 In [6]:
plt.imshow(im, cmap="gray");
../_images/notebooks_05_explore_diffeomorphisms_of_space_10_0.png

Visualize velocity vector field over time#

 In [9]:
time_points = v.shape[0]
num_rows = v.shape[1]
num_columns = v.shape[2]

# Create a grid of points
x = np.linspace(0, num_columns - 1, num_columns)
y = np.linspace(0, num_rows - 1, num_rows)
X, Y = np.meshgrid(x, y)

fig, ax = plt.subplots(figsize=(8, 8))

# Create a quiver plot for the initial vector field
Q = ax.quiver(X, Y, v[0, :, :, 0], v[0, :, :, 1], color="red")

ax.set_title("Vector field visualization")
ax.set_xlabel("X")
ax.set_ylabel("Y")


# Update function for the animation
def update(num):
    U = v[num, :, :, 0]
    V = v[num, :, :, 1]

    # Update the data for the quiver plot
    Q.set_UVC(U, V)

    return (Q,)


# Create the animation
ani = FuncAnimation(fig, update, frames=range(time_points), blit=True)

HTML(ani.to_jshtml());
INFO: Animation.save using <class 'matplotlib.animation.HTMLWriter'>
 Out [9]:
../_images/notebooks_05_explore_diffeomorphisms_of_space_12_2.png

Visualize grid distortion#

 In [7]:
fig = plt.figure(figsize=(10, 5))
ax0 = fig.add_subplot(121)
ax0 = plot_warpgrid(Phi0[-1], interval=1, show_axis=True)
ax0.set_title("Forward flow warpgrid after registration")

ax1 = fig.add_subplot(122)
ax1 = plot_warpgrid(Phi1[0], interval=1, show_axis=True)
ax1.set_title("Backward flow warpgrid after registration");
../_images/notebooks_05_explore_diffeomorphisms_of_space_15_0.png

Examples: Distorted Gaussians#

 In [26]:
from scipy.stats import multivariate_normal


def create_gaussian_image(width, height, resolution, x_variance, y_variance, distance):
    x = np.linspace(0, width - 1, int(width / resolution))
    y = np.linspace(0, height - 1, int(height / resolution))
    x, y = np.meshgrid(x, y)

    pos = np.dstack((x, y))

    mu1 = np.array([width / 2 - distance / 2, height / 2])
    cov1 = np.array([[x_variance, 0], [0, y_variance]])  # covariance matrix

    rv1 = multivariate_normal(mu1, cov1)

    mu2 = np.array([width / 2 + distance / 2, height / 2])
    cov2 = np.array([[x_variance, 0], [0, y_variance]])  # covariance matrix

    rv2 = multivariate_normal(mu2, cov2)

    intensity_values1 = rv1.pdf(pos)
    intensity_values2 = rv2.pdf(pos)
    intensity_values = intensity_values1 + intensity_values2
    intensity_values = intensity_values / np.sum(intensity_values)

    return intensity_values


def create_uniform_image(width, height, resolution):
    x = np.linspace(-1, 1, int(width / resolution))
    y = np.linspace(-1, 1, int(height / resolution))
    x, y = np.meshgrid(x, y)

    # Calculate intensity values based on distance from center
    intensity_values = np.ones_like(x)

    intensity_values = intensity_values / np.sum(intensity_values)

    return intensity_values
 In [27]:
image1 = create_gaussian_image(
    width=10, height=10, resolution=0.2, x_variance=0.2, y_variance=0.4, distance=4
)

image2 = create_gaussian_image(
    width=10, height=10, resolution=0.2, x_variance=1, y_variance=0.4, distance=4
)
# image2 = create_uniform_image(width=5,height=5,resolution=0.1)
 In [28]:
# Create a figure with two subplots side by side
fig, axs = plt.subplots(1, 2, figsize=(16, 6))

# Plot the first image on the first subplot
axs[0].imshow(image1, cmap="RdPu")
axs[0].set_title("Image 1")

# Plot the second image on the second subplot
axs[1].imshow(image2, cmap="RdPu")
axs[1].set_title("Image 2")

# Display the plots
plt.show()
../_images/notebooks_05_explore_diffeomorphisms_of_space_19_0.png

Learn deformation with FNN#

 In [29]:
# hyperparameters
train_ratio = 0.8
batch_size = 10
learning_rate = 0.001

# hardware
device = "cuda" if torch.cuda.is_available() else "mps"

Load data#

 In [30]:
x_in = Phi0[0]
x_out = Phi0[-1]

# reshape
input_data = x_in.reshape((-1, 2))
output_data = x_out.reshape((-1, 2))

# shuffle data
idx = np.arange(input_data.shape[0])
np.random.shuffle(idx)
input_data = input_data[idx]
output_data = output_data[idx]

input_data = torch.from_numpy(input_data).float()
output_data = torch.from_numpy(output_data).float()

n_train = int(input_data.shape[0] * train_ratio)

train_dataset = torch.utils.data.TensorDataset(
    input_data[:n_train], output_data[:n_train]
)
val_dataset = torch.utils.data.TensorDataset(
    input_data[n_train:], output_data[n_train:]
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=True
)
 In [31]:
# FCNET with residual connection
class FCNET(torch.nn.Module):
    def __init__(self, input_dim, output_dim, n_layers, n_neurons, activation):
        super(FCNET, self).__init__()
        self.n_layers = n_layers
        self.n_neurons = n_neurons
        self.layers = torch.nn.ModuleList()
        self.activation = activation
        for i in range(n_layers):
            self.layers.append(
                torch.nn.Linear(input_dim if i == 0 else n_neurons, n_neurons)
            )
        self.out = torch.nn.Linear(n_neurons, output_dim)
        self.layers.append(self.out)

    def forward(self, x):
        x_in = x
        for i in range(self.n_layers):
            x = self.layers[i](x)
            x = self.activation(x)
        x = self.out(x)
        # return x
        return x + x_in
 In [32]:
# train the neural network on the training dataset and validate on the validation dataset
def train_and_validate(
    net,
    train_dataloader,
    val_dataloader,
    optimizer,
    criterion,
    n_epochs,
    checkpoint_num,
):
    train_loss = []
    val_loss = []
    for epoch in range(n_epochs):
        for x, y in train_dataloader:
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            y_pred = net(x)
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()
        train_loss.append(loss.item())
        if epoch % checkpoint_num == 0:
            print("Epoch %d, train loss %.4f" % (epoch, loss.item()))
        with torch.no_grad():
            for x, y in val_dataloader:
                x = x.to(device)
                y = y.to(device)
                y_pred = net(x)
                loss = criterion(y_pred, y)
            val_loss.append(loss.item())
            if epoch % checkpoint_num == 0:
                print("Epoch %d, val loss %.4f" % (epoch, loss.item()))
    return train_loss, val_loss
 In [33]:
# Define the model
net = FCNET(
    input_dim=2, output_dim=2, n_layers=3, n_neurons=100, activation=torch.nn.Tanh()
)
net = net.float()
net.to(device)

# Define the optimizer and the loss function
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, amsgrad=True)
criterion = torch.nn.MSELoss()
 In [34]:
n_epochs = 200
checkpoint_num = 4
train_loss, test_loss = train_and_validate(
    net,
    train_dataloader,
    val_dataloader,
    optimizer,
    criterion,
    n_epochs,
    checkpoint_num,
)
Epoch 0, train loss 0.0223
Epoch 0, val loss 0.1029
Epoch 4, train loss 0.0593
Epoch 4, val loss 0.0691
Epoch 8, train loss 0.2307
Epoch 8, val loss 0.5349
Epoch 12, train loss 0.1046
Epoch 12, val loss 0.4948
Epoch 16, train loss 0.0361
Epoch 16, val loss 0.1221
Epoch 20, train loss 0.2845
Epoch 20, val loss 0.1438
Epoch 24, train loss 0.0319
Epoch 24, val loss 0.1421
Epoch 28, train loss 0.0254
Epoch 28, val loss 0.0926
Epoch 32, train loss 0.0366
Epoch 32, val loss 0.1128
Epoch 36, train loss 0.0230
Epoch 36, val loss 0.0238
Epoch 40, train loss 0.0384
Epoch 40, val loss 0.0407
Epoch 44, train loss 0.0745
Epoch 44, val loss 0.5351
Epoch 48, train loss 0.0909
Epoch 48, val loss 0.0523
Epoch 52, train loss 0.0533
Epoch 52, val loss 0.0190
Epoch 56, train loss 0.1869
Epoch 56, val loss 0.0060
Epoch 60, train loss 0.0217
Epoch 60, val loss 0.6439
Epoch 64, train loss 0.0349
Epoch 64, val loss 0.1460
Epoch 68, train loss 0.0447
Epoch 68, val loss 0.1220
Epoch 72, train loss 0.0145
Epoch 72, val loss 0.0481
Epoch 76, train loss 0.0425
Epoch 76, val loss 0.0418
Epoch 80, train loss 0.0025
Epoch 80, val loss 0.0049
Epoch 84, train loss 0.0166
Epoch 84, val loss 0.0066
Epoch 88, train loss 0.0556
Epoch 88, val loss 0.2860
Epoch 92, train loss 0.0196
Epoch 92, val loss 0.2982
Epoch 96, train loss 0.0060
Epoch 96, val loss 0.0026
Epoch 100, train loss 0.0317
Epoch 100, val loss 0.0593
Epoch 104, train loss 0.1573
Epoch 104, val loss 0.0060
Epoch 108, train loss 0.0044
Epoch 108, val loss 0.0084
Epoch 112, train loss 0.0051
Epoch 112, val loss 0.0049
Epoch 116, train loss 0.0147
Epoch 116, val loss 0.0146
Epoch 120, train loss 0.0219
Epoch 120, val loss 0.0851
Epoch 124, train loss 0.0432
Epoch 124, val loss 0.0105
Epoch 128, train loss 0.0255
Epoch 128, val loss 0.0152
Epoch 132, train loss 0.0425
Epoch 132, val loss 0.0412
Epoch 136, train loss 0.2181
Epoch 136, val loss 0.1264
Epoch 140, train loss 0.0485
Epoch 140, val loss 0.0164
Epoch 144, train loss 0.0241
Epoch 144, val loss 0.0230
Epoch 148, train loss 0.0270
Epoch 148, val loss 0.0708
Epoch 152, train loss 0.0170
Epoch 152, val loss 0.0205
Epoch 156, train loss 0.0165
Epoch 156, val loss 0.0056
Epoch 160, train loss 0.0095
Epoch 160, val loss 0.0281
Epoch 164, train loss 0.0150
Epoch 164, val loss 0.2030
Epoch 168, train loss 0.0221
Epoch 168, val loss 0.0075
Epoch 172, train loss 0.0370
Epoch 172, val loss 0.0030
Epoch 176, train loss 0.0428
Epoch 176, val loss 0.0059
Epoch 180, train loss 0.0022
Epoch 180, val loss 0.0257
Epoch 184, train loss 0.0033
Epoch 184, val loss 0.0063
Epoch 188, train loss 0.0239
Epoch 188, val loss 0.0029
Epoch 192, train loss 0.0023
Epoch 192, val loss 0.0041
Epoch 196, train loss 0.1460
Epoch 196, val loss 0.0913
 In [35]:
plt.plot(train_loss, label="train loss")
plt.plot(test_loss, label="test loss")
plt.legend()
plt.show()
../_images/notebooks_05_explore_diffeomorphisms_of_space_28_0.png
 In [36]:
x_in_reshape = x_in.reshape((-1, 2))
x_in_reshape = torch.from_numpy(x_in_reshape).float()
x_in_reshape = x_in_reshape.to(device)
x_out_pred = net(x_in_reshape)
x_out_pred = x_out_pred.cpu().detach().numpy()
x_out_pred = x_out_pred.reshape(x_in.shape)
 In [37]:
fig = plt.figure(figsize=(10, 10))
ax_in = fig.add_subplot(121)

ax_in = plot_warpgrid(x_out, interval=1)
ax_out = fig.add_subplot(122)
ax_out = plot_warpgrid(x_out_pred, interval=1)
../_images/notebooks_05_explore_diffeomorphisms_of_space_30_0.png
 In [22]:
torch.save(net.state_dict(), "model.pt")
 In [21]:
res_net = FCNET(
    input_dim=2, output_dim=2, n_layers=3, n_neurons=100, activation=torch.nn.Tanh()
)
res_net = res_net.float()
res_state_dict = torch.load("model.pt")

for key in res_state_dict.keys():
    res_state_dict[key] = res_state_dict[key].float()


res_net.load_state_dict(res_state_dict)
res_net.to(device)
res_net.eval()
 Out [21]:
FCNET(
  (layers): ModuleList(
    (0): Linear(in_features=2, out_features=100, bias=True)
    (1-2): 2 x Linear(in_features=100, out_features=100, bias=True)
    (3): Linear(in_features=100, out_features=2, bias=True)
  )
  (activation): Tanh()
  (out): Linear(in_features=100, out_features=2, bias=True)
)
 In [33]:
!pip install -q git+https://github.com/johnmarktaylor91/torchlens
 In [ ]:
import torchlens as tl


# a little helper function to tell us where our model is
def get_model_device(model):
    return next(model.parameters()).device


x = next(iter(dataloaders["train"]))[0][:3]  # a sample of our batched training inputs
x = x.to(get_model_device(model))
model_history = tl.log_forward_pass(model, x, vis_opt="unrolled")

Compute pullback metric#

 In [38]:
def get_learned_diffeo(model):
    def diffeo(x):
        # x = x.to(device)
        x = x.float().to(device)
        y = model(x)
        return y

    return diffeo
 In [39]:
diffeo = get_learned_diffeo(res_net)

pullback_metric = PullbackMetric(dim=2, embedding_dim=2, immersion=diffeo)
 In [40]:
x = gs.linspace(0, 64, 64)
y = gs.linspace(0, 64, 64)
x_grid, y_grid = gs.meshgrid(x, y)

# Combine and reshape the x and y coordinates into a list of 2D points
points = gs.vstack((x_grid.ravel(), y_grid.ravel())).T


# Define your function
def volume_element(x, y):
    point = gs.array([x, y])
    g = pullback_metric.metric_matrix(point)
    vol = gs.sqrt(gs.abs(gs.linalg.det(g)))
    return vol


# Apply the function to each point in the list
values = gs.array([volume_element(x, y) for x, y in points])

# Reshape the values back into a 2D grid
z_values = values.reshape(x_grid.shape)

# Create the heatmap
plt.imshow(
    z_values, origin="lower", extent=[x.min(), x.max(), y.min(), y.max()], cmap="RdPu"
)
plt.colorbar(label="Function Value")
plt.xlabel("X")
plt.ylabel("Y")
plt.title("Heatmap of volume element at (x, y)")
plt.show()
../_images/notebooks_05_explore_diffeomorphisms_of_space_38_0.png
 In [ ]: