LDDMM: how to estimate a deterministic atlas?#

This is analog to a registration problem for multiple meshes.

For registration between two meshes check LDDMM: how to register a mesh to a template?.

[1]:
import shutil
from pathlib import Path

import numpy as np
import pyvista as pv

import polpo.lddmm as plddmm
import polpo.preprocessing.dict as ppdict
import polpo.utils as putils
from polpo.plot.pyvista import RegisteredMeshesGifPlotter
from polpo.preprocessing.load.pregnancy.jacobs import MeshLoader
from polpo.preprocessing.mesh.io import DictMeshWriter
from polpo.preprocessing.mesh.registration import RigidAlignment
No CUDA runtime is found, using CUDA_HOME='/usr'
[2]:
DEBUG = False
RECOMPUTE = False
STATIC_VIZ = True

if STATIC_VIZ:
    pv.set_jupyter_backend("static")
[3]:
OUTPUTS_DIR = Path("results") / "atlas_example"
ATLAS_DIR = OUTPUTS_DIR / "atlas"

if OUTPUTS_DIR.exists() and RECOMPUTE:
    shutil.rmtree(OUTPUTS_DIR)


OUTPUTS_DIR.mkdir(exist_ok=True, parents=True)

Loading meshes#

[4]:
subject_id = "01"
session_subset = [1, 2, 3] if DEBUG else None


path2mesh = MeshLoader(
    subject_subset=[subject_id],
    session_subset=session_subset,
    struct_subset=["L_Hipp"],
    as_mesh=True,
)

pipe = path2mesh + ppdict.ExtractUniqueKey(nested=True)

meshes = pipe()
[5]:
pl = pv.Plotter(border=False)

for mesh in meshes.values():
    pl.add_mesh(mesh, show_edges=True, opacity=0.5)

pl.show()
../../../_images/_generated_notebooks_how_to_lddmm_deterministic_atlas_7_0.png
[6]:
prep_pipe = RigidAlignment(max_iterations=500)

meshes = prep_pipe(meshes)
[7]:
pl = pv.Plotter(border=False)

for mesh in meshes.values():
    pl.add_mesh(mesh, show_edges=True, opacity=0.5)

pl.show()
../../../_images/_generated_notebooks_how_to_lddmm_deterministic_atlas_9_0.png

Deterministic atlas#

Save meshes in vtk format (as required by deformetrica).

[8]:
meshes_writer = DictMeshWriter(dirname=OUTPUTS_DIR, ext="vtk")

dataset = meshes_writer(meshes)

Use LDDMM to estimate the atlas.

[9]:
kernel_width = 10.0  # NB: influences the number of cp
registration_kwargs = dict(
    kernel_width=kernel_width,
    regularisation=1,
    max_iter=2000,
    freeze_control_points=False,
    metric="varifold",
    attachment_kernel_width=2.0,
    tol=1e-10,
)

if not ATLAS_DIR.exists():
    plddmm.learning.estimate_deterministic_atlas(
        targets=dataset,
        output_dir=ATLAS_DIR,
        initial_step_size=1e-1,
        **registration_kwargs,
    )

Visualization#

Visualize the results.

[10]:
cp = plddmm.io.load_cp(ATLAS_DIR)
momenta = plddmm.io.load_momenta(ATLAS_DIR)
template = plddmm.io.load_template(ATLAS_DIR, as_pv=True)
[11]:
pl = pv.Plotter()

pl.add_mesh(template, show_edges=True)

pl.add_mesh(pv.PolyData(cp))

pl.show()
../../../_images/_generated_notebooks_how_to_lddmm_deterministic_atlas_18_0.png

Visualize the reconstruction errors.

[12]:
rec_meshes = plddmm.io.load_deterministic_atlas_reconstructions(ATLAS_DIR, as_pv=True)
[13]:
plot_shape = putils.plot_shape_from_n_plots(len(meshes), n_axis=4)

pl = pv.Plotter(shape=plot_shape, border=False)

for index, key in enumerate(meshes.keys()):
    mesh, rec_mesh = meshes[key], rec_meshes[key]
    pl.subplot(*putils.plot_index_to_shape(index, plot_shape[1]))

    # NB: assumes meshes are registered
    euc_res = np.linalg.norm(mesh.points - rec_mesh.points, axis=-1)
    pl.add_mesh(rec_mesh, scalars=euc_res)
    pl.add_title(str(key), font_size=8)

pl.show()
../../../_images/_generated_notebooks_how_to_lddmm_deterministic_atlas_21_0.png

Visualize the flow from the template to each of the meshes.

[14]:
flows = plddmm.io.load_deterministic_atlas_flows(ATLAS_DIR, as_pv=True)
[15]:
pl = RegisteredMeshesGifPlotter(
    shape=putils.plot_shape_from_n_plots(len(meshes), n_axis=4),
    fps=10,
    border=False,
    off_screen=True,
    notebook=False,
    subtitle=True,
)

pl.add_meshes(ppdict.DictListSwapper()(flows))
pl.close()

pl.show()
[15]:
../../../_images/_generated_notebooks_how_to_lddmm_deterministic_atlas_24_0.png

Further reading#