LDDMM: how to select the deformation kernel#
Before, we showed how to select the varifold kernel. Here, we do something similar for the deformation kernel (NB: the approach is not fully automatic yet, but gives a good sense of what are the trade-offs involved on the selection of this hyperparameters).
[1]:
import random
import shutil
from pathlib import Path
import pyvista as pv
import polpo.lddmm as plddmm
import polpo.preprocessing.dict as ppdict
from polpo.mesh.surface import PvSurface
from polpo.mesh.varifold.tuning import SigmaFromLengths
from polpo.preprocessing import Map
from polpo.preprocessing.load.pregnancy.jacobs import MeshLoader, get_subject_ids
from polpo.preprocessing.mesh.decimation import PvDecimate
from polpo.preprocessing.mesh.io import PvWriter
from polpo.preprocessing.mesh.registration import RigidAlignment
[KeOps] Warning : CUDA was detected, but driver API could not be initialized. Switching to CPU only.
[2]:
DEBUG = False
STATIC_VIZ = True
VIZ = False
if STATIC_VIZ:
pv.set_jupyter_backend("static")
[3]:
STRUCT_NAME = "L_Hipp"
OUTPUTS_DIR = Path("results") / f"lddmm_kernel_tuning_{STRUCT_NAME}"
REGISTRATION_DIR = OUTPUTS_DIR / "registration"
if OUTPUTS_DIR.exists() and not DEBUG:
shutil.rmtree(OUTPUTS_DIR)
[4]:
subject_ids = random.sample(get_subject_ids(include_male=False, sort=True), 2)
subject_ids
[4]:
['1004B', '01']
[5]:
mesh_loader = (
MeshLoader(
subject_subset=subject_ids,
struct_subset=[STRUCT_NAME],
session_subset=None,
derivative="enigma",
as_mesh=True,
)
+ ppdict.DictMap(ppdict.ExtractRandomKey())
+ ppdict.ExtractUniqueKey(nested=True)
+ ppdict.DictToValuesList()
+ RigidAlignment(
target=lambda x: x[0],
known_correspondences=True,
)
+ Map(PvDecimate(target_reduction=0.6, volume_preservation=True) + PvSurface)
)
meshes = mesh_loader()
[mesh.n_points for mesh in meshes]
[5]:
[1002, 1002]
[6]:
if VIZ:
pl = pv.Plotter(border=False)
for mesh in meshes:
pl.add_mesh(mesh.as_pv(), show_edges=True, opacity=0.6)
pl.show()
We select the varifold kernel using characteristic lengths.
[7]:
sigma_search = SigmaFromLengths(
ratio_charlen_mesh=2.0,
ratio_charlen=0.25,
)
sigma_search.fit(meshes)
metric = sigma_search.optimal_metric_
sigma_search.sigma_
[7]:
np.float64(6.015737056732178)
Following LDDMM: how to register a mesh to a template?.
[8]:
meshes_writer = Map(PvWriter(dirname=OUTPUTS_DIR, ext="vtk"))
mesh_filenames = [f"mesh_{which}" for which in ["source", "target"]]
mesh_filenames = meshes_writer(list(zip(mesh_filenames, meshes)))
[9]:
def _registration_dir(kernel_width):
return REGISTRATION_DIR / f"{str(kernel_width).replace(".", "-")}"
kernel_widths = [3.0, 4.0, 5.0, 10.0]
registration_kwargs = dict(
regularisation=1.0,
max_iter=2000,
freeze_control_points=False,
metric="varifold",
tol=1e-16,
attachment_kernel_width=sigma_search.sigma_,
)
for kernel_width in kernel_widths:
registration_kwargs["kernel_width"] = kernel_width
registration_dir = _registration_dir(kernel_width)
if not registration_dir.exists():
plddmm.registration.estimate_registration(
mesh_filenames[0],
mesh_filenames[1],
output_dir=registration_dir,
**registration_kwargs,
)
Logger has been set to: WARNING
context has already been set
Logger has been set to: WARNING
context has already been set
Logger has been set to: WARNING
context has already been set
Logger has been set to: WARNING
context has already been set
[10]:
source, target = meshes
decimated_target = (
PvDecimate(target_reduction=0.1, volume_preservation=True) + PvSurface
)(target)
reconstructed = {}
cps = {}
for kernel_width in kernel_widths:
registration_dir = _registration_dir(kernel_width)
reconstructed[kernel_width] = PvSurface(
plddmm.io.load_deterministic_atlas_reconstruction(registration_dir, as_pv=True)
)
cps[kernel_width] = plddmm.io.load_cp(registration_dir)
[cps_.shape[0] for cps_ in cps.values()]
[10]:
[1008, 420, 240, 36]
[11]:
(
metric.dist(target, source),
metric.dist(target, decimated_target),
)
[11]:
(np.float64(129.1424244614494), np.float64(0.787139562713154))
[12]:
{
kernel_width: (
metric.dist(target, reconstructed_),
metric.dist(source, reconstructed_),
)
for kernel_width, reconstructed_ in reconstructed.items()
}
[12]:
{3.0: (np.float64(5.422280459845494), np.float64(128.31281408319776)),
4.0: (np.float64(7.553122099672369), np.float64(128.64740127143878)),
5.0: (np.float64(9.826397645725537), np.float64(128.91025052187894)),
10.0: (np.float64(23.836260213255446), np.float64(128.35104179628564))}
From the above, we can see that with smaller kernels we get closer to the target mesh (as expected, as they use more control points).
[13]:
if VIZ:
for kernel_width, reconstructed_ in reconstructed.items():
pl = pv.Plotter(border=False)
for mesh in [target, reconstructed_]:
pl.add_mesh(mesh.as_pv(), show_edges=True, opacity=0.6)
pl.add_title(str(kernel_width))
pl.show()