How to do mesh-valued regression?#

[ ]:
import numpy as np
import pyvista as pv
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
from sklearn.preprocessing import FunctionTransformer, StandardScaler

import polpo.preprocessing.dict as ppdict
import polpo.preprocessing.pd as ppd
from polpo.models import ObjectRegressor
from polpo.plot.pyvista import RegisteredMeshesGifPlotter
from polpo.preprocessing import Map
from polpo.preprocessing.learning import DictsToXY
from polpo.preprocessing.load.pregnancy.jacobs import MeshLoader, TabularDataLoader
from polpo.preprocessing.mesh.conversion import ToVertices
from polpo.preprocessing.mesh.registration import RigidAlignment
from polpo.sklearn.adapter import AdapterPipeline
from polpo.sklearn.mesh import BiMeshesToVertices
from polpo.sklearn.np import BiFlattenButFirst, FlattenButFirst
[ ]:
STATIC_VIZ = True

if STATIC_VIZ:
    pv.set_jupyter_backend("static")

Loading meshes#

[ ]:
subject_id = "01"

file_finder = MeshLoader(
    subject_subset=[subject_id],
    struct_subset=["L_Hipp"],
    as_mesh=True,
)

prep_pipe = RigidAlignment(max_iterations=500)

pipe = file_finder + ppdict.ExtractUniqueKey(nested=True) + prep_pipe


meshes = pipe()

Loading tabular data#

[ ]:
subject_id = "01"

pipe = TabularDataLoader(subject_id=subject_id)

df = pipe()

Here, we filter the tabular data.

[ ]:
session_selector = ppd.DfIsInFilter("stage", ["post"], negate=True)

predictor_selector = (
    session_selector + ppd.ColumnsSelector("gestWeek") + ppd.SeriesToDict()
)
[ ]:
x_dict = predictor_selector(df)

Merge data#

We get the data in the proper format for fitting

[ ]:
dataset_pipe = DictsToXY()

X, meshes_ = dataset_pipe((x_dict, meshes))

X.shape, len(meshes_)

Create and fit regressor#

Follow How to perform dimensionality reduction on a mesh?, we create a pipeline to transform the output variable.

[ ]:
pca = PCA(n_components=4)

objs2y = AdapterPipeline(
    steps=[
        BiMeshesToVertices(index=0),
        FunctionTransformer(func=np.stack),
        BiFlattenButFirst(),
        StandardScaler(with_std=False),
        pca,
    ],
)

Tip: polpo.models.Meshes2FlatVertices is syntax sugar for the code above.

[ ]:
model = ObjectRegressor(LinearRegression(fit_intercept=True), objs2y=objs2y)
[ ]:
model.fit(X, meshes_)

Evaluate fit#

model.predict outputs meshes, but we know LinearRegression sees PCA components. We can evaluate r2_score by applying transform.

NB: these are values on the training data.

Tip: objs2y is available in model.objs2y.

[ ]:
meshes_pred = model.predict(X)

y_true = objs2y.transform(meshes_)
y_pred = objs2y.transform(meshes_pred)

r2_score(y_true, y_pred, multioutput="raw_values")
[ ]:
r2_score(y_true, y_pred, multioutput="uniform_average")

This shows the model performs poorly. (NB: the goal of this notebook is not to find a great model, but to show how the analysis can be performed. Adapting the pipeline to use different models is a no-brainer.)

The analysis can also be done at a mesh level. The following assumes Euclidean distance.

[ ]:
meshes2flatvertices = Map(ToVertices()) + np.stack + FlattenButFirst()

r2_score(
    meshes2flatvertices(meshes_),
    meshes2flatvertices(meshes_pred),
    multioutput="uniform_average",
)

To build understanding, let’s plot the data the model actually “sees”.

[ ]:
_, axes = plt.subplots(2, 2, sharex=True)

for index in range(4):
    ax = axes[index // 2][index % 2]
    ax.scatter(X[:, 0], y_true[:, index])
    ax.set_title(f"Comp {index}", fontsize=10)
    if index > 1:
        ax.set_xlabel("Gestational week")

plt.tight_layout()

Visualize predictions#

[ ]:
X_pred = np.linspace(-3, 42, num=10)[:, None]

meshes_pred = model.predict(X_pred)
[ ]:
pl = RegisteredMeshesGifPlotter(fps=3)

pl.add_meshes({int(x): mesh for x, mesh in zip(X_pred[:, 0], meshes_pred)})
pl.close()

pl.show()

Let’s check the predicted volumes.

[ ]:
volumes = [mesh.volume for mesh in meshes_pred]

plt.scatter(X_pred, volumes)
plt.xlabel("Gestational week")
plt.ylabel("Volume");

Further reading#