How to do mesh-valued regression?#

[1]:
from pathlib import Path

import numpy as np
import pyvista as pv
from IPython.display import Image
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
from sklearn.preprocessing import FunctionTransformer, StandardScaler

import polpo.preprocessing.dict as ppdict
import polpo.preprocessing.pd as ppd
from polpo.models import ObjectRegressor
from polpo.preprocessing import (
    Map,
    PartiallyInitializedStep,
)
from polpo.preprocessing.learning import DictsToXY
from polpo.preprocessing.load.pregnancy import (
    DenseMaternalCsvDataLoader,
    DenseMaternalMeshLoader,
)
from polpo.preprocessing.mesh.conversion import ToVertices
from polpo.preprocessing.mesh.io import PvReader
from polpo.preprocessing.mesh.registration import PvAlign
from polpo.sklearn.adapter import AdapterPipeline
from polpo.sklearn.mesh import BiMeshesToVertices
from polpo.sklearn.np import BiFlattenButFirst, FlattenButFirst
[KeOps] Warning : cuda was detected, but driver API could not be initialized. Switching to cpu only.
[2]:
STATIC_VIZ = True

if STATIC_VIZ:
    pv.set_jupyter_backend("static")

Loading meshes#

[3]:
prep_pipe = PartiallyInitializedStep(
    Step=lambda **kwargs: ppdict.DictMap(PvAlign(**kwargs)),
    _target=lambda meshes: meshes[list(meshes.keys())[0]],
    max_iterations=500,
)
[4]:
subject_id = "01"

file_finder = DenseMaternalMeshLoader(
    subject_id=subject_id,
    as_dict=True,
    left=True,
    struct="Hipp",
)

pipe = file_finder + ppdict.DictMap(PvReader()) + prep_pipe

meshes = pipe()

Loading tabular data#

[5]:
pilot = subject_id == "01"

pipe = DenseMaternalCsvDataLoader(pilot=pilot, subject_id=subject_id)

df = pipe()
INFO: Data has already been downloaded... using cached file ('/home/luisfpereira/.herbrain/data/maternal/28Baby_Hormones.csv').

Here, we filter the tabular data.

[6]:
session_selector = ppd.DfIsInFilter("stage", ["post"], negate=True)

predictor_selector = (
    session_selector + ppd.ColumnsSelector("gestWeek") + ppd.SeriesToDict()
)
[7]:
x_dict = predictor_selector(df)

Merge data#

We get the data in the proper format for fitting

[8]:
dataset_pipe = DictsToXY()

X, meshes_ = dataset_pipe((x_dict, meshes))

X.shape, len(meshes_)
[8]:
((19, 1), 19)

Create and fit regressor#

Follow How to perform dimensionality reduction on a mesh?, we create a pipeline to transform the output variable.

[9]:
pca = PCA(n_components=4)

objs2y = AdapterPipeline(
    steps=[
        BiMeshesToVertices(index=0),
        FunctionTransformer(func=np.stack),
        BiFlattenButFirst(),
        StandardScaler(with_std=False),
        pca,
    ],
)

Tip: polpo.models.Meshes2FlatVertices is syntax sugar for the code above.

[10]:
model = ObjectRegressor(LinearRegression(fit_intercept=True), objs2y=objs2y)
[11]:
model.fit(X, meshes_)
[11]:
ObjectRegressor(objs2y=AdapterPipeline(steps=[('step_0', BiMeshesToVertices()),
                                              ('step_1',
                                               FunctionTransformer(func=<function stack at 0x7d97b8240d30>)),
                                              ('step_2', BiFlattenButFirst()),
                                              ('step_3',
                                               StandardScaler(with_std=False)),
                                              ('step_4', PCA(n_components=4))]))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Evaluate fit#

model.predict outputs meshes, but we know LinearRegression sees PCA components. We can evaluate r2_score by applying transform.

NB: these are values on the training data.

Tip: objs2y is available in model.objs2y.

[12]:
meshes_pred = model.predict(X)

y_true = objs2y.transform(meshes_)
y_pred = objs2y.transform(meshes_pred)

r2_score(y_true, y_pred, multioutput="raw_values")
[12]:
array([0.04356306, 0.00644144, 0.24583469, 0.00145244])
[13]:
r2_score(y_true, y_pred, multioutput="uniform_average")
[13]:
0.07432290901629174

This shows the model performs poorly. (NB: the goal of this notebook is not to find a great model, but to show how the analysis can be performed. Adapting the pipeline to use different models is a no-brainer.)

The analysis can also be done at a mesh level. The following assumes Euclidean distance.

[14]:
meshes2flatvertices = Map(ToVertices()) + np.stack + FlattenButFirst()

r2_score(
    meshes2flatvertices(meshes_),
    meshes2flatvertices(meshes_pred),
    multioutput="uniform_average",
)
[14]:
0.0484419438715457

To build understanding, let’s plot the data the model actually “sees”.

[15]:
_, axes = plt.subplots(2, 2, sharex=True)

for index in range(4):
    ax = axes[index // 2][index % 2]
    ax.scatter(X[:, 0], y_true[:, index])
    ax.set_title(f"Comp {index}", fontsize=10)
    if index > 1:
        ax.set_xlabel("Gestational week")

plt.tight_layout()
../../../_images/_notebooks_how_to_maternal_mesh_valued_regression_29_0.png

Visualize predictions#

[16]:
X_pred = np.linspace(-3, 42, num=10)[:, None]

meshes_pred = model.predict(X_pred)
[17]:
outputs_dir = Path("_images")
if not outputs_dir.exists():
    outputs_dir.mkdir()

gif_name = outputs_dir / "regression.gif"

pl = pv.Plotter(
    border=False,
    off_screen=True,
    notebook=False,
)

pl.open_gif(gif_name.as_posix(), fps=3)

rendered_mesh = meshes_pred[0].copy()
pl.add_mesh(rendered_mesh, show_edges=True)
pl.add_title(f"{(X_pred[0][0]):.0f}")
pl.write_frame()

for index, mesh in enumerate(meshes_pred[1:]):
    rendered_mesh.points = mesh.points
    pl.render()

    pl.add_title(f"{(X_pred[index+1][0]):.0f}")

    pl.write_frame()


pl.close()
[18]:
Image(open(gif_name, "rb").read())
[18]:
../../../_images/_notebooks_how_to_maternal_mesh_valued_regression_33_0.png

Let’s check the predicted volumes.

[19]:
volumes = [mesh.volume for mesh in meshes_pred]

plt.scatter(X_pred, volumes)
plt.xlabel("Gestational week")
plt.ylabel("Volume");
../../../_images/_notebooks_how_to_maternal_mesh_valued_regression_35_0.png

Further reading#