LDDMM: how to do regression?#
TODO: add goals of notebook
[1]:
import shutil
from pathlib import Path
import polpo.lddmm as plddmm
import polpo.preprocessing.dict as ppdict
import polpo.preprocessing.pd as ppd
from polpo.preprocessing import NestingSwapper
from polpo.preprocessing.dict import (
DictFilter,
DictMap,
DictMerger,
)
from polpo.preprocessing.load.pregnancy.jacobs import TabularDataLoader
from polpo.preprocessing.load.pregnancy.pilot import (
HippocampalSubfieldsSegmentationsLoader,
)
from polpo.preprocessing.mesh.conversion import PvFromData
from polpo.preprocessing.mesh.io import DictMeshWriter
from polpo.preprocessing.mesh.registration import RigidAlignment
from polpo.preprocessing.mesh.smoothing import PvSmoothTaubin
from polpo.preprocessing.mesh.transform import MeshCenterer
from polpo.preprocessing.mri import (
MeshExtractorFromSegmentedImage,
MeshExtractorFromSegmentedMesh,
)
W1014 14:47:58.272000 63549 site-packages/torch/utils/cpp_extension.py:118] No CUDA runtime is found, using CUDA_HOME='/usr'
[2]:
RECOMPUTE = False
[3]:
T_MIN = 1.0
T_MAX = 25.0
TEMPLATE_SESSION = 3
TARGET_SESSION = 14
STRUCT_NAME = "PostHipp"
OUTPUTS_DIR = Path("results") / f"regression_{STRUCT_NAME.lower()}"
REGISTRATION_DIR = OUTPUTS_DIR / "registration"
REGRESSION_DIR = OUTPUTS_DIR / "spline_regression"
if OUTPUTS_DIR.exists() and RECOMPUTE:
shutil.rmtree(OUTPUTS_DIR)
OUTPUTS_DIR.mkdir(exist_ok=True)
Load predictor#
[4]:
loader = TabularDataLoader(subject_subset=["01"], index_by_session=True)
prep_pipe = (
ppd.ColumnsSelector("gestWeek")
+ ppd.SeriesToDict()
+ DictFilter(lambda value: T_MIN <= value <= T_MAX)
)
# session, week
predictor = (loader + prep_pipe)()
predictor.keys()
INFO: Data has already been downloaded... using cached file ('/home/luisfpereira/.herbrain/data/maternal/maternal_brain_project_pilot/rawdata/28Baby_Hormones.csv').
[4]:
dict_keys([3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14])
Load meshes#
Following data loading of LDDMM: how to register a mesh against a template?.
[5]:
struct_from_image = True
path2img = HippocampalSubfieldsSegmentationsLoader(
subset=list(predictor.keys()), as_image=True
)
if struct_from_image:
img2mesh = MeshExtractorFromSegmentedImage(
struct_id=STRUCT_NAME, encoding="ashs"
) + PvFromData(keep_colors=False)
else:
img2mesh = (
MeshExtractorFromSegmentedImage(struct_id=-1, encoding="ashs")
+ PvFromData()
+ MeshExtractorFromSegmentedMesh(struct_id=STRUCT_NAME, encoding="ashs")
)
pipe = path2img + ppdict.DictMap(img2mesh)
[6]:
raw_meshes = pipe()
raw_meshes.keys()
INFO: Data has already been downloaded... using cached file ('/home/luisfpereira/.herbrain/data/pregnancy/derivatives/segmentations').
[6]:
dict_keys([3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14])
Preprocessing meshes#
Following preprocessing of LDDMM: how to register a mesh against a template?, we center, smooth, and rigid align the meshes against the template.
[7]:
# TODO: consider decimation if above a given number of points
prep_pipe = DictMap(MeshCenterer() + PvSmoothTaubin(n_iter=20)) + RigidAlignment(
max_iterations=10
)
[8]:
meshes = prep_pipe(raw_meshes)
Save meshes in vtk
format (as required by deformetrica
).
[9]:
meshes_writer = DictMeshWriter(dirname=OUTPUTS_DIR, ext="vtk")
dataset = meshes_writer(meshes)
We can now create the dataset:
[10]:
(times, mesh_filenames) = (DictMerger() + NestingSwapper())([predictor, dataset])
times, mesh_filenames
[10]:
((1.0, 1.5, 2.0, 3.0, 9.0, 12.0, 14.0, 15.0, 17.0, 19.0, 22.0, 24.0),
(PosixPath('results/regression_posthipp/mesh_3.vtk'),
PosixPath('results/regression_posthipp/mesh_4.vtk'),
PosixPath('results/regression_posthipp/mesh_5.vtk'),
PosixPath('results/regression_posthipp/mesh_6.vtk'),
PosixPath('results/regression_posthipp/mesh_7.vtk'),
PosixPath('results/regression_posthipp/mesh_8.vtk'),
PosixPath('results/regression_posthipp/mesh_9.vtk'),
PosixPath('results/regression_posthipp/mesh_10.vtk'),
PosixPath('results/regression_posthipp/mesh_11.vtk'),
PosixPath('results/regression_posthipp/mesh_12.vtk'),
PosixPath('results/regression_posthipp/mesh_13.vtk'),
PosixPath('results/regression_posthipp/mesh_14.vtk')))
And we also normalize time:
[11]:
# TODO: do it in a sklearn style
min_time = min(times)
maxmindiff_time = max(times) - min_time
times = [(time_ - min_time) / maxmindiff_time for time_ in times]
LDDMM#
Step 1: find control points#
Follows LDDMM: how to register a mesh against a template?.
[12]:
# TODO: need to adapt registration parameters to substructure
registration_kwargs = dict(
kernel_width=4.0,
regularisation=1.0,
max_iter=2000,
freeze_control_points=False,
attachment_kernel_width=2.0,
metric="varifold",
tol=1e-16,
)
if not REGISTRATION_DIR.exists():
plddmm.registration.estimate_registration(
dataset[TEMPLATE_SESSION],
dataset[TARGET_SESSION],
output_dir=REGISTRATION_DIR,
**registration_kwargs,
)
initial_control_points = plddmm.io.load_cp(REGISTRATION_DIR, as_path=True)
Step 2: perform regression#
[15]:
mesh_filenames[0]
[15]:
PosixPath('results/regression_posthipp/mesh_3.vtk')
[13]:
spline_kwargs = dict(
initial_step_size=100,
regularisation=1.0,
freeze_external_forces=True,
freeze_control_points=True,
)
kwargs = registration_kwargs.copy()
kwargs.update(spline_kwargs)
target_weights = [1 / len(times)] * len(times)
# TODO: revisit
if not REGRESSION_DIR.exists():
plddmm.learning.estimate_spline_regression(
source=mesh_filenames[0],
targets=mesh_filenames,
output_dir=REGRESSION_DIR,
times=times,
subject_id=[""],
t0=min(times),
target_weights=target_weights,
initial_control_points=initial_control_points,
**kwargs,
)
Logger has been set to: DEBUG
OMP_NUM_THREADS was not found in environment variables. An automatic value will be set.
OMP_NUM_THREADS will be set to 10
>> Initial t0 set by the user to 0.00 ; note that the mean visit age is 0.46
context has already been set
>> No specified state-file. By default, Deformetrica state will by saved in file: results/regression_posthipp/spline_regression/deformetrica-state.p.
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[13], line 14
11 target_weights = [1 / len(times)] * len(times)
13 if not REGRESSION_DIR.exists():
---> 14 plddmm.learning.estimate_spline_regression(
15 source=mesh_filenames[0],
16 targets=mesh_filenames,
17 output_dir=REGRESSION_DIR,
18 times=times,
19 subject_id=[""],
20 t0=min(times),
21 target_weights=target_weights,
22 initial_control_points=initial_control_points,
23 **kwargs,
24 )
File ~/Repos/github/polpo/polpo/lddmm/learning.py:403, in estimate_spline_regression(source, targets, output_dir, times, subject_id, t0, max_iter, kernel_width, regularisation, number_of_time_steps, initial_step_size, kernel_type, kernel_device, initial_control_points, tol, freeze_control_points, use_rk2_for_flow, use_rk2_for_shoot, dimension, freeze_external_forces, target_weights, geodesic_weight, metric, attachment_kernel_width, print_every)
400 patient_output_dir = output_dir
402 deformetrica = Deformetrica(patient_output_dir, verbosity="DEBUG")
--> 403 deformetrica.estimate_spline_regression(
404 template_specifications=template,
405 dataset_specifications=data_set,
406 model_options=model,
407 estimator_options=optimization_parameters,
408 )
File ~/miniconda3/envs/deformetrica/lib/python3.12/site-packages/api/deformetrica.py:405, in Deformetrica.estimate_spline_regression(self, template_specifications, dataset_specifications, model_options, estimator_options, write_output)
401 template_specifications, model_options, estimator_options = self.further_initialization(
402 'Regression', template_specifications, model_options, dataset_specifications, estimator_options)
404 # Instantiate dataset.
--> 405 dataset = create_dataset(template_specifications,
406 dimension=model_options['dimension'], **dataset_specifications)
407 assert (dataset.is_time_series()), "Cannot estimate a spline regression from a non-time-series dataset."
409 # Instantiate model.
File ~/miniconda3/envs/deformetrica/lib/python3.12/site-packages/in_out/dataset_functions.py:33, in create_dataset(template_specifications, visit_ages, dataset_filenames, subject_ids, dimension)
31 reader = DeformableObjectReader()
32 for object_id in template_specifications.keys():
---> 33 if object_id not in dataset_filenames[i][j]:
34 raise RuntimeError('The template object with id ' + object_id + ' is not found for the visit '
35 + str(j) + ' of subject ' + str(i) + '. Check the dataset xml.')
36 else:
KeyError: 0