LDDMM: parallel transport#

Imagine we are given three meshes A, B, C.

In this notebook we will:

  • register A to B (fixed), and C to B (fixed)

  • parallel transport BC along BA

[1]:
import shutil
import string

import numpy as np
import pyvista as pv

import polpo.utils as putils
from polpo.mesh.deformetrica import LddmmMetric, Point
from polpo.mesh.generation.blob import create_blob
from polpo.plot.pyvista import RegisteredMeshesGifPlotter
from polpo.preprocessing.mesh.registration import RigidAlignment
[KeOps] Warning : CUDA was detected, but driver API could not be initialized. Switching to CPU only.
[2]:
RECOMPUTE = False

# NB: fix seed before setting it to False
np.random.seed(42)

STATIC_VIZ = True

if STATIC_VIZ:
    pv.set_jupyter_backend("static")
[3]:
OUTPUTS_DIR = putils.get_results_path() / "transport_abc_blob_example"

if OUTPUTS_DIR.exists() and RECOMPUTE:
    shutil.rmtree(OUTPUTS_DIR)

Generate meshes#

[4]:
n_meshes = 3
bump_amp = 0.2

raw_meshes = [
    create_blob(resolution=10, bump_amp=bump_amp, n_bumps=5, smoothing_iter=10)
    for _ in range(n_meshes)
]

raw_meshes[0].points.shape
[4]:
(82, 3)
[5]:
pl = pv.Plotter(border=False)

for mesh in raw_meshes:
    pl.add_mesh(mesh, show_edges=True, opacity=0.5)

pl.show()
../../../_images/_generated_notebooks_how_to_deformetrica_transport_abc_7_0.png
[6]:
prep_pipe = RigidAlignment(known_correspondences=True)

meshes = prep_pipe(raw_meshes)
[7]:
pl = pv.Plotter(border=False)

for mesh in meshes:
    pl.add_mesh(mesh, show_edges=True, opacity=0.5)

pl.show()
../../../_images/_generated_notebooks_how_to_deformetrica_transport_abc_9_0.png

Registrations#

[8]:
kernel_width = 2 * bump_amp
registration_kwargs = dict(
    kernel_width=kernel_width,
    regularisation=1.0,
    max_iter=2000,
    freeze_control_points=False,
    metric="varifold",
    tol=1e-16,
    attachment_kernel_width=bump_amp,
)

metric = LddmmMetric(OUTPUTS_DIR, use_pole_ladder=True, **registration_kwargs)
[9]:
point_a, point_b, point_c = [
    Point(id_=string.ascii_uppercase[index], pv_surface=mesh, dirname=metric.meshes_dir)
    for index, mesh in enumerate(meshes)
]

Closely following LDDMM: how to register a mesh to a template?, we register A to B and C to B.

[10]:
vec_ba = metric.log(point_a, point_b)
vec_bc = metric.log(point_c, point_b)
[11]:
for vec in [vec_ba, vec_bc]:
    pl = RegisteredMeshesGifPlotter()
    pl.add_title(f"{vec.base_point.id} -> {vec.point.id}")

    pl.add_meshes(vec.flow())
    pl.close()

    pl.show()
../../../_images/_generated_notebooks_how_to_deformetrica_transport_abc_15_0.png
../../../_images/_generated_notebooks_how_to_deformetrica_transport_abc_15_1.png

Parallel transport (pole ladder)#

[12]:
# NB: way slower?
trans_vec_bc_pole = metric.parallel_transport(vec_bc, point_b, direction=vec_ba)

This is how the deformation B->C looks from A.

[13]:
trans_point_c_pole = metric.exp(trans_vec_bc_pole, point_a)
[22]:
pl = RegisteredMeshesGifPlotter()

pl.add_title("A + BC")
pl.add_meshes(trans_point_c_pole.flow())
pl.close()

pl.show()
../../../_images/_generated_notebooks_how_to_deformetrica_transport_abc_20_0.png

Parallel transport (fanning)#

[15]:
metric.use_pole_ladder = False
[16]:
trans_vec_bc_fan = metric.parallel_transport(vec_bc, point_b, direction=vec_ba)

This is how the deformation B->C looks from A.

[17]:
trans_point_c_fan = metric.exp(trans_vec_bc_fan, point_a)
[23]:
pl = RegisteredMeshesGifPlotter()

pl.add_title("A + BC")
pl.add_meshes(trans_point_c_fan.flow())

pl.close()

pl.show()
../../../_images/_generated_notebooks_how_to_deformetrica_transport_abc_26_0.png

Some comparisons#

Extra information is computed with the fanning scheme:

  • an approximation of the end point of direction

  • a shooted mesh from that approximation using parallel transport

[19]:
pl = pv.Plotter()

pl.add_mesh(
    point_a.as_pv(),
    color="red",
    opacity=0.5,
    label="truth",
)
pl.add_mesh(
    trans_vec_bc_fan.reconstructed().as_pv(),
    color="green",
    opacity=0.5,
    label="fan-rec",
)
pl.add_mesh(
    vec_ba.reconstructed().as_pv(),
    color="yellow",
    opacity=0.5,
    label="registration-rec",
)


pl.add_legend()

pl.show()
../../../_images/_generated_notebooks_how_to_deformetrica_transport_abc_29_0.png
[20]:
pl = pv.Plotter()

pl.add_mesh(
    trans_vec_bc_fan.reconstructed_shooted().as_pv(),
    color="red",
    opacity=0.5,
    label="fan-rec-shoot",
)
pl.add_mesh(
    trans_point_c_fan.as_pv(),
    color="green",
    opacity=0.5,
    label="fan-shoot",
)
pl.add_mesh(
    trans_point_c_pole.as_pv(),
    color="yellow",
    opacity=0.5,
    label="pole-shoot",
)

pl.add_legend()

pl.show()
../../../_images/_generated_notebooks_how_to_deformetrica_transport_abc_30_0.png