LDDMM: how to estimate a deterministic atlas?#

This is analog to a registration problem for multiple meshes.

For registration between two meshes check LDDMM: how to register a mesh to a template?.

[1]:
import shutil
import string

import numpy as np
import pyvista as pv

import polpo.utils as putils
from polpo.mesh.deformetrica import FrechetMean, LddmmMetric, Point
from polpo.mesh.generation.blob import create_blob
from polpo.preprocessing.mesh.registration import RigidAlignment
[KeOps] Warning : CUDA was detected, but driver API could not be initialized. Switching to CPU only.
[2]:
RECOMPUTE = False

# NB: fix seed before setting it to False
np.random.seed(42)

STATIC_VIZ = True

if STATIC_VIZ:
    pv.set_jupyter_backend("static")
[3]:
OUTPUTS_DIR = putils.get_results_path() / "deterministic_atlas_blob_example"

if OUTPUTS_DIR.exists() and RECOMPUTE:
    shutil.rmtree(OUTPUTS_DIR)


OUTPUTS_DIR.mkdir(exist_ok=True, parents=True)

Generate meshes#

[4]:
n_meshes = 3
bump_amp = 0.2

raw_meshes = [
    create_blob(resolution=10, bump_amp=bump_amp, n_bumps=5, smoothing_iter=10)
    for _ in range(n_meshes)
]

raw_meshes[0].points.shape
[4]:
(82, 3)
[5]:
pl = pv.Plotter(border=False)

for mesh in raw_meshes:
    pl.add_mesh(mesh, show_edges=True, opacity=0.5)

pl.show()
../../../_images/_generated_notebooks_how_to_deformetrica_deterministic_atlas_7_0.png
[6]:
prep_pipe = RigidAlignment(known_correspondences=True)

meshes = prep_pipe(raw_meshes)
[7]:
pl = pv.Plotter(border=False)

for mesh in meshes:
    pl.add_mesh(mesh, show_edges=True, opacity=0.5)

pl.show()
../../../_images/_generated_notebooks_how_to_deformetrica_deterministic_atlas_9_0.png

Deterministic atlas#

[8]:
kernel_width = 2 * bump_amp
registration_kwargs = dict(
    kernel_width=kernel_width,
    regularisation=1.0,
    max_iter=2000,
    freeze_control_points=False,
    metric="varifold",
    tol=1e-16,
    attachment_kernel_width=bump_amp,
)

metric = LddmmMetric(OUTPUTS_DIR, **registration_kwargs)
[9]:
dataset = [
    Point(id_=string.ascii_uppercase[index], pv_surface=mesh, dirname=metric.meshes_dir)
    for index, mesh in enumerate(meshes)
]

Use LDDMM to estimate the atlas.

[10]:
estimator = FrechetMean(
    metric,
    initial_step_size=1e-1,
)

estimator.fit(dataset, atlas_id="atlas")
[10]:
<polpo.mesh.deformetrica.FrechetMean at 0x768e092eb890>
[11]:
atlas = estimator.estimate_
[12]:
pl = pv.Plotter(border=False)

for point in atlas.points:
    pl.add_mesh(point.as_pv(), opacity=0.2, color="red")
pl.add_mesh(
    atlas.as_pv(),
    show_edges=True,
    opacity=0.5,
    color="green",
    label="atlas",
)

pl.add_points(pv.PolyData(atlas.control_points(as_path=False)), color="green")
pl.add_legend()


pl.show()
../../../_images/_generated_notebooks_how_to_deformetrica_deterministic_atlas_16_0.png

Reconstruction

[13]:
for tangent_vec in atlas.tangent_vecs():
    pl = pv.Plotter()

    pl.add_mesh(
        tangent_vec.reconstructed().as_pv(),
        show_edges=True,
        opacity=0.5,
        label="rec",
    )
    pl.add_mesh(
        tangent_vec.point.as_pv(),
        opacity=0.5,
        color="red",
        label="target",
    )

    pl.add_title(tangent_vec.point.id)
    pl.add_legend()

    pl.show()
../../../_images/_generated_notebooks_how_to_deformetrica_deterministic_atlas_18_0.png
../../../_images/_generated_notebooks_how_to_deformetrica_deterministic_atlas_18_1.png
../../../_images/_generated_notebooks_how_to_deformetrica_deterministic_atlas_18_2.png

Further reading#