# V4 Digital Twins

## Set Up + Imports

In [1]:
import setup

setup.main()
%load_ext autoreload
%autoreload 2
%load_ext jupyter_black

import torch
import os
from nnvision.models.ptrmodels import task_core_gauss_readout
from mei.modules import EnsembleModel

Working directory:  /home/facosta/neurometry/neurometry
Directory added to path:  /home/facosta/neurometry
Directory added to path:  /home/facosta/neurometry/neurometry


2024-09-04 17:05:46.580091: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Load model

In [2]:
weights_path = os.path.join(os.getcwd(), "datasets/digital_twins/weights")

In [3]:
def load_v4_model_from_weights(base_dir):
    model_fn = task_core_gauss_readout
    model_config = {
        "input_channels": 1,
        "model_name": "resnet50_l2_eps0_1",
        "layer_name": "layer3.0",
        "pretrained": False,
        "bias": False,
        "final_batchnorm": True,
        "final_nonlinearity": True,
        "momentum": 0.1,
        "fine_tune": False,
        "init_mu_range": 0.4,
        "init_sigma_range": 0.6,
        "readout_bias": True,
        "gamma_readout": 3.0,
        "gauss_type": "isotropic",
        "elu_offset": -1,
    }

    data_info = {
        "all_sessions": {
            "input_dimensions": torch.Size([64, 1, 100, 100]),
            "input_channels": 1,
            "output_dimension": 1244,
            "img_mean": 124.54466,
            "img_std": 70.28,
        },
    }

    # fill the list ensemble names with task driven 01 - 10
    ensemble_names = [
        "task_driven_ensemble_model_01.pth.tar",
        "task_driven_ensemble_model_02.pth.tar",
        "task_driven_ensemble_model_03.pth.tar",
        "task_driven_ensemble_model_04.pth.tar",
        "task_driven_ensemble_model_05.pth.tar",
    ]

    ensemble_models = []

    for f in ensemble_names:
        ensemble_filename = os.path.join(base_dir, f)
        ensemble_state_dict = torch.load(ensemble_filename)
        ensemble_model = model_fn(
            seed=0,
            dataloaders=None,
            **model_config,
            data_info=data_info,
        )
        ensemble_model.load_state_dict(ensemble_state_dict)
        ensemble_models.append(ensemble_model)

    task_driven_ensemble = EnsembleModel(*ensemble_models)
    return task_driven_ensemble

In [4]:
model = load_v4_model_from_weights(weights_path)
print(model)

EnsembleModel(EncoderShifter(
  (core): TaskDrivenCore3(
    (features): Sequential(
      (TaskDriven): Sequential(
        (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (layer1): Sequential(
          (0): Bottleneck(
            (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn3): BatchNorm2d(256, eps=1e-0