LDDMM: how to register a mesh to a template?#

[1]:
import shutil
from pathlib import Path

import numpy as np
import pyvista as pv

import polpo.lddmm as plddmm
import polpo.preprocessing.dict as ppdict
from polpo.plot.pyvista import RegisteredMeshesGifPlotter
from polpo.preprocessing import Map
from polpo.preprocessing.load.deformetrica import LoadMeshFlow
from polpo.preprocessing.load.pregnancy.pilot import (
    HippocampalSubfieldsSegmentationsLoader,
)
from polpo.preprocessing.mesh.conversion import PvFromData
from polpo.preprocessing.mesh.io import PvWriter
from polpo.preprocessing.mesh.registration import RigidAlignment
from polpo.preprocessing.mesh.smoothing import PvSmoothTaubin
from polpo.preprocessing.mesh.transform import MeshCenterer
from polpo.preprocessing.mri import (
    MeshExtractorFromSegmentedImage,
    MeshExtractorFromSegmentedMesh,
)
No CUDA runtime is found, using CUDA_HOME='/usr'
[2]:
RECOMPUTE = True  # in case registration has already been run
STATIC_VIZ = True

if STATIC_VIZ:
    pv.set_jupyter_backend("static")
[3]:
SOURCE_INDEX = 3
TARGET_INDEX = 14

STRUCT_NAME = "PostHipp"

OUTPUTS_DIR = Path("results") / f"registration_{STRUCT_NAME.lower()}"
REGISTRATION_DIR = OUTPUTS_DIR / "registration"

if OUTPUTS_DIR.exists() and RECOMPUTE:
    shutil.rmtree(OUTPUTS_DIR)

Load meshes#

Following How to get a mesh from an MRI image?, we start by loading two selected meshes. See also How to select a mesh subset?.

[4]:
struct_from_image = True

path2img = HippocampalSubfieldsSegmentationsLoader(
    subset=[SOURCE_INDEX, TARGET_INDEX], as_image=True
)

if struct_from_image:
    img2mesh = MeshExtractorFromSegmentedImage(
        struct_id=STRUCT_NAME, encoding="ashs"
    ) + PvFromData(keep_colors=False)
else:
    img2mesh = (
        MeshExtractorFromSegmentedImage(struct_id=-1, encoding="ashs")
        + PvFromData()
        + MeshExtractorFromSegmentedMesh(struct_id=STRUCT_NAME, encoding="ashs")
    )

pipe = path2img + ppdict.DictMap(img2mesh) + ppdict.DictToValuesList()
[5]:
meshes = pipe()
INFO: Data has already been downloaded... using cached file ('/home/luisfpereira/.herbrain/data/pregnancy/derivatives/segmentations').
[6]:
pl = pv.Plotter(border=False)

pl.add_mesh(meshes[0], show_edges=True, color="red")
pl.add_mesh(meshes[1], show_edges=True, color="green")

pl.show()
../../../_images/_notebooks_how_to_lddmm_lddmm_register_mesh_template_8_0.png

Preprocessing#

As we can see in the visualization, meshes are not rigid aligned. Preprocessing takes care of this kind of details, before applying LDDMM.

[7]:
# TODO: consider decimation if above a given number of points

prep_pipe = Map(MeshCenterer() + PvSmoothTaubin(n_iter=20)) + RigidAlignment(
    max_iterations=10
)
[8]:
meshes = prep_pipe(meshes)

[mesh.points.shape[0] for mesh in meshes]
[8]:
[768, 681]
[9]:
pl = pv.Plotter(border=False)

pl.add_mesh(meshes[0], show_edges=True, color="red")
pl.add_mesh(meshes[1], show_edges=True, color="green")

pl.show()
../../../_images/_notebooks_how_to_lddmm_lddmm_register_mesh_template_13_0.png

Registration#

Save meshes in vtk format (as required by deformetrica).

[10]:
meshes_writer = Map(PvWriter(dirname=OUTPUTS_DIR, ext="vtk"))

mesh_filenames = [f"mesh_{which}" for which in ["source", "target"]]
mesh_filenames = meshes_writer(list(zip(mesh_filenames, meshes)))

Use LDDMM to register the meshes.

[11]:
kernel_width = 10.0
registration_kwargs = dict(
    kernel_width=kernel_width,
    regularisation=1.0,
    max_iter=2000,
    freeze_control_points=False,
    metric="varifold",
    tol=1e-16,
    attachment_kernel_width=2.0,
)

if not REGISTRATION_DIR.exists():
    plddmm.registration.estimate_registration(
        mesh_filenames[0],
        mesh_filenames[1],
        output_dir=REGISTRATION_DIR,
        **registration_kwargs,
    )
Logger has been set to: DEBUG
>> No initial CP spacing given: using diffeo kernel width of 10.0
OMP_NUM_THREADS was not found in environment variables. An automatic value will be set.
OMP_NUM_THREADS will be set to 10
context has already been set
>> No specified state-file. By default, Deformetrica state will by saved in file: results/registration_posthipp/registration/deformetrica-state.p.
>> Using a Sobolev gradient for the template data with the ScipyLBFGS estimator memory length being larger than 1. Beware: that can be tricky.
instantiating kernel torch with kernel_width 10.0 and gpu_mode GpuMode.KERNEL. addr: 0x7684d0876bd0
instantiating kernel torch with kernel_width 2.0 and gpu_mode GpuMode.KERNEL. addr: 0x7684d046e790
>> Set of 10 control points defined.
>> Momenta initialized to zero, for 1 subjects.
dtype=float32
>> Started estimator: ScipyOptimize

>> Scipy optimization method: L-BFGS-B

------------------------------------- Iteration: 1 -------------------------------------

------------------------------------- Iteration: 20 -------------------------------------
>> Log-likelihood = -1.112E+03     [ attachment = -9.075E+02 ; regularity = -2.041E+02 ]

------------------------------------- Iteration: 40 -------------------------------------
>> Log-likelihood = -9.826E+02     [ attachment = -7.854E+02 ; regularity = -1.972E+02 ]

------------------------------------- Iteration: 60 -------------------------------------
>> Log-likelihood = -9.533E+02     [ attachment = -7.490E+02 ; regularity = -2.043E+02 ]
>> Log-likelihood = -9.466E+02     [ attachment = -7.446E+02 ; regularity = -2.019E+02 ]

------------------------------------- Iteration: 80 -------------------------------------
>> Log-likelihood = -9.269E+02     [ attachment = -7.152E+02 ; regularity = -2.117E+02 ]
>> Log-likelihood = -9.256E+02     [ attachment = -7.142E+02 ; regularity = -2.114E+02 ]

------------------------------------- Iteration: 100 -------------------------------------
>> Log-likelihood = -9.094E+02     [ attachment = -6.961E+02 ; regularity = -2.133E+02 ]

------------------------------------- Iteration: 120 -------------------------------------
>> Log-likelihood = -8.942E+02     [ attachment = -6.708E+02 ; regularity = -2.234E+02 ]

------------------------------------- Iteration: 140 -------------------------------------
>> Log-likelihood = -8.730E+02     [ attachment = -6.581E+02 ; regularity = -2.149E+02 ]

------------------------------------- Iteration: 160 -------------------------------------
>> Log-likelihood = -8.622E+02     [ attachment = -6.394E+02 ; regularity = -2.228E+02 ]

------------------------------------- Iteration: 180 -------------------------------------
>> Log-likelihood = -8.524E+02     [ attachment = -6.358E+02 ; regularity = -2.166E+02 ]

------------------------------------- Iteration: 200 -------------------------------------
>> Log-likelihood = -8.457E+02     [ attachment = -6.284E+02 ; regularity = -2.172E+02 ]

------------------------------------- Iteration: 220 -------------------------------------
>> Log-likelihood = -8.410E+02     [ attachment = -6.322E+02 ; regularity = -2.089E+02 ]

------------------------------------- Iteration: 240 -------------------------------------
>> Log-likelihood = -8.379E+02     [ attachment = -6.275E+02 ; regularity = -2.104E+02 ]

------------------------------------- Iteration: 260 -------------------------------------
>> Log-likelihood = -8.360E+02     [ attachment = -6.241E+02 ; regularity = -2.118E+02 ]

------------------------------------- Iteration: 280 -------------------------------------
>> Log-likelihood = -8.348E+02     [ attachment = -6.237E+02 ; regularity = -2.111E+02 ]

------------------------------------- Iteration: 300 -------------------------------------
>> Log-likelihood = -8.340E+02     [ attachment = -6.234E+02 ; regularity = -2.106E+02 ]

------------------------------------- Iteration: 320 -------------------------------------
>> Log-likelihood = -8.333E+02     [ attachment = -6.222E+02 ; regularity = -2.111E+02 ]

------------------------------------- Iteration: 340 -------------------------------------
>> Log-likelihood = -8.325E+02     [ attachment = -6.226E+02 ; regularity = -2.099E+02 ]

------------------------------------- Iteration: 360 -------------------------------------
>> Log-likelihood = -8.319E+02     [ attachment = -6.201E+02 ; regularity = -2.118E+02 ]
>> Gradient at Termination: 439.6210930702169
>> CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
>> Estimation took: 38 seconds
Deformetrica.__del__()

Visualization#

[12]:
source, target = meshes

Template and control points#

[13]:
cp = plddmm.io.load_cp(REGISTRATION_DIR)
momenta = plddmm.io.load_momenta(REGISTRATION_DIR)

template = plddmm.io.load_template(REGISTRATION_DIR, as_pv=True)

Confirm that source and template are the same mesh.

[14]:
np.amax(np.abs(template.points - source.points))
[14]:
0.0

Visualize template with control points.

[15]:
pl = pv.Plotter()

pl.add_mesh(template, show_edges=True)

# TODO: add velocity and allow for filtering
# control_points.set_active_scalars("Velocity")
pl.add_points(pv.PolyData(cp), color="red")

pl.show()
../../../_images/_notebooks_how_to_lddmm_lddmm_register_mesh_template_26_0.png

Target and reconstructed meshes#

[16]:
reconstructed = plddmm.io.load_deterministic_atlas_reconstruction(
    REGISTRATION_DIR, as_pv=True
)
[17]:
pl = pv.Plotter()

pl.add_mesh(reconstructed, show_edges=True)
pl.add_mesh(target, opacity=0.55)

pl.show()
../../../_images/_notebooks_how_to_lddmm_lddmm_register_mesh_template_29_0.png

Flow#

[18]:
flow_meshes = plddmm.io.load_deterministic_atlas_flow(
    REGISTRATION_DIR,
    as_pv=True,
)
[19]:
pl = RegisteredMeshesGifPlotter()

pl.add_meshes(flow_meshes)
pl.close()

pl.show()
[19]:
../../../_images/_notebooks_how_to_lddmm_lddmm_register_mesh_template_32_0.png

Reconstruct flow by shooting#

[20]:
SHOOT_DIR = OUTPUTS_DIR / "shoot"
[21]:
plddmm.geometry.shoot(
    source=plddmm.io.load_template(REGISTRATION_DIR, as_path=True),
    control_points=plddmm.io.load_cp(REGISTRATION_DIR, as_path=True),
    momenta=plddmm.io.load_momenta(REGISTRATION_DIR, as_path=True),
    kernel_width=kernel_width,
    output_dir=SHOOT_DIR,
    write_adjoint_parameters=False,
)
[ compute_shooting function ]
Defaulting geodesic t0 to 0.
Defaulting geodesic tmax to 1.
Defaulting geodesic tmin to 0.
/home/luisfpereira/miniconda3/lib/python3.11/site-packages/in_out/dataset_functions.py:265: UserWarning: Watch out, I did not get a distance type for the object shape, Please make sure you are running shooting or a parallel transport, otherwise distances are required.
  warnings.warn(msg)
[22]:
flow_meshes_ = (LoadMeshFlow(as_path=False) + ppdict.DictToValuesList())(SHOOT_DIR)

diffs = [
    np.abs(np.amax(mesh.points - cmp_mesh.points))
    for mesh, cmp_mesh in zip(flow_meshes, flow_meshes_)
]

# TODO: identify origin of differences
np.amax(diffs)
[22]:
0.46335089206695557
[23]:
pl = RegisteredMeshesGifPlotter()

pl.add_meshes(flow_meshes_)
pl.close()

pl.show()
[23]:
../../../_images/_notebooks_how_to_lddmm_lddmm_register_mesh_template_37_0.png

Further reading#