How to do mesh-valued regression?#

[1]:
from pathlib import Path

import numpy as np
import pyvista as pv
from IPython.display import Image
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import FunctionTransformer, StandardScaler

import polpo.preprocessing.pd as ppd
from polpo.models import ObjectRegressor
from polpo.preprocessing import IndexMap, NestingSwapper, PartiallyInitializedStep
from polpo.preprocessing.dict import DictMap, DictMerger
from polpo.preprocessing.load.pregnancy import (
    DenseMaternalCsvDataLoader,
    DenseMaternalMeshLoader,
)
from polpo.preprocessing.mesh.io import PvReader
from polpo.preprocessing.mesh.registration import PvAlign
from polpo.preprocessing.sklearn.adapter import AdapterPipeline
from polpo.preprocessing.sklearn.mesh import InvertibleMeshesToVertices
from polpo.preprocessing.sklearn.np import InvertibleFlattenButFirst
[KeOps] Warning : cuda was detected, but driver API could not be initialized. Switching to cpu only.
[2]:
STATIC_VIZ = True

if STATIC_VIZ:
    pv.set_jupyter_backend("static")

Loading meshes#

[3]:
prep_pipe = PartiallyInitializedStep(
    Step=lambda **kwargs: DictMap(PvAlign(**kwargs)),
    _target=lambda meshes: meshes[list(meshes.keys())[0]],
    max_iterations=500,
)
[4]:
subject_id = "01"

file_finder = DenseMaternalMeshLoader(
    subject_id=subject_id,
    as_dict=True,
    left=True,
    struct="Hipp",
)

pipe = file_finder + DictMap(PvReader()) + prep_pipe

meshes = pipe()

Loading tabular data#

[5]:
pilot = subject_id == "01"

pipe = DenseMaternalCsvDataLoader(pilot=pilot, subject_id=subject_id)

df = pipe()
INFO: Data has already been downloaded... using cached file ('/home/luisfpereira/.herbrain/data/maternal/28Baby_Hormones.csv').

Here, we filter the tabular data.

[6]:
session_selector = ppd.DfIsInFilter("stage", ["post"], negation=True)

predictor_selector = (
    session_selector + ppd.ColumnsSelector("gestWeek") + ppd.SeriesToDict()
)
[7]:
x_dict = predictor_selector(df)

Create and fit regressor#

Follow How to perform dimensionality reduction on a mesh?, we create a pipeline to transform the output variable.

[8]:
pca = PCA(n_components=4)

objs2y = AdapterPipeline(
    steps=[
        FunctionTransformer(func=np.squeeze),  # undo sklearn 2d
        InvertibleMeshesToVertices(index=0),
        FunctionTransformer(func=np.stack),
        InvertibleFlattenButFirst(),
        StandardScaler(with_std=False),
        pca,
    ],
)

We get the data in the proper format for fitting and instantiate a regressor model.

[9]:
dataset_pipe = (
    DictMerger() + NestingSwapper() + IndexMap(lambda x: np.array(x)[:, None], index=0)
)

X, meshes_ = dataset_pipe((x_dict, meshes))

X.shape, len(meshes_)
[9]:
((19, 1), 19)
[10]:
model = ObjectRegressor(LinearRegression(fit_intercept=True), objs2y=objs2y)
[11]:
model.fit(X, meshes_)
[11]:
ObjectRegressor()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Visualize predictions#

[12]:
X_pred = np.linspace(-3, 42, num=10)[:, None]

meshes_pred = model.predict(X_pred)
[13]:
outputs_dir = Path("_images")
if not outputs_dir.exists():
    outputs_dir.mkdir()

gif_name = outputs_dir / "regression.gif"

pl = pv.Plotter(
    border=False,
    off_screen=True,
    notebook=False,
)

pl.open_gif(gif_name.as_posix(), fps=3)

rendered_mesh = meshes_pred[0].copy()
pl.add_mesh(rendered_mesh, show_edges=True)
pl.add_title(f"{(X_pred[0][0]):.0f}")
pl.write_frame()

for index, mesh in enumerate(meshes_pred[1:]):
    rendered_mesh.points = mesh.points
    pl.render()

    pl.add_title(f"{(X_pred[index+1][0]):.0f}")

    pl.write_frame()


pl.close()
[14]:
Image(open(gif_name, "rb").read())
[14]:
../../../_images/_notebooks_how_to_all_mesh_valued_regression_21_0.png

Let’s check the predicted volumes.

[15]:
volumes = [mesh.volume for mesh in meshes_pred]

plt.scatter(X_pred, volumes)
plt.xlabel("Gestational week")
plt.ylabel("Volume");
../../../_images/_notebooks_how_to_all_mesh_valued_regression_23_0.png