From c1602b6045b85d3b803b8f4d0b7f667e5c23a522 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Sat, 15 Jul 2023 12:28:53 +0200
Subject: [PATCH 01/68] changed pring to print and fix the printing message of
 PlnPCA.

---
 pyPLNmodels/models.py | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 9c2024ee..b751e9e4 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -365,7 +365,7 @@ class _model(ABC):
         verbose : bool, optional(keyword-only)
             Whether to print training progress. Defaults to False.
         """
-        self._pring_beginning_message()
+        self._print_beginning_message()
         self._beginning_time = time.time()
 
         if self._fitted is False:
@@ -1668,7 +1668,7 @@ class Pln(_model):
         covariances = components_var @ (sk_components.T.unsqueeze(0))
         return covariances
 
-    def _pring_beginning_message(self):
+    def _print_beginning_message(self):
         """
         Method for printing the beginning message.
         """
@@ -2087,7 +2087,7 @@ class PlnPCAcollection:
         """
         return [model.rank for model in self.values()]
 
-    def _pring_beginning_message(self) -> str:
+    def _print_beginning_message(self) -> str:
         """
         Method for printing the beginning message.
 
@@ -2150,7 +2150,7 @@ class PlnPCAcollection:
         verbose : bool, optional(keyword-only)
             Whether to print verbose output, by default False.
         """
-        self._pring_beginning_message()
+        self._print_beginning_message()
         for i in range(len(self.values())):
             model = self[self.ranks[i]]
             model.fit(
@@ -2912,12 +2912,12 @@ class PlnPCA(_model):
         """
         return self._rank
 
-    def _pring_beginning_message(self):
+    def _print_beginning_message(self):
         """
         Print the beginning message when fitted.
         """
         print("-" * NB_CHARACTERS_FOR_NICE_PLOT)
-        print(f"Fitting a PlnPCAcollection model with {self._rank} components")
+        print(f"Fitting a PlnPCA model with {self._rank} components")
 
     @property
     def model_parameters(self) -> Dict[str, torch.Tensor]:
-- 
GitLab


From 6bad44b89281edd50a92103bc1f3d5b14c156cd7 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 20 Jul 2023 10:14:46 +0200
Subject: [PATCH 02/68] changed initialization to compute only eigenvalues. The
 init of the components should be 100* faster.

---
 pyPLNmodels/_initialization.py | 25 ++++++++++++++-----------
 1 file changed, 14 insertions(+), 11 deletions(-)

diff --git a/pyPLNmodels/_initialization.py b/pyPLNmodels/_initialization.py
index e0c3f47e..72de90b3 100644
--- a/pyPLNmodels/_initialization.py
+++ b/pyPLNmodels/_initialization.py
@@ -2,6 +2,11 @@ import torch
 import math
 from typing import Optional
 from ._utils import _log_stirling
+import time
+from sklearn.decomposition import PCA
+import seaborn as sns
+import matplotlib.pyplot as plt
+import numpy as np
 
 if torch.cuda.is_available():
     DEVICE = torch.device("cuda")
@@ -41,7 +46,7 @@ def _init_covariance(
 
 
 def _init_components(
-    endog: torch.Tensor, exog: torch.Tensor, coef: torch.Tensor, rank: int
+    endog: torch.Tensor, exog: torch.Tensor, rank: int
 ) -> torch.Tensor:
     """
     Initialization for components for the Pln model. Get a first guess for covariance
@@ -51,12 +56,6 @@ def _init_components(
     ----------
     endog : torch.Tensor
         Samples with size (n,p)
-    offsets : torch.Tensor
-        Offset, size (n,p)
-    exog : torch.Tensor
-        Covariates, size (n,d)
-    coef : torch.Tensor
-        Coefficient of size (d,p)
     rank : int
         The dimension of the latent space, i.e. the reduced dimension.
 
@@ -65,9 +64,11 @@ def _init_components(
     torch.Tensor
         Initialization of components of size (p,rank)
     """
-    sigma_hat = _init_covariance(endog, exog, coef).detach()
-    components = _components_from_covariance(sigma_hat, rank)
-    return components
+    log_y = torch.log(endog + (endog == 0) * math.exp(-2))
+    pca = PCA(n_components=rank)
+    pca.fit(log_y)
+    pca_comp = pca.components_.T * np.sqrt(pca.explained_variance_)
+    return torch.from_numpy(pca_comp).to(DEVICE)
 
 
 def _init_latent_mean(
@@ -102,13 +103,14 @@ def _init_latent_mean(
         The learning rate of the optimizer. Default is 0.01.
     eps : float, optional
         The tolerance. The algorithm will stop as soon as the criterion is lower than the tolerance.
-        Default is 7e-3.
+        Default is 7e-1.
 
     Returns
     -------
     torch.Tensor
         The initialized latent mean with size (n,rank)
     """
+    t = time.time()
     mode = torch.randn(endog.shape[0], components.shape[1], device=DEVICE)
     mode.requires_grad_(True)
     optimizer = torch.optim.Rprop([mode], lr=lr)
@@ -127,6 +129,7 @@ def _init_latent_mean(
             keep_condition = False
         old_mode = torch.clone(mode)
         i += 1
+    print("time mean", time.time() - t)
     return mode
 
 
-- 
GitLab


From ad28b4c6efd021ef24f706f3911f93d3f644a892 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 20 Jul 2023 10:17:37 +0200
Subject: [PATCH 03/68] GPU support

---
 pyPLNmodels/_utils.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 805c9dca..681b5c23 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -306,12 +306,12 @@ def _format_model_param(
     exog = _format_data(exog)
     if add_const is True:
         if exog is None:
-            exog = torch.ones(endog.shape[0], 1)
+            exog = torch.ones(endog.shape[0], 1).to(DEVICE)
         else:
             if _has_null_variance(exog) is False:
                 exog = torch.concat(
                     (exog, torch.ones(endog.shape[0]).unsqueeze(1)), dim=1
-                )
+                ).to(DEVICE)
     if offsets is None:
         if offsets_formula == "logsum":
             print("Setting the offsets as the log of the sum of endog")
-- 
GitLab


From 48cb139972fdc82dcafd61b0c17091151c5c3542 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 20 Jul 2023 10:18:40 +0200
Subject: [PATCH 04/68] changed arguments of init

---
 pyPLNmodels/models.py | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index b751e9e4..6b47bf78 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -205,7 +205,7 @@ class _model(ABC):
         if self._get_max_components() < 2:
             raise RuntimeError("Can't perform visualization for dim < 2.")
         pca = self.sk_PCA(n_components=2)
-        proj_variables = pca.transform(self.latent_variables)
+        proj_variables = pca.transform(self.latent_variables.detach().cpu())
         x = proj_variables[:, 0]
         y = proj_variables[:, 1]
         sns.scatterplot(x=x, y=y, hue=colors, ax=ax)
@@ -380,7 +380,7 @@ class _model(ABC):
             criterion = self._compute_criterion_and_update_plotargs(loss, tol)
             if abs(criterion) < tol:
                 stop_condition = True
-            if verbose and self.nb_iteration_done % 50 == 0:
+            if verbose and self.nb_iteration_done % 50 == 1:
                 self._print_stats()
         self._print_end_of_fitting_message(stop_condition, tol)
         self._fitted = True
@@ -2938,9 +2938,7 @@ class PlnPCA(_model):
         if not hasattr(self, "_coef"):
             super()._smart_init_coef()
         if not hasattr(self, "_components"):
-            self._components = _init_components(
-                self._endog, self._exog, self._coef, self._rank
-            )
+            self._components = _init_components(self._endog, self._exog, self._rank)
 
     def _random_init_model_parameters(self):
         """
-- 
GitLab


From 970466f26d5ad94ac5788dfcaed68f9e0251000f Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 20 Jul 2023 16:05:11 +0200
Subject: [PATCH 05/68] now we can take batches.

---
 pyPLNmodels/_utils.py |   2 +-
 pyPLNmodels/models.py | 156 +++++++++++++++++++++++++++++++++++++++---
 2 files changed, 149 insertions(+), 9 deletions(-)

diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 681b5c23..3f9c6441 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -761,7 +761,7 @@ def get_simulated_count_data(
             pln_param.covariance,
             pln_param.coef,
         )
-    return pln_param.endog, pln_param.cov, pln_param.offsets
+    return endog, pln_param.exog, pln_param.offsets
 
 
 def get_real_count_data(
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 6b47bf78..11987d85 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -76,6 +76,7 @@ class _model(ABC):
         dict_initialization: Optional[dict] = None,
         take_log_offsets: bool = False,
         add_const: bool = True,
+        batch_size: int = None,
     ):
         """
         Initializes the model class.
@@ -97,6 +98,13 @@ class _model(ABC):
             Whether to take the log of offsets. Defaults to False.
         add_const: bool, optional(keyword-only)
             Whether to add a column of one in the exog. Defaults to True.
+        batch_size: int, optional(keyword-only)
+            The batch size when optimizing the elbo. If None,
+            batch gradient descent will be performed (i.e. batch_size = n_samples).
+        Raises
+        ------
+        ValueError
+            If the batch_size is greater than the number of samples, or not int.
         """
         (
             self._endog,
@@ -107,6 +115,7 @@ class _model(ABC):
             endog, exog, offsets, offsets_formula, take_log_offsets, add_const
         )
         self._fitted = False
+        self._batch_size = self._handle_batch_size(batch_size)
         self._plotargs = _PlotArgs(self._WINDOW)
         if dict_initialization is not None:
             self._set_init_parameters(dict_initialization)
@@ -120,6 +129,7 @@ class _model(ABC):
         offsets_formula: str = "logsum",
         dict_initialization: Optional[dict] = None,
         take_log_offsets: bool = False,
+        batch_size: int = None,
     ):
         """
         Create a model instance from a formula and data.
@@ -137,6 +147,9 @@ class _model(ABC):
             The initialization dictionary. Defaults to None.
         take_log_offsets : bool, optional(keyword-only)
             Whether to take the log of offsets. Defaults to False.
+        batch_size: int, optional(keyword-only)
+            The batch size when optimizing the elbo. If None,
+            batch gradient descent will be performed (i.e. batch_size = n_samples).
         """
         endog, exog, offsets = _extract_data_from_formula(formula, data)
         return cls(
@@ -147,6 +160,7 @@ class _model(ABC):
             dict_initialization=dict_initialization,
             take_log_offsets=take_log_offsets,
             add_const=False,
+            batch_size=batch_size,
         )
 
     def _set_init_parameters(self, dict_initialization: dict):
@@ -166,6 +180,24 @@ class _model(ABC):
             setattr(self, key, array)
         self._fitted = True
 
+    @property
+    def batch_size(self) -> int:
+        """
+        The batch size of the model. Should not be greater than the number of samples.
+        """
+        return self._batch_size
+
+    @property
+    def _current_batch_size(self) -> int:
+        return self._exog_b.shape[0]
+
+    @batch_size.setter
+    def batch_size(self, batch_size: int):
+        """
+        Setter for the batch size. Should be an integer not greater than the number of samples.
+        """
+        self._batch_size = self._handle_batch_size(batch_size)
+
     @property
     def fitted(self) -> bool:
         """
@@ -216,6 +248,17 @@ class _model(ABC):
                 _plot_ellipse(x[i], y[i], cov=covariances[i], ax=ax)
         return ax
 
+    def _handle_batch_size(self, batch_size):
+        if batch_size is None:
+            batch_size = self.n_samples
+        if batch_size > self.n_samples:
+            raise ValueError(
+                f"batch_size ({batch_size}) can not be greater than the number of samples ({self.n_samples})"
+            )
+        elif isinstance(batch_size, int) is False:
+            raise ValueError(f"batch_size should be int, got {type(batch_size)}")
+        return batch_size
+
     @property
     def nb_iteration_done(self) -> int:
         """
@@ -385,21 +428,65 @@ class _model(ABC):
         self._print_end_of_fitting_message(stop_condition, tol)
         self._fitted = True
 
+    def _get_batch(self, batch_size, shuffle=False):
+        """Get the batches required to do a  minibatch gradient ascent.
+
+        Args:
+            batch_size: int. The batch size. Should be lower than n.
+
+        Returns: A generator. Will generate n//batch_size + 1 batches of
+            size batch_size (except the last one since the rest of the
+            division is not always 0)
+        """
+        indices = np.arange(self.n_samples)
+        if shuffle:
+            np.random.shuffle(indices)
+        nb_full_batch, last_batch_size = (
+            self.n_samples // batch_size,
+            self.n_samples % batch_size,
+        )
+        self.nb_batches = nb_full_batch + (last_batch_size > 0)
+        for i in range(nb_full_batch):
+            yield self._return_batch(indices, i * batch_size, (i + 1) * batch_size)
+        # Last batch
+        if last_batch_size != 0:
+            yield self._return_batch(indices, -last_batch_size, self.n_samples)
+
+    def _return_batch(self, indices, beginning, end):
+        return (
+            self._endog[indices[beginning:end]],
+            self._exog[beginning:end],
+            self._offsets[indices[beginning:end]],
+            self._latent_mean[beginning:end],
+            self._latent_sqrt_var[beginning:end],
+        )
+
     def _trainstep(self):
         """
-        Perform a single training step.
+        Perform a single pass of the data.
 
         Returns
         -------
         torch.Tensor
             The loss value.
         """
-        self.optim.zero_grad()
-        loss = -self.compute_elbo()
-        loss.backward()
-        self.optim.step()
-        self._update_closed_forms()
-        return loss
+        elbo = 0
+        for batch in self._get_batch(self._batch_size):
+            self._extract_batch(batch)
+            self.optim.zero_grad()
+            loss = -self._compute_elbo_b()
+            loss.backward()
+            elbo += loss.item()
+            self.optim.step()
+            self._update_closed_forms()
+        return elbo / self.nb_batches
+
+    def _extract_batch(self, batch):
+        self._endog_b = batch[0]
+        self._exog_b = batch[1]
+        self._offsets_b = batch[2]
+        self._latent_mean_b = batch[3]
+        self._latent_sqrt_var_b = batch[4]
 
     def transform(self):
         """
@@ -633,7 +720,7 @@ class _model(ABC):
         float
             The computed criterion.
         """
-        self._plotargs._elbos_list.append(-loss.item())
+        self._plotargs._elbos_list.append(-loss)
         self._plotargs.running_times.append(time.time() - self._beginning_time)
         if self._plotargs.iteration_number > self._WINDOW:
             criterion = abs(
@@ -1334,6 +1421,7 @@ class Pln(_model):
         dict_initialization: Optional[Dict[str, torch.Tensor]] = None,
         take_log_offsets: bool = False,
         add_const: bool = True,
+        batch_size: int = None,
     ):
         super().__init__(
             endog=endog,
@@ -1343,6 +1431,7 @@ class Pln(_model):
             dict_initialization=dict_initialization,
             take_log_offsets=take_log_offsets,
             add_const=add_const,
+            batch_size=batch_size,
         )
 
     @classmethod
@@ -1370,6 +1459,7 @@ class Pln(_model):
         offsets_formula: str = "logsum",
         dict_initialization: Optional[Dict[str, torch.Tensor]] = None,
         take_log_offsets: bool = False,
+        batch_size: int = None,
     ):
         endog, exog, offsets = _extract_data_from_formula(formula, data)
         return cls(
@@ -1380,6 +1470,7 @@ class Pln(_model):
             dict_initialization=dict_initialization,
             take_log_offsets=take_log_offsets,
             add_const=False,
+            batch_size=batch_size,
         )
 
     @_add_doc(
@@ -1619,6 +1710,23 @@ class Pln(_model):
             self._latent_sqrt_var,
         )
 
+    def _compute_elbo_b(self):
+        """
+        Method for computing the evidence lower bound (ELBO) on the current batch.
+
+        Returns
+        -------
+        torch.Tensor
+            The computed ELBO on the current batch.
+        """
+        return profiled_elbo_pln(
+            self._endog_b,
+            self._exog_b,
+            self._offsets_b,
+            self._latent_mean_b,
+            self._latent_sqrt_var_b,
+        )
+
     def _smart_init_model_parameters(self):
         """
         Method for smartly initializing the model parameters.
@@ -1779,6 +1887,7 @@ class PlnPCAcollection:
         dict_of_dict_initialization: Optional[dict] = None,
         take_log_offsets: bool = False,
         add_const: bool = True,
+        batch_size: int = None,
     ):
         """
         Constructor for PlnPCAcollection.
@@ -1801,6 +1910,9 @@ class PlnPCAcollection:
             Whether to take the logarithm of offsets, by default False.
         add_const: bool, optional(keyword-only)
             Whether to add a column of one in the exog. Defaults to True.
+        batch_size: int, optional(keyword-only)
+            The batch size when optimizing the elbo. If None,
+            batch gradient descent will be performed (i.e. batch_size = n_samples).
         Returns
         -------
         PlnPCAcollection
@@ -1831,6 +1943,7 @@ class PlnPCAcollection:
         ranks: Iterable[int] = range(3, 5),
         dict_of_dict_initialization: Optional[dict] = None,
         take_log_offsets: bool = False,
+        batch_size: int = None,
     ) -> "PlnPCAcollection":
         """
         Create an instance of PlnPCAcollection from a formula.
@@ -1851,6 +1964,10 @@ class PlnPCAcollection:
             The dictionary of initialization, by default None.
         take_log_offsets : bool, optional(keyword-only)
             Whether to take the logarithm of offsets, by default False.
+        batch_size: int, optional(keyword-only)
+            The batch size when optimizing the elbo. If None,
+            batch gradient descent will be performed (i.e. batch_size = n_samples).
+
         Returns
         -------
         PlnPCAcollection
@@ -2583,6 +2700,7 @@ class PlnPCA(_model):
         dict_initialization: Optional[Dict[str, torch.Tensor]] = None,
         take_log_offsets: bool = False,
         add_const: bool = True,
+        batch_size: int = None,
     ):
         self._rank = rank
         super().__init__(
@@ -2593,6 +2711,7 @@ class PlnPCA(_model):
             dict_initialization=dict_initialization,
             take_log_offsets=take_log_offsets,
             add_const=add_const,
+            batch_size=batch_size,
         )
 
     @classmethod
@@ -2624,6 +2743,7 @@ class PlnPCA(_model):
         rank: int = 5,
         offsets_formula: str = "logsum",
         dict_initialization: Optional[Dict[str, torch.Tensor]] = None,
+        batch_size: int = None,
     ):
         endog, exog, offsets = _extract_data_from_formula(formula, data)
         return cls(
@@ -2634,6 +2754,7 @@ class PlnPCA(_model):
             rank=rank,
             dict_initialization=dict_initialization,
             add_const=False,
+            batch_size=batch_size,
         )
 
     @_add_doc(
@@ -2991,6 +3112,25 @@ class PlnPCA(_model):
             return [self._components, self._latent_mean, self._latent_sqrt_var]
         return [self._components, self._coef, self._latent_mean, self._latent_sqrt_var]
 
+    def _compute_elbo_b(self) -> torch.Tensor:
+        """
+        Compute the evidence lower bound (ELBO) with the current batch.
+
+        Returns
+        -------
+        torch.Tensor
+            The ELBO value on the current batch.
+        """
+        return elbo_plnpca(
+            self._endog_b,
+            self._exog_b,
+            self._offsets_b,
+            self._latent_mean_b,
+            self._latent_sqrt_var_b,
+            self._components,
+            self._coef,
+        )
+
     def compute_elbo(self) -> torch.Tensor:
         """
         Compute the evidence lower bound (ELBO).
-- 
GitLab


From 899eb0fcf08b4cc0ebadb515e1f5684800e39642 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 20 Jul 2023 17:36:09 +0200
Subject: [PATCH 06/68] put paper in the .gitignore.

---
 .gitignore | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/.gitignore b/.gitignore
index a0185fcb..a95ada79 100644
--- a/.gitignore
+++ b/.gitignore
@@ -153,6 +153,9 @@ tests/Pln*
 slides/
 index.html
 
+paper/*
+
+
 tests/test_models*
 tests/test_load*
 tests/test_readme*
-- 
GitLab


From bd66e25433ef28dc1c5ed4ce2c7543c060b1b234 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Fri, 21 Jul 2023 09:16:08 +0200
Subject: [PATCH 07/68] Now takes the optimizer with a string instead of the
 whole torch.optimizer.Adam for example. Allow not to import torch and still
 choosing its optimizer.

---
 pyPLNmodels/models.py | 53 ++++++++++++++++++++++++++++++++++++-------
 1 file changed, 45 insertions(+), 8 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 11987d85..f01f3972 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -2,7 +2,7 @@ import time
 from abc import ABC, abstractmethod
 import warnings
 import os
-from typing import Optional, Dict, List, Type, Any, Iterable, Union
+from typing import Optional, Dict, List, Type, Any, Iterable, Union, Literal
 
 import pandas as pd
 import torch
@@ -385,7 +385,7 @@ class _model(ABC):
         nb_max_iteration: int = 50000,
         *,
         lr: float = 0.01,
-        class_optimizer: torch.optim.Optimizer = torch.optim.Rprop,
+        class_optimizer: Literal["Rprop", "Adam"] = "Rprop",
         tol: float = 1e-3,
         do_smart_init: bool = True,
         verbose: bool = False,
@@ -399,8 +399,10 @@ class _model(ABC):
             The maximum number of iterations. Defaults to 50000.
         lr : float, optional(keyword-only)
             The learning rate. Defaults to 0.01.
-        class_optimizer : torch.optim.Optimizer, optional
-            The optimizer class. Defaults to torch.optim.Rprop.
+        class_optimizer : str, optional
+            The optimizer class. Defaults to "Rprop". If the
+            batch_size is lower than the number of samples, the Rprop
+            algorithm should not be used. A warning will be sent.
         tol : float, optional(keyword-only)
             The tolerance for convergence. Defaults to 1e-3.
         do_smart_init : bool, optional(keyword-only)
@@ -416,7 +418,7 @@ class _model(ABC):
         elif len(self._plotargs.running_times) > 0:
             self._beginning_time -= self._plotargs.running_times[-1]
         self._put_parameters_to_device()
-        self.optim = class_optimizer(self._list_of_parameters_needing_gradient, lr=lr)
+        self._handle_optimizer(class_optimizer, lr)
         stop_condition = False
         while self.nb_iteration_done < nb_max_iteration and not stop_condition:
             loss = self._trainstep()
@@ -428,6 +430,41 @@ class _model(ABC):
         self._print_end_of_fitting_message(stop_condition, tol)
         self._fitted = True
 
+    def _handle_optimizer(self, class_optimizer, lr):
+        if class_optimizer == "Rprop":
+            if self.batch_size < self.n_samples:
+                optimizer_is_set = False
+                while optimizer_is_set is False:
+                    msg = (
+                        f"The Rprop optimizer should not be used when mini batch are used"
+                        f"(i.e. batch_size ({self.batch_size}) < n_samples = {self.n_samples}). "
+                        f"Do you wish to turn to the Adam Optimizer? (y/n) "
+                    )
+                    print(msg)
+                    turn = str(input())
+                    if turn == "y":
+                        self.optim = torch.optim.Adam(
+                            self._list_of_parameters_needing_gradient, lr=lr
+                        )
+                        optimizer_is_set = True
+                    elif turn == "n":
+                        self.optim = torch.optim.Rprop(
+                            self._list_of_parameters_needing_gradient, lr=lr
+                        )
+                        optimizer_is_set = True
+            else:
+                self.optim = torch.optim.Rprop(
+                    self._list_of_parameters_needing_gradient, lr=lr
+                )
+        elif class_optimizer == "Adam":
+            self.optim = torch.optim.Adam(
+                self._list_of_parameters_needing_gradient, lr=lr
+            )
+        else:
+            raise ValueError(
+                f"Optimizer should be either 'Adam' or 'Rprop', got {class_optimizer}"
+            )
+
     def _get_batch(self, batch_size, shuffle=False):
         """Get the batches required to do a  minibatch gradient ascent.
 
@@ -1488,7 +1525,7 @@ class Pln(_model):
         nb_max_iteration: int = 50000,
         *,
         lr: float = 0.01,
-        class_optimizer: torch.optim.Optimizer = torch.optim.Rprop,
+        class_optimizer: Literal["Rprop", "Adam"] = "Rprop",
         tol: float = 1e-3,
         do_smart_init: bool = True,
         verbose: bool = False,
@@ -2244,7 +2281,7 @@ class PlnPCAcollection:
         nb_max_iteration: int = 50000,
         *,
         lr: float = 0.01,
-        class_optimizer: Type[torch.optim.Optimizer] = torch.optim.Rprop,
+        class_optimizer: Literal["Rprop", "Adam"] = "Rprop",
         tol: float = 1e-3,
         do_smart_init: bool = True,
         verbose: bool = False,
@@ -2772,7 +2809,7 @@ class PlnPCA(_model):
         nb_max_iteration: int = 50000,
         *,
         lr: float = 0.01,
-        class_optimizer: torch.optim.Optimizer = torch.optim.Rprop,
+        class_optimizer: Literal["Rprop", "Adam"] = "Rprop",
         tol: float = 1e-3,
         do_smart_init: bool = True,
         verbose: bool = False,
-- 
GitLab


From 31c0ff30a58524cdcaec6858ebc7636529b5b14f Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Fri, 21 Jul 2023 17:40:48 +0200
Subject: [PATCH 08/68] put the batch size in the fit method and removed the
 possibility to choose the optimizer. Bug when shuffling the dataset.

---
 pyPLNmodels/_initialization.py |  2 -
 pyPLNmodels/models.py          | 88 ++++++++--------------------------
 2 files changed, 20 insertions(+), 70 deletions(-)

diff --git a/pyPLNmodels/_initialization.py b/pyPLNmodels/_initialization.py
index 72de90b3..57cc69cf 100644
--- a/pyPLNmodels/_initialization.py
+++ b/pyPLNmodels/_initialization.py
@@ -110,7 +110,6 @@ def _init_latent_mean(
     torch.Tensor
         The initialized latent mean with size (n,rank)
     """
-    t = time.time()
     mode = torch.randn(endog.shape[0], components.shape[1], device=DEVICE)
     mode.requires_grad_(True)
     optimizer = torch.optim.Rprop([mode], lr=lr)
@@ -129,7 +128,6 @@ def _init_latent_mean(
             keep_condition = False
         old_mode = torch.clone(mode)
         i += 1
-    print("time mean", time.time() - t)
     return mode
 
 
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index f01f3972..8c45780f 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -76,7 +76,6 @@ class _model(ABC):
         dict_initialization: Optional[dict] = None,
         take_log_offsets: bool = False,
         add_const: bool = True,
-        batch_size: int = None,
     ):
         """
         Initializes the model class.
@@ -98,9 +97,6 @@ class _model(ABC):
             Whether to take the log of offsets. Defaults to False.
         add_const: bool, optional(keyword-only)
             Whether to add a column of one in the exog. Defaults to True.
-        batch_size: int, optional(keyword-only)
-            The batch size when optimizing the elbo. If None,
-            batch gradient descent will be performed (i.e. batch_size = n_samples).
         Raises
         ------
         ValueError
@@ -115,7 +111,6 @@ class _model(ABC):
             endog, exog, offsets, offsets_formula, take_log_offsets, add_const
         )
         self._fitted = False
-        self._batch_size = self._handle_batch_size(batch_size)
         self._plotargs = _PlotArgs(self._WINDOW)
         if dict_initialization is not None:
             self._set_init_parameters(dict_initialization)
@@ -129,7 +124,6 @@ class _model(ABC):
         offsets_formula: str = "logsum",
         dict_initialization: Optional[dict] = None,
         take_log_offsets: bool = False,
-        batch_size: int = None,
     ):
         """
         Create a model instance from a formula and data.
@@ -147,9 +141,6 @@ class _model(ABC):
             The initialization dictionary. Defaults to None.
         take_log_offsets : bool, optional(keyword-only)
             Whether to take the log of offsets. Defaults to False.
-        batch_size: int, optional(keyword-only)
-            The batch size when optimizing the elbo. If None,
-            batch gradient descent will be performed (i.e. batch_size = n_samples).
         """
         endog, exog, offsets = _extract_data_from_formula(formula, data)
         return cls(
@@ -160,7 +151,6 @@ class _model(ABC):
             dict_initialization=dict_initialization,
             take_log_offsets=take_log_offsets,
             add_const=False,
-            batch_size=batch_size,
         )
 
     def _set_init_parameters(self, dict_initialization: dict):
@@ -385,10 +375,10 @@ class _model(ABC):
         nb_max_iteration: int = 50000,
         *,
         lr: float = 0.01,
-        class_optimizer: Literal["Rprop", "Adam"] = "Rprop",
         tol: float = 1e-3,
         do_smart_init: bool = True,
         verbose: bool = False,
+        batch_size=None,
     ):
         """
         Fit the model. The lower tol, the more accurate the model.
@@ -399,26 +389,25 @@ class _model(ABC):
             The maximum number of iterations. Defaults to 50000.
         lr : float, optional(keyword-only)
             The learning rate. Defaults to 0.01.
-        class_optimizer : str, optional
-            The optimizer class. Defaults to "Rprop". If the
-            batch_size is lower than the number of samples, the Rprop
-            algorithm should not be used. A warning will be sent.
         tol : float, optional(keyword-only)
             The tolerance for convergence. Defaults to 1e-3.
         do_smart_init : bool, optional(keyword-only)
             Whether to perform smart initialization. Defaults to True.
         verbose : bool, optional(keyword-only)
             Whether to print training progress. Defaults to False.
+        batch_size: int, optional(keyword-only)
+            The batch size when optimizing the elbo. If None,
+            batch gradient descent will be performed (i.e. batch_size = n_samples).
         """
         self._print_beginning_message()
         self._beginning_time = time.time()
-
+        self._batch_size = self._handle_batch_size(batch_size)
         if self._fitted is False:
             self._init_parameters(do_smart_init)
         elif len(self._plotargs.running_times) > 0:
             self._beginning_time -= self._plotargs.running_times[-1]
         self._put_parameters_to_device()
-        self._handle_optimizer(class_optimizer, lr)
+        self._handle_optimizer(lr)
         stop_condition = False
         while self.nb_iteration_done < nb_max_iteration and not stop_condition:
             loss = self._trainstep()
@@ -430,39 +419,14 @@ class _model(ABC):
         self._print_end_of_fitting_message(stop_condition, tol)
         self._fitted = True
 
-    def _handle_optimizer(self, class_optimizer, lr):
-        if class_optimizer == "Rprop":
-            if self.batch_size < self.n_samples:
-                optimizer_is_set = False
-                while optimizer_is_set is False:
-                    msg = (
-                        f"The Rprop optimizer should not be used when mini batch are used"
-                        f"(i.e. batch_size ({self.batch_size}) < n_samples = {self.n_samples}). "
-                        f"Do you wish to turn to the Adam Optimizer? (y/n) "
-                    )
-                    print(msg)
-                    turn = str(input())
-                    if turn == "y":
-                        self.optim = torch.optim.Adam(
-                            self._list_of_parameters_needing_gradient, lr=lr
-                        )
-                        optimizer_is_set = True
-                    elif turn == "n":
-                        self.optim = torch.optim.Rprop(
-                            self._list_of_parameters_needing_gradient, lr=lr
-                        )
-                        optimizer_is_set = True
-            else:
-                self.optim = torch.optim.Rprop(
-                    self._list_of_parameters_needing_gradient, lr=lr
-                )
-        elif class_optimizer == "Adam":
+    def _handle_optimizer(self, lr):
+        if self.batch_size < self.n_samples:
             self.optim = torch.optim.Adam(
                 self._list_of_parameters_needing_gradient, lr=lr
             )
         else:
-            raise ValueError(
-                f"Optimizer should be either 'Adam' or 'Rprop', got {class_optimizer}"
+            self.optim = torch.optim.Rprop(
+                self._list_of_parameters_needing_gradient, lr=lr
             )
 
     def _get_batch(self, batch_size, shuffle=False):
@@ -508,8 +472,9 @@ class _model(ABC):
             The loss value.
         """
         elbo = 0
-        for batch in self._get_batch(self._batch_size):
+        for batch in self._get_batch(self._batch_size, shuffle=False):
             self._extract_batch(batch)
+            # print('current bach', self._current_batch_size)
             self.optim.zero_grad()
             loss = -self._compute_elbo_b()
             loss.backward()
@@ -1458,7 +1423,6 @@ class Pln(_model):
         dict_initialization: Optional[Dict[str, torch.Tensor]] = None,
         take_log_offsets: bool = False,
         add_const: bool = True,
-        batch_size: int = None,
     ):
         super().__init__(
             endog=endog,
@@ -1468,7 +1432,6 @@ class Pln(_model):
             dict_initialization=dict_initialization,
             take_log_offsets=take_log_offsets,
             add_const=add_const,
-            batch_size=batch_size,
         )
 
     @classmethod
@@ -1496,7 +1459,6 @@ class Pln(_model):
         offsets_formula: str = "logsum",
         dict_initialization: Optional[Dict[str, torch.Tensor]] = None,
         take_log_offsets: bool = False,
-        batch_size: int = None,
     ):
         endog, exog, offsets = _extract_data_from_formula(formula, data)
         return cls(
@@ -1507,7 +1469,6 @@ class Pln(_model):
             dict_initialization=dict_initialization,
             take_log_offsets=take_log_offsets,
             add_const=False,
-            batch_size=batch_size,
         )
 
     @_add_doc(
@@ -1525,18 +1486,18 @@ class Pln(_model):
         nb_max_iteration: int = 50000,
         *,
         lr: float = 0.01,
-        class_optimizer: Literal["Rprop", "Adam"] = "Rprop",
         tol: float = 1e-3,
         do_smart_init: bool = True,
         verbose: bool = False,
+        batch_size: int = None,
     ):
         super().fit(
             nb_max_iteration,
             lr=lr,
-            class_optimizer=class_optimizer,
             tol=tol,
             do_smart_init=do_smart_init,
             verbose=verbose,
+            batch_size=batch_size,
         )
 
     @_add_doc(
@@ -1924,7 +1885,6 @@ class PlnPCAcollection:
         dict_of_dict_initialization: Optional[dict] = None,
         take_log_offsets: bool = False,
         add_const: bool = True,
-        batch_size: int = None,
     ):
         """
         Constructor for PlnPCAcollection.
@@ -1980,7 +1940,6 @@ class PlnPCAcollection:
         ranks: Iterable[int] = range(3, 5),
         dict_of_dict_initialization: Optional[dict] = None,
         take_log_offsets: bool = False,
-        batch_size: int = None,
     ) -> "PlnPCAcollection":
         """
         Create an instance of PlnPCAcollection from a formula.
@@ -2001,9 +1960,6 @@ class PlnPCAcollection:
             The dictionary of initialization, by default None.
         take_log_offsets : bool, optional(keyword-only)
             Whether to take the logarithm of offsets, by default False.
-        batch_size: int, optional(keyword-only)
-            The batch size when optimizing the elbo. If None,
-            batch gradient descent will be performed (i.e. batch_size = n_samples).
 
         Returns
         -------
@@ -2281,10 +2237,10 @@ class PlnPCAcollection:
         nb_max_iteration: int = 50000,
         *,
         lr: float = 0.01,
-        class_optimizer: Literal["Rprop", "Adam"] = "Rprop",
         tol: float = 1e-3,
         do_smart_init: bool = True,
         verbose: bool = False,
+        batch_size: int = None,
     ):
         """
         Fit each model in the PlnPCAcollection.
@@ -2295,14 +2251,15 @@ class PlnPCAcollection:
             The maximum number of iterations, by default 50000.
         lr : float, optional(keyword-only)
             The learning rate, by default 0.01.
-        class_optimizer : Type[torch.optim.Optimizer], optional(keyword-only)
-            The optimizer class, by default torch.optim.Rprop.
         tol : float, optional(keyword-only)
             The tolerance, by default 1e-3.
         do_smart_init : bool, optional(keyword-only)
             Whether to do smart initialization, by default True.
         verbose : bool, optional(keyword-only)
             Whether to print verbose output, by default False.
+        batch_size: int, optional(keyword-only)
+            The batch size when optimizing the elbo. If None,
+            batch gradient descent will be performed (i.e. batch_size = n_samples).
         """
         self._print_beginning_message()
         for i in range(len(self.values())):
@@ -2310,7 +2267,6 @@ class PlnPCAcollection:
             model.fit(
                 nb_max_iteration,
                 lr=lr,
-                class_optimizer=class_optimizer,
                 tol=tol,
                 do_smart_init=do_smart_init,
                 verbose=verbose,
@@ -2737,7 +2693,6 @@ class PlnPCA(_model):
         dict_initialization: Optional[Dict[str, torch.Tensor]] = None,
         take_log_offsets: bool = False,
         add_const: bool = True,
-        batch_size: int = None,
     ):
         self._rank = rank
         super().__init__(
@@ -2748,7 +2703,6 @@ class PlnPCA(_model):
             dict_initialization=dict_initialization,
             take_log_offsets=take_log_offsets,
             add_const=add_const,
-            batch_size=batch_size,
         )
 
     @classmethod
@@ -2780,7 +2734,6 @@ class PlnPCA(_model):
         rank: int = 5,
         offsets_formula: str = "logsum",
         dict_initialization: Optional[Dict[str, torch.Tensor]] = None,
-        batch_size: int = None,
     ):
         endog, exog, offsets = _extract_data_from_formula(formula, data)
         return cls(
@@ -2791,7 +2744,6 @@ class PlnPCA(_model):
             rank=rank,
             dict_initialization=dict_initialization,
             add_const=False,
-            batch_size=batch_size,
         )
 
     @_add_doc(
@@ -2809,18 +2761,18 @@ class PlnPCA(_model):
         nb_max_iteration: int = 50000,
         *,
         lr: float = 0.01,
-        class_optimizer: Literal["Rprop", "Adam"] = "Rprop",
         tol: float = 1e-3,
         do_smart_init: bool = True,
         verbose: bool = False,
+        batch_size=None,
     ):
         super().fit(
             nb_max_iteration,
             lr=lr,
-            class_optimizer=class_optimizer,
             tol=tol,
             do_smart_init=do_smart_init,
             verbose=verbose,
+            batch_size=batch_size,
         )
 
     @_add_doc(
-- 
GitLab


From d3f8e5e281d51e64a5d00fa858136cf3207c4fdb Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Mon, 24 Jul 2023 16:05:15 +0200
Subject: [PATCH 09/68] Fixed the shuffle issue.

---
 pyPLNmodels/models.py | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 8c45780f..1390e37a 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -454,12 +454,13 @@ class _model(ABC):
             yield self._return_batch(indices, -last_batch_size, self.n_samples)
 
     def _return_batch(self, indices, beginning, end):
+        to_take = torch.tensor(indices[beginning:end])
         return (
-            self._endog[indices[beginning:end]],
-            self._exog[beginning:end],
-            self._offsets[indices[beginning:end]],
-            self._latent_mean[beginning:end],
-            self._latent_sqrt_var[beginning:end],
+            torch.index_select(self._endog, 0, to_take),
+            torch.index_select(self._exog, 0, to_take),
+            torch.index_select(self._offsets, 0, to_take),
+            torch.index_select(self._latent_mean, 0, to_take),
+            torch.index_select(self._latent_sqrt_var, 0, to_take),
         )
 
     def _trainstep(self):
@@ -472,9 +473,8 @@ class _model(ABC):
             The loss value.
         """
         elbo = 0
-        for batch in self._get_batch(self._batch_size, shuffle=False):
+        for batch in self._get_batch(self._batch_size, shuffle=True):
             self._extract_batch(batch)
-            # print('current bach', self._current_batch_size)
             self.optim.zero_grad()
             loss = -self._compute_elbo_b()
             loss.backward()
-- 
GitLab


From ef6e8b2abf9d9411a0606b343016d1034fe77d93 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Mon, 24 Jul 2023 16:26:12 +0200
Subject: [PATCH 10/68] did not take the batch size when fitting the collection

---
 pyPLNmodels/models.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 1390e37a..cbccf767 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -2270,6 +2270,7 @@ class PlnPCAcollection:
                 tol=tol,
                 do_smart_init=do_smart_init,
                 verbose=verbose,
+                batch_size=batch_size,
             )
             if i < len(self.values()) - 1:
                 next_model = self[self.ranks[i + 1]]
-- 
GitLab


From 1dffd2358e2aeb7e27bd94246c28ecd9f6f732b9 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Tue, 25 Jul 2023 16:22:03 +0200
Subject: [PATCH 11/68] GPU support

---
 pyPLNmodels/_initialization.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pyPLNmodels/_initialization.py b/pyPLNmodels/_initialization.py
index 57cc69cf..02574c29 100644
--- a/pyPLNmodels/_initialization.py
+++ b/pyPLNmodels/_initialization.py
@@ -66,7 +66,7 @@ def _init_components(
     """
     log_y = torch.log(endog + (endog == 0) * math.exp(-2))
     pca = PCA(n_components=rank)
-    pca.fit(log_y)
+    pca.fit(log_y.detach().cpu())
     pca_comp = pca.components_.T * np.sqrt(pca.explained_variance_)
     return torch.from_numpy(pca_comp).to(DEVICE)
 
-- 
GitLab


From e2b3246ba3e44ee26b1c73368ada1d1ef8413f3d Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Tue, 25 Jul 2023 16:26:04 +0200
Subject: [PATCH 12/68] merge right lines of check_tol branch.

---
 pyPLNmodels/models.py | 42 ++++++++++++++++++++++++++----------------
 1 file changed, 26 insertions(+), 16 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index cbccf767..5edbcd30 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -250,7 +250,7 @@ class _model(ABC):
         return batch_size
 
     @property
-    def nb_iteration_done(self) -> int:
+    def _nb_iteration_done(self) -> int:
         """
         The number of iterations done.
 
@@ -259,7 +259,7 @@ class _model(ABC):
         int
             The number of iterations done.
         """
-        return len(self._plotargs._elbos_list)
+        return len(self._plotargs._elbos_list) * self._nb_batches
 
     @property
     def n_samples(self) -> int:
@@ -429,7 +429,7 @@ class _model(ABC):
                 self._list_of_parameters_needing_gradient, lr=lr
             )
 
-    def _get_batch(self, batch_size, shuffle=False):
+    def _get_batch(self, shuffle=False):
         """Get the batches required to do a  minibatch gradient ascent.
 
         Args:
@@ -442,19 +442,17 @@ class _model(ABC):
         indices = np.arange(self.n_samples)
         if shuffle:
             np.random.shuffle(indices)
-        nb_full_batch, last_batch_size = (
-            self.n_samples // batch_size,
-            self.n_samples % batch_size,
-        )
-        self.nb_batches = nb_full_batch + (last_batch_size > 0)
-        for i in range(nb_full_batch):
-            yield self._return_batch(indices, i * batch_size, (i + 1) * batch_size)
+
+        for i in range(self._nb_full_batch):
+            yield self._return_batch(
+                indices, i * self._batch_size, (i + 1) * self._batch_size
+            )
         # Last batch
-        if last_batch_size != 0:
-            yield self._return_batch(indices, -last_batch_size, self.n_samples)
+        if self._last_batch_size != 0:
+            yield self._return_batch(indices, -self._last_batch_size, self.n_samples)
 
     def _return_batch(self, indices, beginning, end):
-        to_take = torch.tensor(indices[beginning:end])
+        to_take = torch.tensor(indices[beginning:end]).to(DEVICE)
         return (
             torch.index_select(self._endog, 0, to_take),
             torch.index_select(self._exog, 0, to_take),
@@ -463,6 +461,18 @@ class _model(ABC):
             torch.index_select(self._latent_sqrt_var, 0, to_take),
         )
 
+    @property
+    def _nb_full_batch(self):
+        return self.n_samples // self._batch_size
+
+    @property
+    def _last_batch_size(self):
+        return self.n_samples % self._batch_size
+
+    @property
+    def _nb_batches(self):
+        return self._nb_full_batch + (self._last_batch_size > 0)
+
     def _trainstep(self):
         """
         Perform a single pass of the data.
@@ -473,7 +483,7 @@ class _model(ABC):
             The loss value.
         """
         elbo = 0
-        for batch in self._get_batch(self._batch_size, shuffle=True):
+        for batch in self._get_batch(shuffle=True):
             self._extract_batch(batch)
             self.optim.zero_grad()
             loss = -self._compute_elbo_b()
@@ -481,7 +491,7 @@ class _model(ABC):
             elbo += loss.item()
             self.optim.step()
             self._update_closed_forms()
-        return elbo / self.nb_batches
+        return elbo / self._nb_batches
 
     def _extract_batch(self, batch):
         self._endog_b = batch[0]
@@ -1232,7 +1242,7 @@ class _model(ABC):
         dict
             The dictionary of optimization parameters.
         """
-        return {"Number of iterations done": self.nb_iteration_done}
+        return {"Number of iterations done": self._nb_iteration_done}
 
     @property
     def _useful_properties_string(self):
-- 
GitLab


From f61427842572a7014c9f53bac973fd3831b98c36 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Tue, 25 Jul 2023 16:33:00 +0200
Subject: [PATCH 13/68] forgot to add _

---
 pyPLNmodels/models.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 5edbcd30..80c4cbf3 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -409,7 +409,7 @@ class _model(ABC):
         self._put_parameters_to_device()
         self._handle_optimizer(lr)
         stop_condition = False
-        while self.nb_iteration_done < nb_max_iteration and not stop_condition:
+        while self._nb_iteration_done < nb_max_iteration and not stop_condition:
             loss = self._trainstep()
             criterion = self._compute_criterion_and_update_plotargs(loss, tol)
             if abs(criterion) < tol:
-- 
GitLab


From 0171a060015decc0fbd132c78cfa2f4b756d9d2a Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Mon, 9 Oct 2023 09:33:28 +0200
Subject: [PATCH 14/68] merge changes with check_tol. Basically change the way
 we compute tolerance. All this to handle batch tolerance. The cumulative elbo
 is a good indicator.

---
 pyPLNmodels/_utils.py | 15 +++++++--------
 pyPLNmodels/models.py | 22 +++++++++++-----------
 2 files changed, 18 insertions(+), 19 deletions(-)

diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 3f9c6441..d2b1aea0 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -24,19 +24,14 @@ else:
 
 
 class _PlotArgs:
-    def __init__(self, window: int):
+    def __init__(self):
         """
         Initialize the PlotArgs class.
-
-        Parameters
-        ----------
-        window : int
-            The size of the window for computing the criterion.
         """
-        self.window = window
         self.running_times = []
-        self.criterions = [1] * window  # the first window criterion won't be computed.
+        self.criterions = []
         self._elbos_list = []
+        self.cumulative_elbo_list = [0]
 
     @property
     def iteration_number(self) -> int:
@@ -50,6 +45,10 @@ class _PlotArgs:
         """
         return len(self._elbos_list)
 
+    @property
+    def cumulative_elbo(self):
+        return self.cumulative_elbo_list[-1]
+
     def _show_loss(self, ax=None):
         """
         Show the loss of the model (i.e. the negative ELBO).
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 80c4cbf3..c964fb8c 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -57,7 +57,6 @@ class _model(ABC):
     Base class for all the Pln models. Should be inherited.
     """
 
-    _WINDOW: int = 15
     _endog: torch.Tensor
     _exog: torch.Tensor
     _offsets: torch.Tensor
@@ -734,14 +733,15 @@ class _model(ABC):
         """
         self._plotargs._elbos_list.append(-loss)
         self._plotargs.running_times.append(time.time() - self._beginning_time)
-        if self._plotargs.iteration_number > self._WINDOW:
-            criterion = abs(
-                self._plotargs._elbos_list[-1]
-                - self._plotargs._elbos_list[-1 - self._WINDOW]
-            )
-            self._plotargs.criterions.append(criterion)
-            return criterion
-        return tol
+        self._plotargs.cumulative_elbo_list.append(
+            self._plotargs.cumulative_elbo_list - loss
+        )
+        criterion = (
+            self._plotargs.cumulative_elbo_list[-2]
+            - self._plotargs.cumulative_elbo_list[-1]
+        ) / self._plotargs.cumulative_elbo_list[-1]
+        self._plotargs.criterions.append(criterion)
+        return criterion
 
     def _update_closed_forms(self):
         """
@@ -2924,11 +2924,11 @@ class PlnPCA(_model):
     def _endog_predictions(self):
         covariance_a_posteriori = torch.sum(
             (self._components**2).unsqueeze(0)
-            * (self.latent_sqrt_var**2).unsqueeze(1),
+            * (self._latent_sqrt_var**2).unsqueeze(1),
             axis=2,
         )
         if self.exog is not None:
-            XB = self.exog @ self.coef
+            XB = self._exog @ self._coef
         else:
             XB = 0
         return torch.exp(
-- 
GitLab


From d6375c6658a00611aef9a56bb6b76e7c2d7850b6 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Mon, 9 Oct 2023 09:37:33 +0200
Subject: [PATCH 15/68] dd tests only on the main and dev branch.

---
 .gitlab-ci.yml | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index d8c20b0f..c95c78f1 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -23,6 +23,9 @@ tests:
     - cd tests
     - python create_readme_and_docstrings_tests.py
     - pytest .
+  only:
+    - main
+    - dev
 
 
 build_package:
-- 
GitLab


From 41d85f976bd6a475efd6e6441b096c9c863bc3df Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Mon, 9 Oct 2023 10:05:23 +0200
Subject: [PATCH 16/68] add the elbo of the zi model.

---
 pyPLNmodels/elbos.py | 92 ++++++++++++++++++++++++++++----------------
 1 file changed, 59 insertions(+), 33 deletions(-)

diff --git a/pyPLNmodels/elbos.py b/pyPLNmodels/elbos.py
index 6dcda361..454cfb75 100644
--- a/pyPLNmodels/elbos.py
+++ b/pyPLNmodels/elbos.py
@@ -172,6 +172,7 @@ def elbo_plnpca(
     ) / n_samples
 
 
+## pb with trunc_log
 ## should rename some variables so that is is clearer when we see the formula
 def elbo_zi_pln(
     endog,
@@ -179,13 +180,13 @@ def elbo_zi_pln(
     offsets,
     latent_mean,
     latent_sqrt_var,
-    pi,
-    covariance,
+    latent_prob,
+    components,
     coef,
-    _coef_inflation,
+    coef_inflation,
     dirac,
 ):
-    """Compute the ELBO (Evidence LOwer Bound) for the Zero Inflated Pln model.
+    """Compute the ELBO (Evidence LOwer Bound) for the Zero Inflated PLN model.
     See the doc for more details on the computation.
 
     Args:
@@ -197,41 +198,66 @@ def elbo_zi_pln(
         pi: torch.tensor. Variational parameter with size (n,p)
         covariance: torch.tensor. Model parameter with size (p,p)
         coef: torch.tensor. Model parameter with size (d,p)
-        _coef_inflation: torch.tensor. Model parameter with size (d,p)
+        coef_inflation: torch.tensor. Model parameter with size (d,p)
     Returns:
         torch.tensor of size 1 with a gradient.
     """
-    if torch.norm(pi * dirac - pi) > 0.0001:
-        print("Bug")
-        return False
-    n_samples = endog.shape[0]
-    dim = endog.shape[1]
-    s_rond_s = torch.square(latent_sqrt_var)
-    offsets_plus_m = offsets + latent_mean
-    m_minus_xb = latent_mean - exog @ coef
-    x_coef_inflation = exog @ _coef_inflation
-    elbo = torch.sum(
-        (1 - pi)
-        * (
-            endog @ offsets_plus_m
-            - torch.exp(offsets_plus_m + s_rond_s / 2)
-            - _log_stirling(endog),
-        )
-        + pi
+    covariance = components @ (components.T)
+    if torch.norm(latent_prob * dirac - latent_prob) > 0.00000001:
+        raise RuntimeError("Latent probability is not zero when it should be.")
+    n_samples, dim = endog.shape
+    s_rond_s = torch.multiply(latent_sqrt_var, latent_sqrt_var)
+    o_plus_m = offsets + latent_mean
+    if exog is None:
+        XB = torch.zeros_like(endog)
+        xcoef_inflation = torch.zeros_like(endog)
+    else:
+        XB = exog @ coef
+        xcoef_inflation = exog @ coef_inflation
+    m_minus_xb = latent_mean - XB
+
+    A = torch.exp(o_plus_m + s_rond_s / 2)
+    inside_a = torch.multiply(
+        1 - latent_prob, torch.multiply(endog, o_plus_m) - A - _log_stirling(endog)
     )
+    a = torch.sum(inside_a)
 
-    elbo -= torch.sum(pi * _trunc_log(pi) + (1 - pi) * _trunc_log(1 - pi))
-    elbo += torch.sum(
-        pi * x_coef_inflation - torch.log(1 + torch.exp(x_coef_inflation))
+    Omega = torch.inverse(covariance)
+
+    m_moins_xb_outer = torch.mm(m_minus_xb.T, m_minus_xb)
+    un_moins_rho = 1 - latent_prob
+    un_moins_rho_m_moins_xb = un_moins_rho * m_minus_xb
+    un_moins_rho_m_moins_xb_outer = un_moins_rho_m_moins_xb.T @ un_moins_rho_m_moins_xb
+    inside_b = -1 / 2 * Omega * un_moins_rho_m_moins_xb_outer
+    b = -n_samples / 2 * torch.logdet(covariance) + torch.sum(inside_b)
+
+    inside_c = torch.multiply(latent_prob, xcoef_inflation) - torch.log(
+        1 + torch.exp(xcoef_inflation)
     )
+    c = torch.sum(inside_c)
+    log_diag = torch.log(torch.diag(covariance))
+    log_S_term = torch.sum(
+        torch.multiply(1 - latent_prob, torch.log(torch.abs(latent_sqrt_var))), axis=0
+    )
+    y = torch.sum(latent_prob, axis=0)
+    covariance_term = 1 / 2 * torch.log(torch.diag(covariance)) * y
+    inside_d = covariance_term + log_S_term
+
+    d = n_samples * dim / 2 + torch.sum(inside_d)
 
-    elbo -= 0.5 * torch.trace(
-        torch.mm(
-            torch.inverse(covariance),
-            torch.diag(torch.sum(s_rond_s, dim=0)) + m_minus_xb.T @ m_minus_xb,
-        )
+    inside_e = torch.multiply(latent_prob, _trunc_log(latent_prob)) + torch.multiply(
+        1 - latent_prob, _trunc_log(1 - latent_prob)
     )
-    elbo += 0.5 * n_samples * torch.log(torch.det(covariance))
-    elbo += 0.5 * n_samples * dim
-    elbo += 0.5 * torch.sum(torch.log(s_rond_s))
+    e = -torch.sum(inside_e)
+    sum_un_moins_rho_s2 = torch.sum(torch.multiply(1 - latent_prob, s_rond_s), axis=0)
+    diag_sig_sum_rho = torch.multiply(
+        torch.diag(covariance), torch.sum(latent_prob, axis=0)
+    )
+    new = torch.sum(latent_prob * un_moins_rho * (m_minus_xb**2), axis=0)
+    K = sum_un_moins_rho_s2 + diag_sig_sum_rho + new
+    inside_f = torch.diag(Omega) * K
+    f = -1 / 2 * torch.sum(inside_f)
+    full_diag_omega = torch.diag(Omega).expand(exog.shape[0], -1)
+    elbo = a + b + c + d + e + f
+    print(" inside a shape", inside_a.shape)
     return elbo
-- 
GitLab


From f985bb6be6f4a881165627db65d2b92c2f55f627 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Mon, 9 Oct 2023 10:22:35 +0200
Subject: [PATCH 17/68] remove useless sums and optimize a little.

---
 pyPLNmodels/elbos.py | 14 ++++----------
 1 file changed, 4 insertions(+), 10 deletions(-)

diff --git a/pyPLNmodels/elbos.py b/pyPLNmodels/elbos.py
index 454cfb75..f7378d0f 100644
--- a/pyPLNmodels/elbos.py
+++ b/pyPLNmodels/elbos.py
@@ -220,7 +220,6 @@ def elbo_zi_pln(
     inside_a = torch.multiply(
         1 - latent_prob, torch.multiply(endog, o_plus_m) - A - _log_stirling(endog)
     )
-    a = torch.sum(inside_a)
 
     Omega = torch.inverse(covariance)
 
@@ -229,12 +228,10 @@ def elbo_zi_pln(
     un_moins_rho_m_moins_xb = un_moins_rho * m_minus_xb
     un_moins_rho_m_moins_xb_outer = un_moins_rho_m_moins_xb.T @ un_moins_rho_m_moins_xb
     inside_b = -1 / 2 * Omega * un_moins_rho_m_moins_xb_outer
-    b = -n_samples / 2 * torch.logdet(covariance) + torch.sum(inside_b)
 
     inside_c = torch.multiply(latent_prob, xcoef_inflation) - torch.log(
         1 + torch.exp(xcoef_inflation)
     )
-    c = torch.sum(inside_c)
     log_diag = torch.log(torch.diag(covariance))
     log_S_term = torch.sum(
         torch.multiply(1 - latent_prob, torch.log(torch.abs(latent_sqrt_var))), axis=0
@@ -243,21 +240,18 @@ def elbo_zi_pln(
     covariance_term = 1 / 2 * torch.log(torch.diag(covariance)) * y
     inside_d = covariance_term + log_S_term
 
-    d = n_samples * dim / 2 + torch.sum(inside_d)
-
     inside_e = torch.multiply(latent_prob, _trunc_log(latent_prob)) + torch.multiply(
         1 - latent_prob, _trunc_log(1 - latent_prob)
     )
-    e = -torch.sum(inside_e)
     sum_un_moins_rho_s2 = torch.sum(torch.multiply(1 - latent_prob, s_rond_s), axis=0)
     diag_sig_sum_rho = torch.multiply(
         torch.diag(covariance), torch.sum(latent_prob, axis=0)
     )
     new = torch.sum(latent_prob * un_moins_rho * (m_minus_xb**2), axis=0)
     K = sum_un_moins_rho_s2 + diag_sig_sum_rho + new
-    inside_f = torch.diag(Omega) * K
-    f = -1 / 2 * torch.sum(inside_f)
+    inside_f = -1 / 2 * torch.diag(Omega) * K
     full_diag_omega = torch.diag(Omega).expand(exog.shape[0], -1)
-    elbo = a + b + c + d + e + f
-    print(" inside a shape", inside_a.shape)
+    elbo = torch.sum(inside_a + inside_c + inside_d)
+    elbo += torch.sum(inside_b) - n_samples / 2 * torch.logdet(covariance)
+    elbo += n_samples * dim / 2 + torch.sum(inside_d + inside_f)
     return elbo
-- 
GitLab


From 9070fc11ea27239d5ae7eb8999de499ebae6338b Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Mon, 9 Oct 2023 10:34:39 +0200
Subject: [PATCH 18/68] rewrite the elbo to avoid doublons and torch multiply
 are replaced by *

---
 pyPLNmodels/elbos.py | 56 ++++++++++++++++++++------------------------
 1 file changed, 26 insertions(+), 30 deletions(-)

diff --git a/pyPLNmodels/elbos.py b/pyPLNmodels/elbos.py
index f7378d0f..ec743430 100644
--- a/pyPLNmodels/elbos.py
+++ b/pyPLNmodels/elbos.py
@@ -202,11 +202,15 @@ def elbo_zi_pln(
     Returns:
         torch.tensor of size 1 with a gradient.
     """
-    covariance = components @ (components.T)
     if torch.norm(latent_prob * dirac - latent_prob) > 0.00000001:
         raise RuntimeError("Latent probability is not zero when it should be.")
+    covariance = components @ (components.T)
+    diag_cov = torch.diag(covariance)
+    Omega = torch.inverse(covariance)
+    diag_omega = torch.diag(Omega)
+    un_moins_prob = 1 - latent_prob
     n_samples, dim = endog.shape
-    s_rond_s = torch.multiply(latent_sqrt_var, latent_sqrt_var)
+    s_rond_s = latent_sqrt_var * latent_sqrt_var
     o_plus_m = offsets + latent_mean
     if exog is None:
         XB = torch.zeros_like(endog)
@@ -217,40 +221,32 @@ def elbo_zi_pln(
     m_minus_xb = latent_mean - XB
 
     A = torch.exp(o_plus_m + s_rond_s / 2)
-    inside_a = torch.multiply(
-        1 - latent_prob, torch.multiply(endog, o_plus_m) - A - _log_stirling(endog)
-    )
-
-    Omega = torch.inverse(covariance)
-
+    inside_a = un_moins_prob * (endog * o_plus_m - A - _log_stirling(endog))
     m_moins_xb_outer = torch.mm(m_minus_xb.T, m_minus_xb)
-    un_moins_rho = 1 - latent_prob
-    un_moins_rho_m_moins_xb = un_moins_rho * m_minus_xb
-    un_moins_rho_m_moins_xb_outer = un_moins_rho_m_moins_xb.T @ un_moins_rho_m_moins_xb
-    inside_b = -1 / 2 * Omega * un_moins_rho_m_moins_xb_outer
-
-    inside_c = torch.multiply(latent_prob, xcoef_inflation) - torch.log(
-        1 + torch.exp(xcoef_inflation)
+    un_moins_prob_m_moins_xb = un_moins_prob * m_minus_xb
+    un_moins_prob_m_moins_xb_outer = (
+        un_moins_prob_m_moins_xb.T @ un_moins_prob_m_moins_xb
     )
-    log_diag = torch.log(torch.diag(covariance))
+    inside_b = -1 / 2 * Omega * un_moins_prob_m_moins_xb_outer
+
+    inside_c = latent_prob * xcoef_inflation - torch.log(1 + torch.exp(xcoef_inflation))
+    log_diag = torch.log(diag_cov)
     log_S_term = torch.sum(
-        torch.multiply(1 - latent_prob, torch.log(torch.abs(latent_sqrt_var))), axis=0
+        un_moins_prob * torch.log(torch.abs(latent_sqrt_var)), axis=0
     )
-    y = torch.sum(latent_prob, axis=0)
-    covariance_term = 1 / 2 * torch.log(torch.diag(covariance)) * y
+    sum_prob = torch.sum(latent_prob, axis=0)
+    covariance_term = 1 / 2 * torch.log(diag_cov) * sum_prob
     inside_d = covariance_term + log_S_term
 
-    inside_e = torch.multiply(latent_prob, _trunc_log(latent_prob)) + torch.multiply(
-        1 - latent_prob, _trunc_log(1 - latent_prob)
-    )
-    sum_un_moins_rho_s2 = torch.sum(torch.multiply(1 - latent_prob, s_rond_s), axis=0)
-    diag_sig_sum_rho = torch.multiply(
-        torch.diag(covariance), torch.sum(latent_prob, axis=0)
-    )
-    new = torch.sum(latent_prob * un_moins_rho * (m_minus_xb**2), axis=0)
-    K = sum_un_moins_rho_s2 + diag_sig_sum_rho + new
-    inside_f = -1 / 2 * torch.diag(Omega) * K
-    full_diag_omega = torch.diag(Omega).expand(exog.shape[0], -1)
+    inside_e = torch.multiply(
+        latent_prob, _trunc_log(latent_prob)
+    ) + un_moins_prob * _trunc_log(un_moins_prob)
+    sum_un_moins_prob_s2 = torch.sum(un_moins_prob * s_rond_s, axis=0)
+    diag_sig_sum_prob = diag_cov * torch.sum(latent_prob, axis=0)
+    new = torch.sum(latent_prob * un_moins_prob * (m_minus_xb**2), axis=0)
+    K = sum_un_moins_prob_s2 + diag_sig_sum_prob + new
+    inside_f = -1 / 2 * diag_omega * K
+    full_diag_omega = diag_omega.expand(exog.shape[0], -1)
     elbo = torch.sum(inside_a + inside_c + inside_d)
     elbo += torch.sum(inside_b) - n_samples / 2 * torch.logdet(covariance)
     elbo += n_samples * dim / 2 + torch.sum(inside_d + inside_f)
-- 
GitLab


From 1784e7f94933dbd9713e081e69b27603d53a3640 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Mon, 9 Oct 2023 11:39:41 +0200
Subject: [PATCH 19/68] began to merge changements of zi inside. till line 3482

---
 pyPLNmodels/_initialization.py |   4 +-
 pyPLNmodels/models.py          | 190 ++++++++++++++++++++++++++++++---
 2 files changed, 177 insertions(+), 17 deletions(-)

diff --git a/pyPLNmodels/_initialization.py b/pyPLNmodels/_initialization.py
index 02574c29..f5663746 100644
--- a/pyPLNmodels/_initialization.py
+++ b/pyPLNmodels/_initialization.py
@@ -14,9 +14,7 @@ else:
     DEVICE = torch.device("cpu")
 
 
-def _init_covariance(
-    endog: torch.Tensor, exog: torch.Tensor, coef: torch.Tensor
-) -> torch.Tensor:
+def _init_covariance(endog: torch.Tensor, exog: torch.Tensor) -> torch.Tensor:
     """
     Initialization for the covariance for the Pln model. Take the log of endog
     (careful when endog=0), and computes the Maximum Likelihood
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index c964fb8c..66ea3339 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -18,7 +18,7 @@ from scipy import stats
 from ._closed_forms import (
     _closed_formula_coef,
     _closed_formula_covariance,
-    _closed_formula_pi,
+    _closed_formula_latent_prob,
 )
 from .elbos import elbo_plnpca, elbo_zi_pln, profiled_elbo_pln
 from ._utils import (
@@ -32,6 +32,7 @@ from ._utils import (
     _array2tensor,
     _handle_data,
     _add_doc,
+    _closed_form_latent_prob,
 )
 
 from ._initialization import (
@@ -110,7 +111,7 @@ class _model(ABC):
             endog, exog, offsets, offsets_formula, take_log_offsets, add_const
         )
         self._fitted = False
-        self._plotargs = _PlotArgs(self._WINDOW)
+        self._plotargs = _PlotArgs()
         if dict_initialization is not None:
             self._set_init_parameters(dict_initialization)
 
@@ -249,7 +250,7 @@ class _model(ABC):
         return batch_size
 
     @property
-    def _nb_iteration_done(self) -> int:
+    def nb_iteration_done(self) -> int:
         """
         The number of iterations done.
 
@@ -408,7 +409,7 @@ class _model(ABC):
         self._put_parameters_to_device()
         self._handle_optimizer(lr)
         stop_condition = False
-        while self._nb_iteration_done < nb_max_iteration and not stop_condition:
+        while self.nb_iteration_done < nb_max_iteration and not stop_condition:
             loss = self._trainstep()
             criterion = self._compute_criterion_and_update_plotargs(loss, tol)
             if abs(criterion) < tol:
@@ -1242,7 +1243,7 @@ class _model(ABC):
         dict
             The dictionary of optimization parameters.
         """
-        return {"Number of iterations done": self._nb_iteration_done}
+        return {"Number of iterations done": self.nb_iteration_done}
 
     @property
     def _useful_properties_string(self):
@@ -1850,7 +1851,7 @@ class Pln(_model):
         covariance : torch.Tensor
             The covariance matrix.
         """
-        pass
+        raise AttributeError("You can not set the covariance for the Pln model.")
 
 
 class PlnPCAcollection:
@@ -3190,7 +3191,7 @@ class PlnPCA(_model):
     @property
     def covariance(self) -> torch.Tensor:
         """
-        Property representing the covariance a posteriori of the latent variables.
+        Property representing the covariance of the latent variables.
 
         Returns
         -------
@@ -3326,13 +3327,137 @@ class PlnPCA(_model):
         return self.latent_variables
 
 
-class ZIPln(Pln):
+class ZIPln(_model):
     _NAME = "ZIPln"
 
-    _pi: torch.Tensor
+    _latent_prob: torch.Tensor
     _coef_inflation: torch.Tensor
     _dirac: torch.Tensor
 
+    @_add_doc(
+        _model,
+        example="""
+            >>> from pyPLNmodels import ZIPln, get_real_count_data
+            >>> endog= get_real_count_data()
+            >>> zi = ZIPln(endog, add_const = True)
+            >>> zi.fit()
+            >>> print(zi)
+        """,
+        returns="""
+            ZIPln
+        """,
+        see_also="""
+        :func:`pyPLNmodels.ZIPln.from_formula`
+        """,
+    )
+    def __init__(
+        self,
+        endog: Optional[Union[torch.Tensor, np.ndarray, pd.DataFrame]],
+        *,
+        exog: Optional[Union[torch.Tensor, np.ndarray, pd.DataFrame]] = None,
+        offsets: Optional[Union[torch.Tensor, np.ndarray, pd.DataFrame]] = None,
+        offsets_formula: str = "logsum",
+        dict_initialization: Optional[Dict[str, torch.Tensor]] = None,
+        take_log_offsets: bool = False,
+        add_const: bool = True,
+        use_closed_form: bool = False,
+    ):
+        super().__init__(
+            endog=endog,
+            exog=exog,
+            offsets=offsets,
+            offsets_formula=offsets_formula,
+            dict_initialization=dict_initialization,
+            take_log_offsets=take_log_offsets,
+            add_const=add_const,
+        )
+        self._use_closed_form = use_closed_form
+
+    @classmethod
+    @_add_doc(
+        _model,
+        example="""
+            >>> from pyPLNmodels import ZIPln, get_real_count_data
+            >>> endog = get_real_count_data()
+            >>> data = {"endog": endog}
+            >>> zi = ZIPln.from_formula("endog ~ 1", data = data)
+        """,
+        returns="""
+            ZIPln
+        """,
+        see_also="""
+        :class:`pyPLNmodels.ZIPln`
+        :func:`pyPLNmodels.ZIPln.__init__`
+    """,
+    )
+    def from_formula(
+        cls,
+        formula: str,
+        data: Dict[str, Union[torch.Tensor, np.ndarray, pd.DataFrame]],
+        *,
+        offsets_formula: str = "logsum",
+        dict_initialization: Optional[Dict[str, torch.Tensor]] = None,
+        take_log_offsets: bool = False,
+        use_closed_form: bool = True,
+    ):
+        endog, exog, offsets = _extract_data_from_formula(formula, data)
+        return cls(
+            endog,
+            exog=exog,
+            offsets=offsets,
+            offsets_formula=offsets_formula,
+            dict_initialization=dict_initialization,
+            take_log_offsets=take_log_offsets,
+            add_const=False,
+            use_closed_form=use_closed_form,
+        )
+
+    @_add_doc(
+        _model,
+        example="""
+        >>> from pyPLNmodels import ZIPln, get_real_count_data
+        >>> endog = get_real_count_data()
+        >>> zi = Pln(endog,add_const = True)
+        >>> zi.fit()
+        >>> print(zi)
+        """,
+    )
+    def fit(
+        self,
+        nb_max_iteration: int = 50000,
+        *,
+        lr: float = 0.01,
+        tol: float = 1e-3,
+        do_smart_init: bool = True,
+        verbose: bool = False,
+        batch_size: int = None,
+    ):
+        super().fit(
+            nb_max_iteration,
+            lr=lr,
+            tol=tol,
+            do_smart_init=do_smart_init,
+            verbose=verbose,
+            batch_size=batch_size,
+        )
+
+    @_add_doc(
+        _model,
+        example="""
+            >>> import matplotlib.pyplot as plt
+            >>> from pyPLNmodels import ZIPln, get_real_count_data
+            >>> endog, labels = get_real_count_data(return_labels = True)
+            >>> zi = ZIPln(endog,add_const = True)
+            >>> zi.fit()
+            >>> zi.plot_expected_vs_true()
+            >>> plt.show()
+            >>> zi.plot_expected_vs_true(colors = labels)
+            >>> plt.show()
+            """,
+    )
+    def plot_expected_vs_true(self, ax=None, colors=None):
+        super().plot_expected_vs_true(ax=ax, colors=colors)
+
     @property
     def _description(self):
         return "with full covariance model and zero-inflation."
@@ -3346,7 +3471,7 @@ class ZIPln(Pln):
     def _smart_init_model_parameters(self):
         super()._smart_init_model_parameters()
         if not hasattr(self, "_covariance"):
-            self._covariance = _init_covariance(self._endog, self._exog, self._coef)
+            self._components = _init_components(self._endog, self._exog, self.dim)
         if not hasattr(self, "_coef_inflation"):
             self._coef_inflation = torch.randn(self.nb_cov, self.dim)
 
@@ -3354,11 +3479,29 @@ class ZIPln(Pln):
         self._dirac = self._endog == 0
         self._latent_mean = torch.randn(self.n_samples, self.dim)
         self._latent_sqrt_var = torch.randn(self.n_samples, self.dim)
-        self._pi = (
+        self._latent_prob = (
             torch.empty(self.n_samples, self.dim).uniform_(0, 1).to(DEVICE)
             * self._dirac
         )
 
+    @property
+    def _covariance(self):
+        return self._components @ (self._components.T)
+
+    @property
+    def covariance(self) -> torch.Tensor:
+        """
+        Property representing the covariance of the latent variables.
+
+        Returns
+        -------
+        Optional[torch.Tensor]
+            The covariance tensor or None if components are not present.
+        """
+        if hasattr(self, "_components"):
+            return self.components @ (self.components.T)
+        return None
+
     def compute_elbo(self):
         return elbo_zi_pln(
             self._endog,
@@ -3366,7 +3509,7 @@ class ZIPln(Pln):
             self._offsets,
             self._latent_mean,
             self._latent_sqrt_var,
-            self._pi,
+            self._latent_prob,
             self._covariance,
             self._coef,
             self._coef_inflation,
@@ -3375,9 +3518,19 @@ class ZIPln(Pln):
 
     @property
     def _list_of_parameters_needing_gradient(self):
-        return [self._latent_mean, self._latent_sqrt_var, self._coef_inflation]
+        list_parameters = [
+            self._latent_mean,
+            self._latent_sqrt_var,
+            self._coef_inflation,
+            self._components,
+            self._coef,
+        ]
+        if self._use_closed_form:
+            list_parameters.append(self._latent_prob)
+        return list_parameters
 
     def _update_closed_forms(self):
+        pass
         self._coef = _closed_formula_coef(self._exog, self._latent_mean)
         self._covariance = _closed_formula_covariance(
             self._exog,
@@ -3386,7 +3539,7 @@ class ZIPln(Pln):
             self._coef,
             self.n_samples,
         )
-        self._pi = _closed_formula_pi(
+        self._latent_prob = _closed_formula_latent_prob(
             self._offsets,
             self._latent_mean,
             self._latent_sqrt_var,
@@ -3395,6 +3548,15 @@ class ZIPln(Pln):
             self._coef_inflation,
         )
 
+    @property
+    def closed_form_latent_prob(self):
+        """
+        The closed form for the latent probability.
+        """
+        return closed_form_latent_prob(
+            self._exog, self._coef, self._coef_inflation, self._covariance, self._dirac
+        )
+
     @property
     def number_of_parameters(self):
         return self.dim * (2 * self.nb_cov + (self.dim + 1) / 2)
-- 
GitLab


From 228db74b296e6cfb47aa42a9f970589e740ba399 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Mon, 9 Oct 2023 16:35:25 +0200
Subject: [PATCH 20/68] continue to merge changes from the zi branch.

---
 pyPLNmodels/models.py | 167 ++++++++++++++++++++++++++++--------------
 1 file changed, 111 insertions(+), 56 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 66ea3339..6289640c 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -238,6 +238,12 @@ class _model(ABC):
                 _plot_ellipse(x[i], y[i], cov=covariances[i], ax=ax)
         return ax
 
+    def _update_parameters(self):
+        """
+        Update the parameters with a gradient step and project if necessary.
+        """
+        self.optim.step()
+
     def _handle_batch_size(self, batch_size):
         if batch_size is None:
             batch_size = self.n_samples
@@ -489,7 +495,7 @@ class _model(ABC):
             loss = -self._compute_elbo_b()
             loss.backward()
             elbo += loss.item()
-            self.optim.step()
+            self._udpate_parameters()
             self._update_closed_forms()
         return elbo / self._nb_batches
 
@@ -2482,23 +2488,26 @@ class PlnPCAcollection:
         bic = self.BIC
         aic = self.AIC
         loglikes = self.loglikes
-        bic_color = "blue"
-        aic_color = "red"
-        loglikes_color = "orange"
-        plt.scatter(bic.keys(), bic.values(), label="BIC criterion", c=bic_color)
-        plt.plot(bic.keys(), bic.values(), c=bic_color)
-        plt.axvline(self.best_BIC_model_rank, c=bic_color, linestyle="dotted")
-        plt.scatter(aic.keys(), aic.values(), label="AIC criterion", c=aic_color)
-        plt.axvline(self.best_AIC_model_rank, c=aic_color, linestyle="dotted")
-        plt.plot(aic.keys(), aic.values(), c=aic_color)
-        plt.xticks(list(aic.keys()))
-        plt.scatter(
-            loglikes.keys(),
-            -np.array(list(loglikes.values())),
-            label="Negative log likelihood",
-            c=loglikes_color,
-        )
-        plt.plot(loglikes.keys(), -np.array(list(loglikes.values())), c=loglikes_color)
+        colors = {"BIC": "blue", "AIC": "red", "Negative log likelihood": "orange"}
+        for criterion, values in zip(
+            ["BIC", "AIC", "Negative log likelihood"], [bic, aic, loglikes]
+        ):
+            plt.scatter(
+                values.keys(),
+                values.values(),
+                label=f"{criterion} criterion",
+                c=colors[criterion],
+            )
+            plt.plot(values.keys(), values.values(), c=colors[criterion])
+            if criterion == "BIC":
+                plt.axvline(
+                    self.best_BIC_model_rank, c=colors[criterion], linestyle="dotted"
+                )
+            elif criterion == "AIC":
+                plt.axvline(
+                    self.best_AIC_model_rank, c=colors[criterion], linestyle="dotted"
+                )
+                plt.xticks(list(values.keys()))
         plt.legend()
         plt.show()
 
@@ -2696,7 +2705,7 @@ class PlnPCA(_model):
     )
     def __init__(
         self,
-        endog: Optional[Union[torch.Tensor, np.ndarray, pd.DataFrame]],
+        endog: Union[torch.Tensor, np.ndarray, pd.DataFrame],
         *,
         exog: Optional[Union[torch.Tensor, np.ndarray, pd.DataFrame]] = None,
         offsets: Optional[Union[torch.Tensor, np.ndarray, pd.DataFrame]] = None,
@@ -3465,16 +3474,21 @@ class ZIPln(_model):
     def _random_init_model_parameters(self):
         super()._random_init_model_parameters()
         self._coef_inflation = torch.randn(self.nb_cov, self.dim)
-        self._covariance = torch.diag(torch.ones(self.dim)).to(DEVICE)
+        self._coef = torch.randn(self.nb_cov, self.dim)
+        self._components = torch.randn(self.nb_cov, self.dim)
 
-    # should change the good initialization, especially for _coef_inflation
+    # should change the good initialization for _coef_inflation
     def _smart_init_model_parameters(self):
+        # init of _coef.
         super()._smart_init_model_parameters()
         if not hasattr(self, "_covariance"):
             self._components = _init_components(self._endog, self._exog, self.dim)
         if not hasattr(self, "_coef_inflation"):
             self._coef_inflation = torch.randn(self.nb_cov, self.dim)
 
+    def _print_beginning_message(self):
+        print("Fitting a ZIPln model.")
+
     def _random_init_latent_parameters(self):
         self._dirac = self._endog == 0
         self._latent_mean = torch.randn(self.n_samples, self.dim)
@@ -3482,12 +3496,36 @@ class ZIPln(_model):
         self._latent_prob = (
             torch.empty(self.n_samples, self.dim).uniform_(0, 1).to(DEVICE)
             * self._dirac
-        )
+        ).double()
+
+    def _smart_init_latent_parameters(self):
+        self._random_init_latent_parameters()
 
     @property
     def _covariance(self):
         return self._components @ (self._components.T)
 
+    def latent_variables(self):
+        return self.latent_mean, self.latent_prob
+
+    def _update_parameters(self):
+        super()._update_parameters()
+        self._project_latent_prob()
+
+    def _project_latent_prob(self):
+        """
+        Project the latent probability since it must be between 0 and 1.
+        """
+        if self.use_closed_form_prob is False:
+            with torch.no_grad():
+                self._latent_prob = torch.maximum(
+                    self._latent_prob, torch.tensor([0]), out=self._latent_prob
+                )
+                self._latent_prob = torch.minimum(
+                    self._latent_prob, torch.tensor([1]), out=self._latent_prob
+                )
+                self._latent_prob *= self._dirac
+
     @property
     def covariance(self) -> torch.Tensor:
         """
@@ -3498,24 +3536,67 @@ class ZIPln(_model):
         Optional[torch.Tensor]
             The covariance tensor or None if components are not present.
         """
-        if hasattr(self, "_components"):
-            return self.components @ (self.components.T)
-        return None
+        return self._cpu_attribute_or_none("_covariance")
+
+    @property
+    def latent_prob(self):
+        return self._cpu_attribute_or_none("_latent_prob")
+
+    @property
+    def closed_form_latent_prob(self):
+        """
+        The closed form for the latent probability.
+        """
+        return closed_form_latent_prob(
+            self._exog, self._coef, self._coef_inflation, self._covariance, self._dirac
+        )
 
     def compute_elbo(self):
+        if self._use_closed_form_prob is True:
+            latent_prob = self.closed_form_latent_prob
+        else:
+            latent_prob = self._latent_prob
         return elbo_zi_pln(
             self._endog,
             self._exog,
             self._offsets,
             self._latent_mean,
             self._latent_sqrt_var,
-            self._latent_prob,
-            self._covariance,
+            latent_prob,
+            self._components,
             self._coef,
             self._coef_inflation,
             self._dirac,
         )
 
+    def _compute_elbo_b(self):
+        if self._use_closed_form_prob is True:
+            latent_prob_b = _closed_form_latent_prob(
+                self._exog_b,
+                self._coef,
+                self._coef_inflation,
+                self._covariance,
+                self._dirac_b,
+            )
+        else:
+            latent_prob_b = self._latent_prob_b
+        return elbo_zi_pln(
+            self._endog_b,
+            self._exog_b,
+            self._offsets_b,
+            self._latent_mean_b,
+            self._latent_sqrt_var_b,
+            latent_prob_b,
+            self._components,
+            self._coef,
+            self._coef_inflation,
+            self._dirac_b,
+        )
+
+    @property
+    def number_of_parameters(self):
+        return self.dim * (2 * self.nb_cov + (self.dim + 1) / 2)
+
     @property
     def _list_of_parameters_needing_gradient(self):
         list_parameters = [
@@ -3527,36 +3608,10 @@ class ZIPln(_model):
         ]
         if self._use_closed_form:
             list_parameters.append(self._latent_prob)
+        if self._exog is not None:
+            list_parameters.append(self._coef)
+            list_parameters.append(self._coef_inflation)
         return list_parameters
 
     def _update_closed_forms(self):
         pass
-        self._coef = _closed_formula_coef(self._exog, self._latent_mean)
-        self._covariance = _closed_formula_covariance(
-            self._exog,
-            self._latent_mean,
-            self._latent_sqrt_var,
-            self._coef,
-            self.n_samples,
-        )
-        self._latent_prob = _closed_formula_latent_prob(
-            self._offsets,
-            self._latent_mean,
-            self._latent_sqrt_var,
-            self._dirac,
-            self._exog,
-            self._coef_inflation,
-        )
-
-    @property
-    def closed_form_latent_prob(self):
-        """
-        The closed form for the latent probability.
-        """
-        return closed_form_latent_prob(
-            self._exog, self._coef, self._coef_inflation, self._covariance, self._dirac
-        )
-
-    @property
-    def number_of_parameters(self):
-        return self.dim * (2 * self.nb_cov + (self.dim + 1) / 2)
-- 
GitLab


From a5dcb70af4883c07f166103f6c2fbac21de5730a Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Mon, 9 Oct 2023 16:38:27 +0200
Subject: [PATCH 21/68] add gradients to zero inflated class.

---
 pyPLNmodels/models.py | 313 ++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 313 insertions(+)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 6289640c..6c0fd0fe 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -3615,3 +3615,316 @@ class ZIPln(_model):
 
     def _update_closed_forms(self):
         pass
+
+    def grad_M(self):
+        if self.use_closed_form_prob is True:
+            latent_prob = self.closed_form_latent_prob
+        else:
+            latent_prob = self._latent_prob
+        un_moins_prob = 1 - latent_prob
+        first = un_moins_prob * (
+            self._endog
+            - torch.exp(
+                self._offsets + self._latent_mean + self.latent_sqrt_var**2 / 2
+            )
+        )
+        MmoinsXB = self._latent_mean - self._exog @ self._coef
+        A = (un_moins_prob * MmoinsXB) @ torch.inverse(self._covariance)
+        diag_omega = torch.diag(torch.inverse(self._covariance))
+        full_diag_omega = diag_omega.expand(self.exog.shape[0], -1)
+        second = -un_moins_prob * A
+        added = -full_diag_omega * latent_prob * un_moins_prob * (MmoinsXB)
+        return first + second + added
+
+    def grad_S(self):
+        if self.use_closed_form_prob is True:
+            latent_prob = self.closed_form_latent_prob
+        else:
+            latent_prob = self._latent_prob
+        Omega = torch.inverse(self.covariance)
+        un_moins_prob = 1 - latent_prob
+        first = un_moins_prob * torch.exp(
+            self._offsets + self._latent_mean + self._latent_sqrt_var**2 / 2
+        )
+        first = -torch.multiply(first, self._latent_sqrt_var)
+        sec = un_moins_prob * 1 / self._latent_sqrt_var
+        K = un_moins_prob * (
+            torch.multiply(
+                torch.full((self.n_samples, 1), 1.0), torch.diag(Omega).unsqueeze(0)
+            )
+        )
+        third = -self._latent_sqrt_var * K
+        return first + sec + third
+
+    def grad_theta(self):
+        if self.use_closed_form_prob is True:
+            latent_prob = self.closed_form_latent_prob
+        else:
+            latent_prob = self._latent_prob
+
+        un_moins_prob = 1 - latent_prob
+        MmoinsXB = self._latent_mean - self._exog @ self._coef
+        A = (un_moins_prob * MmoinsXB) @ torch.inverse(self._covariance)
+        diag_omega = torch.diag(torch.inverse(self._covariance))
+        full_diag_omega = diag_omega.expand(self.exog.shape[0], -1)
+        added = latent_prob * (MmoinsXB) * full_diag_omega
+        A += added
+        second = -un_moins_prob * A
+        grad_no_closed_form = -self._exog.T @ second
+        if self.use_closed_form_prob is False:
+            return grad_no_closed_form
+        else:
+            XB_zero = self._exog @ self._coef_inflation
+            diag = torch.diag(self._covariance)
+            full_diag = diag.expand(self._exog.shape[0], -1)
+            XB = self._exog @ self._coef
+            derivative = d_h_x2(XB_zero, XB, full_diag, self._dirac)
+            grad_closed_form = self.gradients_closed_form_thetas(derivative)
+            return grad_closed_form + grad_no_closed_form
+
+    def gradients_closed_form_thetas(self, derivative):
+        Omega = torch.inverse(self._covariance)
+        MmoinsXB = self._latent_mean - self._exog @ self._coef
+        s_rond_s = self._latent_sqrt_var**2
+        latent_prob = self.closed_form_latent_prob
+        A = torch.exp(self._offsets + self._latent_mean + s_rond_s / 2)
+        poiss_term = (
+            self._endog * (self._offsets + self._latent_mean)
+            - A
+            - _log_stirling(self._endog)
+        )
+        a = -self._exog.T @ (derivative * poiss_term)
+        b = self._exog.T @ (
+            derivative * MmoinsXB * (((1 - latent_prob) * MmoinsXB) @ Omega)
+        )
+        c = self._exog.T @ (derivative * (self._exog @ self._coef_inflation))
+        first_d = derivative * torch.log(torch.abs(self._latent_sqrt_var))
+        second_d = (
+            1 / 2 * derivative @ (torch.diag(torch.log(torch.diag(self._covariance))))
+        )
+        d = -self._exog.T @ (first_d - second_d)
+        e = -self._exog.T @ (
+            derivative * (_trunc_log(latent_prob) - _trunc_log(1 - latent_prob))
+        )
+        first_f = (
+            +1
+            / 2
+            * self._exog.T
+            @ (derivative * (s_rond_s @ torch.diag(torch.diag(Omega))))
+        )
+        second_f = (
+            -1
+            / 2
+            * self._exog.T
+            @ derivative
+            @ torch.diag(torch.diag(Omega) * torch.diag(self._covariance))
+        )
+        full_diag_omega = torch.diag(Omega).expand(self.exog.shape[0], -1)
+        common = (MmoinsXB) ** 2 * (full_diag_omega)
+        new_f = -1 / 2 * self._exog.T @ (derivative * common * (1 - 2 * latent_prob))
+        f = first_f + second_f + new_f
+        return a + b + c + d + e + f
+
+    def grad_theta_0(self):
+        if self.use_closed_form_prob is True:
+            latent_prob = self.closed_form_latent_prob
+        else:
+            latent_prob = self._latent_prob
+        grad_no_closed_form = self._exog.T @ latent_prob - self._exog.T @ (
+            torch.exp(self._exog @ self._coef_inflation)
+            / (1 + torch.exp(self._exog @ self._coef_inflation))
+        )
+        if self.use_closed_form_prob is False:
+            return grad_no_closed_form
+        else:
+            grad_closed_form = self.gradients_closed_form_thetas(
+                latent_prob * (1 - latent_prob)
+            )
+            return grad_closed_form + grad_no_closed_form
+
+    def grad_C(self):
+        if self.use_closed_form_prob is True:
+            latent_prob = self.closed_form_latent_prob
+        else:
+            latent_prob = self._latent_prob
+        omega = torch.inverse(self._covariance)
+        if self._coef is not None:
+            m_minus_xb = self._latent_mean - torch.mm(self._exog, self._coef)
+        else:
+            m_minus_xb = self._latent_mean
+        m_moins_xb_outer = torch.mm(m_minus_xb.T, m_minus_xb)
+
+        un_moins_rho = 1 - latent_prob
+
+        un_moins_rho_m_moins_xb = un_moins_rho * m_minus_xb
+        un_moins_rho_m_moins_xb_outer = (
+            un_moins_rho_m_moins_xb.T @ un_moins_rho_m_moins_xb
+        )
+        deter = (
+            -self.n_samples
+            * torch.inverse(self._components @ (self._components.T))
+            @ self._components
+        )
+        sec_part_b_grad = (
+            omega @ (un_moins_rho_m_moins_xb_outer) @ omega @ self._components
+        )
+        b_grad = deter + sec_part_b_grad
+
+        diag = torch.diag(self.covariance)
+        rho_t_unn = torch.sum(latent_prob, axis=0)
+        omega_unp = torch.sum(omega, axis=0)
+        K = torch.sum(un_moins_rho * self._latent_sqrt_var**2, axis=0) + diag * (
+            rho_t_unn
+        )
+        added = torch.sum(latent_prob * un_moins_rho * (m_minus_xb**2), axis=0)
+        K += added
+        first_part_grad = omega @ torch.diag_embed(K) @ omega @ self._components
+        x = torch.diag(omega) * rho_t_unn
+        second_part_grad = -torch.diag_embed(x) @ self._components
+        y = rho_t_unn
+        first = torch.multiply(y, 1 / torch.diag(self.covariance)).unsqueeze(1)
+        second = torch.full((1, self.dim), 1.0)
+        Diag = (first * second) * torch.eye(self.dim)
+        last_grad = Diag @ self._components
+        grad_no_closed_form = b_grad + first_part_grad + second_part_grad + last_grad
+        if self.use_closed_form_prob is False:
+            return grad_no_closed_form
+        else:
+            s_rond_s = self._latent_sqrt_var**2
+            XB_zero = self._exog @ self._coef_inflation
+            XB = self._exog @ self._coef
+            A = torch.exp(self._offsets + self._latent_mean + s_rond_s / 2)
+            poiss_term = (
+                self._endog * (self._offsets + self._latent_mean)
+                - A
+                - _log_stirling(self._endog)
+            )
+            full_diag_sigma = diag.expand(self._exog.shape[0], -1)
+            full_diag_omega = torch.diag(omega).expand(self._exog.shape[0], -1)
+            H3 = d_h_x3(XB_zero, XB, full_diag_sigma, self._dirac)
+            poiss_term_H = poiss_term * H3
+            a = (
+                -2
+                * (
+                    ((poiss_term_H.T @ torch.ones(self.n_samples, self.dim)))
+                    * (torch.eye(self.dim))
+                )
+                @ self._components
+            )
+            B_Omega = ((1 - latent_prob) * m_minus_xb) @ omega
+            K = H3 * B_Omega * m_minus_xb
+            b = (
+                2
+                * (
+                    (
+                        (m_minus_xb * B_Omega * H3).T
+                        @ torch.ones(self.n_samples, self.dim)
+                    )
+                    * torch.eye(self.dim)
+                )
+                @ self._components
+            )
+            c = (
+                2
+                * (
+                    ((XB_zero * H3).T @ torch.ones(self.n_samples, self.dim))
+                    * torch.eye(self.dim)
+                )
+                @ self._components
+            )
+            d = (
+                -2
+                * (
+                    (
+                        (torch.log(torch.abs(self._latent_sqrt_var)) * H3).T
+                        @ torch.ones(self.n_samples, self.dim)
+                    )
+                    * torch.eye(self.dim)
+                )
+                @ self._components
+            )
+            log_full_diag_sigma = torch.log(diag).expand(self._exog.shape[0], -1)
+            d += (
+                ((log_full_diag_sigma * H3).T @ torch.ones(self.n_samples, self.dim))
+                * torch.eye(self.dim)
+            ) @ self._components
+            e = (
+                -2
+                * (
+                    (
+                        ((_trunc_log(latent_prob) - _trunc_log(1 - latent_prob)) * H3).T
+                        @ torch.ones(self.n_samples, self.dim)
+                    )
+                    * torch.eye(self.dim)
+                )
+                @ self._components
+            )
+            f = (
+                -(
+                    (
+                        (full_diag_omega * (full_diag_sigma - s_rond_s) * H3).T
+                        @ torch.ones(self.n_samples, self.dim)
+                    )
+                    * torch.eye(self.dim)
+                )
+                @ self._components
+            )
+            f -= (
+                (
+                    ((1 - 2 * latent_prob) * m_minus_xb**2 * full_diag_omega * H3).T
+                    @ torch.ones(self.n_samples, self.dim)
+                )
+                * torch.eye(self.dim)
+            ) @ self._components
+            grad_closed_form = a + b + c + d + e + f
+            return grad_closed_form + grad_no_closed_form
+
+    def grad_rho(self):
+        if self.use_closed_form_prob is True:
+            latent_prob = self.closed_form_latent_prob
+        else:
+            latent_prob = self._latent_prob
+        omega = torch.inverse(self._covariance)
+        s_rond_s = self._latent_sqrt_var * self._latent_sqrt_var
+        A = torch.exp(self._offsets + self._latent_mean + s_rond_s / 2)
+        first = (
+            -self._endog * (self._offsets + self._latent_mean)
+            + A
+            + _log_stirling(self._endog)
+        )
+        un_moins_prob = 1 - latent_prob
+        MmoinsXB = self._latent_mean - self._exog @ self._coef
+        A = (un_moins_prob * MmoinsXB) @ torch.inverse(self._covariance)
+        second = MmoinsXB * A
+        third = self._exog @ self._coef_inflation
+        fourth_first = -torch.log(torch.abs(self._latent_sqrt_var))
+        fourth_second = (
+            1
+            / 2
+            * torch.multiply(
+                torch.full((self.n_samples, 1), 1.0),
+                torch.log(torch.diag(self.covariance)).unsqueeze(0),
+            )
+        )
+        fourth = fourth_first + fourth_second
+        fifth = _trunc_log(un_moins_prob) - _trunc_log(latent_prob)
+        sixth_first = (
+            1
+            / 2
+            * torch.multiply(
+                torch.full((self.n_samples, 1), 1.0), torch.diag(omega).unsqueeze(0)
+            )
+            * s_rond_s
+        )
+        sixth_second = (
+            -1
+            / 2
+            * torch.multiply(
+                torch.full((self.n_samples, 1), 1.0),
+                (torch.diag(omega) * torch.diag(self._covariance)).unsqueeze(0),
+            )
+        )
+        sixth = sixth_first + sixth_second
+        full_diag_omega = torch.diag(omega).expand(self.exog.shape[0], -1)
+        seventh = -1 / 2 * (1 - 2 * latent_prob) * (MmoinsXB) ** 2 * (full_diag_omega)
+        return first + second + third + fourth + fifth + sixth + seventh
-- 
GitLab


From 07c0e81362d2026430f6687847629664b5718faf Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Tue, 10 Oct 2023 19:16:24 +0200
Subject: [PATCH 22/68] pass the tests but ZI or stochasticity have not been
 tested yet.

---
 pyPLNmodels/_closed_forms.py | 15 +++++++++
 pyPLNmodels/_utils.py        | 48 +++++++++++++++++++++++++++--
 pyPLNmodels/models.py        | 59 ++++++++++++++++++++++--------------
 tests/test_pln_full.py       |  2 +-
 tests/test_setters.py        | 18 ++++++-----
 5 files changed, 109 insertions(+), 33 deletions(-)

diff --git a/pyPLNmodels/_closed_forms.py b/pyPLNmodels/_closed_forms.py
index b57e7850..3524d48d 100644
--- a/pyPLNmodels/_closed_forms.py
+++ b/pyPLNmodels/_closed_forms.py
@@ -1,4 +1,5 @@
 from typing import Optional
+from ._utils import phi
 
 import torch  # pylint:disable=[C0114]
 
@@ -98,3 +99,17 @@ def _closed_formula_pi(
     """
     poiss_param = torch.exp(offsets + latent_mean + 0.5 * torch.square(latent_sqrt_var))
     return torch._sigmoid(poiss_param + torch.mm(exog, _coef_inflation)) * dirac
+
+
+def _closed_formula_latent_prob(exog, coef, coef_infla, cov, dirac):
+    if exog is not None:
+        XB = exog @ coef
+        XB_zero = exog @ coef_infla
+    else:
+        XB_zero = 0
+        XB = 0
+    XB_zero = exog @ coef_infla
+    pi = torch.sigmoid(XB_zero)
+    diag = torch.diag(cov)
+    full_diag = diag.expand(exog.shape[0], -1)
+    return torch.sigmoid(XB_zero - torch.log(phi(XB, full_diag))) * dirac
diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index d2b1aea0..7169b74d 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -79,8 +79,8 @@ class _PlotArgs:
         """
         ax = plt.gca() if ax is None else ax
         ax.plot(
-            self.running_times[self.window :],
-            self.criterions[self.window :],
+            self.running_times,
+            self.criterions,
             label="Delta",
         )
         ax.set_yscale("log")
@@ -1004,3 +1004,47 @@ def _add_doc(parent_class, *, params=None, example=None, returns=None, see_also=
         return fun
 
     return wrapper
+
+
+def pf_lambert(x, y):
+    return x - (1 - (y * torch.exp(-x) + 1) / (x + 1))
+
+
+def lambert(y, nb_pf=10):
+    x = torch.log(1 + y)
+    for _ in range(nb_pf):
+        x = pf_lambert(x, y)
+    return x
+
+
+def d_varpsi_x1(mu, sigma2):
+    W = lambert(sigma2 * torch.exp(mu))
+    first = phi(mu, sigma2)
+    third = 1 / sigma2 + 1 / 2 * 1 / ((1 + W) ** 2)
+    return -first * W * third
+
+
+def phi(mu, sigma2):
+    y = sigma2 * torch.exp(mu)
+    lamby = lambert(y)
+    log_num = -1 / (2 * sigma2) * (lamby**2 + 2 * lamby)
+    return torch.exp(log_num) / torch.sqrt(1 + lamby)
+
+
+def d_varpsi_x2(mu, sigma2):
+    first = d_varpsi_x1(mu, sigma2) / sigma2
+    W = lambert(sigma2 * torch.exp(mu))
+    second = (W**2 + 2 * W) / 2 / (sigma2**2) * phi(mu, sigma2)
+    return first + second
+
+
+def d_h_x2(a, x, y, dirac):
+    rho = torch.sigmoid(a - torch.log(phi(x, y))) * dirac
+    rho_prime = rho * (1 - rho)
+    return -rho_prime * d_varpsi_x1(x, y) / phi(x, y)
+
+
+def d_h_x3(a, x, y, dirac):
+    rho = torch.sigmoid(a - torch.log(phi(x, y))) * dirac
+    rho_prime = rho * (1 - rho)
+    return -rho_prime * d_varpsi_x2(x, y) / phi(x, y)
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 6c0fd0fe..64ddb104 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -32,7 +32,6 @@ from ._utils import (
     _array2tensor,
     _handle_data,
     _add_doc,
-    _closed_form_latent_prob,
 )
 
 from ._initialization import (
@@ -65,6 +64,7 @@ class _model(ABC):
     _beginning_time: float
     _latent_sqrt_var: torch.Tensor
     _latent_mean: torch.Tensor
+    _batch_size: int = None
 
     def __init__(
         self,
@@ -164,7 +164,10 @@ class _model(ABC):
         """
         if "coef" not in dict_initialization.keys():
             print("No coef is initialized.")
-            self.coef = None
+            dict_initialization["coef"] = None
+        if self._NAME == "Pln":
+            del dict_initialization["covariance"]
+            del dict_initialization["coef"]
         for key, array in dict_initialization.items():
             array = _format_data(array)
             setattr(self, key, array)
@@ -175,6 +178,8 @@ class _model(ABC):
         """
         The batch size of the model. Should not be greater than the number of samples.
         """
+        if self._batch_size is None:
+            return self.n_samples
         return self._batch_size
 
     @property
@@ -265,7 +270,7 @@ class _model(ABC):
         int
             The number of iterations done.
         """
-        return len(self._plotargs._elbos_list) * self._nb_batches
+        return len(self._plotargs._elbos_list) * self.nb_batches
 
     @property
     def n_samples(self) -> int:
@@ -359,7 +364,7 @@ class _model(ABC):
 
     def _put_parameters_to_device(self):
         """
-        Move parameters to the device.
+        Move parameters to the cGPU device if present.
         """
         for parameter in self._list_of_parameters_needing_gradient:
             parameter.requires_grad_(True)
@@ -374,7 +379,6 @@ class _model(ABC):
         List[torch.Tensor]
             List of parameters needing gradient.
         """
-        ...
 
     def fit(
         self,
@@ -459,9 +463,13 @@ class _model(ABC):
 
     def _return_batch(self, indices, beginning, end):
         to_take = torch.tensor(indices[beginning:end]).to(DEVICE)
+        if self._exog is not None:
+            exog_b = torch.index_select(self._exog, 0, to_take)
+        else:
+            exog_b = None
         return (
             torch.index_select(self._endog, 0, to_take),
-            torch.index_select(self._exog, 0, to_take),
+            exog_b,
             torch.index_select(self._offsets, 0, to_take),
             torch.index_select(self._latent_mean, 0, to_take),
             torch.index_select(self._latent_sqrt_var, 0, to_take),
@@ -469,14 +477,14 @@ class _model(ABC):
 
     @property
     def _nb_full_batch(self):
-        return self.n_samples // self._batch_size
+        return self.n_samples // self.batch_size
 
     @property
     def _last_batch_size(self):
-        return self.n_samples % self._batch_size
+        return self.n_samples % self.batch_size
 
     @property
-    def _nb_batches(self):
+    def nb_batches(self):
         return self._nb_full_batch + (self._last_batch_size > 0)
 
     def _trainstep(self):
@@ -495,9 +503,9 @@ class _model(ABC):
             loss = -self._compute_elbo_b()
             loss.backward()
             elbo += loss.item()
-            self._udpate_parameters()
+            self._update_parameters()
             self._update_closed_forms()
-        return elbo / self._nb_batches
+        return elbo / self.nb_batches
 
     def _extract_batch(self, batch):
         self._endog_b = batch[0]
@@ -740,8 +748,9 @@ class _model(ABC):
         """
         self._plotargs._elbos_list.append(-loss)
         self._plotargs.running_times.append(time.time() - self._beginning_time)
+        elbo = -loss
         self._plotargs.cumulative_elbo_list.append(
-            self._plotargs.cumulative_elbo_list - loss
+            self._plotargs.cumulative_elbo + elbo
         )
         criterion = (
             self._plotargs.cumulative_elbo_list[-2]
@@ -1652,7 +1661,11 @@ class Pln(_model):
         ----------
         coef : Union[torch.Tensor, np.ndarray, pd.DataFrame]
             The regression coefficients of the gaussian latent variables.
+        Raises
+        ------
+        AttributeError since you can not set the coef in the Pln model.
         """
+        raise AttributeError("You can not set the coef in the Pln model.")
 
     def _endog_predictions(self):
         return torch.exp(
@@ -3543,17 +3556,17 @@ class ZIPln(_model):
         return self._cpu_attribute_or_none("_latent_prob")
 
     @property
-    def closed_form_latent_prob(self):
+    def closed_formula_latent_prob(self):
         """
         The closed form for the latent probability.
         """
-        return closed_form_latent_prob(
+        return closed_formula_latent_prob(
             self._exog, self._coef, self._coef_inflation, self._covariance, self._dirac
         )
 
     def compute_elbo(self):
         if self._use_closed_form_prob is True:
-            latent_prob = self.closed_form_latent_prob
+            latent_prob = self.closed_formula_latent_prob
         else:
             latent_prob = self._latent_prob
         return elbo_zi_pln(
@@ -3571,7 +3584,7 @@ class ZIPln(_model):
 
     def _compute_elbo_b(self):
         if self._use_closed_form_prob is True:
-            latent_prob_b = _closed_form_latent_prob(
+            latent_prob_b = _closed_formula_latent_prob(
                 self._exog_b,
                 self._coef,
                 self._coef_inflation,
@@ -3618,7 +3631,7 @@ class ZIPln(_model):
 
     def grad_M(self):
         if self.use_closed_form_prob is True:
-            latent_prob = self.closed_form_latent_prob
+            latent_prob = self.closed_formula_latent_prob
         else:
             latent_prob = self._latent_prob
         un_moins_prob = 1 - latent_prob
@@ -3638,7 +3651,7 @@ class ZIPln(_model):
 
     def grad_S(self):
         if self.use_closed_form_prob is True:
-            latent_prob = self.closed_form_latent_prob
+            latent_prob = self.closed_formula_latent_prob
         else:
             latent_prob = self._latent_prob
         Omega = torch.inverse(self.covariance)
@@ -3658,7 +3671,7 @@ class ZIPln(_model):
 
     def grad_theta(self):
         if self.use_closed_form_prob is True:
-            latent_prob = self.closed_form_latent_prob
+            latent_prob = self.closed_formula_latent_prob
         else:
             latent_prob = self._latent_prob
 
@@ -3686,7 +3699,7 @@ class ZIPln(_model):
         Omega = torch.inverse(self._covariance)
         MmoinsXB = self._latent_mean - self._exog @ self._coef
         s_rond_s = self._latent_sqrt_var**2
-        latent_prob = self.closed_form_latent_prob
+        latent_prob = self.closed_formula_latent_prob
         A = torch.exp(self._offsets + self._latent_mean + s_rond_s / 2)
         poiss_term = (
             self._endog * (self._offsets + self._latent_mean)
@@ -3727,7 +3740,7 @@ class ZIPln(_model):
 
     def grad_theta_0(self):
         if self.use_closed_form_prob is True:
-            latent_prob = self.closed_form_latent_prob
+            latent_prob = self.closed_formula_latent_prob
         else:
             latent_prob = self._latent_prob
         grad_no_closed_form = self._exog.T @ latent_prob - self._exog.T @ (
@@ -3744,7 +3757,7 @@ class ZIPln(_model):
 
     def grad_C(self):
         if self.use_closed_form_prob is True:
-            latent_prob = self.closed_form_latent_prob
+            latent_prob = self.closed_formula_latent_prob
         else:
             latent_prob = self._latent_prob
         omega = torch.inverse(self._covariance)
@@ -3881,7 +3894,7 @@ class ZIPln(_model):
 
     def grad_rho(self):
         if self.use_closed_form_prob is True:
-            latent_prob = self.closed_form_latent_prob
+            latent_prob = self.closed_formula_latent_prob
         else:
             latent_prob = self._latent_prob
         omega = torch.inverse(self._covariance)
diff --git a/tests/test_pln_full.py b/tests/test_pln_full.py
index 2d61befd..870114a0 100644
--- a/tests/test_pln_full.py
+++ b/tests/test_pln_full.py
@@ -8,7 +8,7 @@ from tests.utils import filter_models
 @filter_models(["Pln"])
 def test_number_of_iterations_pln_full(fitted_pln):
     nb_iterations = len(fitted_pln._elbos_list)
-    assert 50 < nb_iterations < 500
+    assert 20 < nb_iterations < 1000
 
 
 @pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"])
diff --git a/tests/test_setters.py b/tests/test_setters.py
index 828989e8..eb7814d7 100644
--- a/tests/test_setters.py
+++ b/tests/test_setters.py
@@ -19,7 +19,8 @@ def test_data_setter_with_torch(pln):
 def test_parameters_setter_with_torch(pln):
     pln.latent_mean = pln.latent_mean
     pln.latent_sqrt_var = pln.latent_sqrt_var
-    pln.coef = pln.coef
+    if pln._NAME != "Pln":
+        pln.coef = pln.coef
     if pln._NAME == "PlnPCA":
         pln.components = pln.components
     pln.fit()
@@ -50,7 +51,8 @@ def test_parameters_setter_with_numpy(pln):
         np_coef = None
     pln.latent_mean = np_latent_mean
     pln.latent_sqrt_var = np_latent_sqrt_var
-    pln.coef = np_coef
+    if pln._NAME != "Pln":
+        pln.coef = np_coef
     if pln._NAME == "PlnPCA":
         pln.components = pln.components.numpy()
     pln.fit()
@@ -81,7 +83,8 @@ def test_parameters_setter_with_pandas(pln):
         pd_coef = None
     pln.latent_mean = pd_latent_mean
     pln.latent_sqrt_var = pd_latent_sqrt_var
-    pln.coef = pd_coef
+    if pln._NAME != "Pln":
+        pln.coef = pd_coef
     if pln._NAME == "PlnPCA":
         pln.components = pd.DataFrame(pln.components.numpy())
     pln.fit()
@@ -141,8 +144,9 @@ def test_fail_parameters_setter_with_torch(pln):
             d = 0
         else:
             d = pln.exog.shape[-1]
-        with pytest.raises(ValueError):
-            pln.coef = torch.zeros(d + 1, dim)
+        if pln._NAME != "Pln":
+            with pytest.raises(ValueError):
+                pln.coef = torch.zeros(d + 1, dim)
 
-        with pytest.raises(ValueError):
-            pln.coef = torch.zeros(d, dim + 1)
+            with pytest.raises(ValueError):
+                pln.coef = torch.zeros(d, dim + 1)
-- 
GitLab


From abb8a72976a6ea09bbb4fde80f424c40c7722527 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Tue, 10 Oct 2023 22:55:12 +0200
Subject: [PATCH 23/68] add ZI in the __init__

---
 pyPLNmodels/__init__.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pyPLNmodels/__init__.py b/pyPLNmodels/__init__.py
index e785b288..6ed723c7 100644
--- a/pyPLNmodels/__init__.py
+++ b/pyPLNmodels/__init__.py
@@ -1,4 +1,4 @@
-from .models import PlnPCAcollection, Pln, PlnPCA  # pylint:disable=[C0114]
+from .models import PlnPCAcollection, Pln, PlnPCA, ZIPln  # pylint:disable=[C0114]
 from .oaks import load_oaks
 from .elbos import profiled_elbo_pln, elbo_plnpca, elbo_pln
 from ._utils import (
-- 
GitLab


From 86e5612501e4c982b30963e27592e32081563e5f Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Wed, 11 Oct 2023 09:03:48 +0200
Subject: [PATCH 24/68] tried to import the ZI.

---
 pyPLNmodels/models.py | 35 ++++++++++++++++++++++++++++++-----
 1 file changed, 30 insertions(+), 5 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 64ddb104..7974a1e4 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -3382,7 +3382,7 @@ class ZIPln(_model):
         dict_initialization: Optional[Dict[str, torch.Tensor]] = None,
         take_log_offsets: bool = False,
         add_const: bool = True,
-        use_closed_form: bool = False,
+        use_closed_form_prob: bool = False,
     ):
         super().__init__(
             endog=endog,
@@ -3393,7 +3393,32 @@ class ZIPln(_model):
             take_log_offsets=take_log_offsets,
             add_const=add_const,
         )
-        self._use_closed_form = use_closed_form
+        self._use_closed_form_prob = use_closed_form_prob
+
+    def _extract_batch(self, batch):
+        super()._extract_batch(batch)
+        if self._use_closed_form_prob is False:
+            self._latent_prob_b = batch[5]
+
+    def _return_batch(self, indices, beginning, end):
+        pln_batch = super()._return_batch(indices, beginning, end)
+        if self._use_closed_form_prob is False:
+            return pln_batch + torch.index_select(self._latent_prob, 0, to_take)
+        return pln_batch
+
+    def _return_batch(self, indices, beginning, end):
+        to_take = torch.tensor(indices[beginning:end]).to(DEVICE)
+        if self._exog is not None:
+            exog_b = torch.index_select(self._exog, 0, to_take)
+        else:
+            exog_b = None
+        return (
+            torch.index_select(self._endog, 0, to_take),
+            exog_b,
+            torch.index_select(self._offsets, 0, to_take),
+            torch.index_select(self._latent_mean, 0, to_take),
+            torch.index_select(self._latent_sqrt_var, 0, to_take),
+        )
 
     @classmethod
     @_add_doc(
@@ -3439,7 +3464,7 @@ class ZIPln(_model):
         example="""
         >>> from pyPLNmodels import ZIPln, get_real_count_data
         >>> endog = get_real_count_data()
-        >>> zi = Pln(endog,add_const = True)
+        >>> zi = ZIPln(endog,add_const = True)
         >>> zi.fit()
         >>> print(zi)
         """,
@@ -3493,7 +3518,7 @@ class ZIPln(_model):
     # should change the good initialization for _coef_inflation
     def _smart_init_model_parameters(self):
         # init of _coef.
-        super()._smart_init_model_parameters()
+        super()._smart_init_coef()
         if not hasattr(self, "_covariance"):
             self._components = _init_components(self._endog, self._exog, self.dim)
         if not hasattr(self, "_coef_inflation"):
@@ -3619,7 +3644,7 @@ class ZIPln(_model):
             self._components,
             self._coef,
         ]
-        if self._use_closed_form:
+        if self._use_closed_form_prob:
             list_parameters.append(self._latent_prob)
         if self._exog is not None:
             list_parameters.append(self._coef)
-- 
GitLab


From e3c325644ca283150b4e8a94b747a96741a3ea1a Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 12 Oct 2023 10:25:12 +0200
Subject: [PATCH 25/68] finally add the right criterion

---
 pyPLNmodels/_utils.py | 35 ++++++++++++++++++++--
 pyPLNmodels/models.py | 67 ++++++++++++++++++-------------------------
 2 files changed, 61 insertions(+), 41 deletions(-)

diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 7169b74d..582f859f 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -23,15 +23,46 @@ else:
     DEVICE = torch.device("cpu")
 
 
-class _PlotArgs:
+BETA = 0.03
+
+
+class _CriterionArgs:
     def __init__(self):
         """
         Initialize the PlotArgs class.
+
+        Parameters
+        ----------
+        window : int
+            The size of the window for computing the criterion.
         """
         self.running_times = []
-        self.criterions = []
         self._elbos_list = []
         self.cumulative_elbo_list = [0]
+        self.new_derivative = 0
+        self.normalized_elbo_list = []
+        self.criterion_list = [1]
+
+    def update_criterion(self, elbo, running_time):
+        self._elbos_list.append(elbo)
+        self.running_times.append(running_time)
+        self.cumulative_elbo_list.append(self.cumulative_elbo + elbo)
+        self.normalized_elbo_list.append(-elbo / self.cumulative_elbo_list[-1])
+        if self.iteration_number > 1:
+            current_derivative = np.abs(
+                (self.normalized_elbo_list[-2] - self.normalized_elbo_list[-1])
+                / (self.running_times[-2] - self.running_times[-1])
+            )
+            old_derivative = self.new_derivative
+            self.new_derivative = (
+                self.new_derivative * (1 - BETA) + current_derivative * BETA
+            )
+            current_hessian = np.abs(
+                (self.new_derivative - self.old_derivative)
+                / (self.running_times[-2] - self.running_times[-1])
+            )
+            self.criterion = self.criterion * (1 - BETA) + current_hessian * BETA
+            self.criterion_list.append(self.criterion)
 
     @property
     def iteration_number(self) -> int:
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 7974a1e4..41cb12a8 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -22,7 +22,7 @@ from ._closed_forms import (
 )
 from .elbos import elbo_plnpca, elbo_zi_pln, profiled_elbo_pln
 from ._utils import (
-    _PlotArgs,
+    _CriterionArgs,
     _format_data,
     _nice_string_of_dict,
     _plot_ellipse,
@@ -111,7 +111,7 @@ class _model(ABC):
             endog, exog, offsets, offsets_formula, take_log_offsets, add_const
         )
         self._fitted = False
-        self._plotargs = _PlotArgs()
+        self._criterion_args = _CriterionArgs()
         if dict_initialization is not None:
             self._set_init_parameters(dict_initialization)
 
@@ -270,7 +270,7 @@ class _model(ABC):
         int
             The number of iterations done.
         """
-        return len(self._plotargs._elbos_list) * self.nb_batches
+        return len(self._criterion_args._elbos_list) * self.nb_batches
 
     @property
     def n_samples(self) -> int:
@@ -385,7 +385,7 @@ class _model(ABC):
         nb_max_iteration: int = 50000,
         *,
         lr: float = 0.01,
-        tol: float = 1e-3,
+        tol: float = 1e-8,
         do_smart_init: bool = True,
         verbose: bool = False,
         batch_size=None,
@@ -400,7 +400,7 @@ class _model(ABC):
         lr : float, optional(keyword-only)
             The learning rate. Defaults to 0.01.
         tol : float, optional(keyword-only)
-            The tolerance for convergence. Defaults to 1e-3.
+            The tolerance for convergence. Defaults to 1e-8.
         do_smart_init : bool, optional(keyword-only)
             Whether to perform smart initialization. Defaults to True.
         verbose : bool, optional(keyword-only)
@@ -414,14 +414,14 @@ class _model(ABC):
         self._batch_size = self._handle_batch_size(batch_size)
         if self._fitted is False:
             self._init_parameters(do_smart_init)
-        elif len(self._plotargs.running_times) > 0:
-            self._beginning_time -= self._plotargs.running_times[-1]
+        elif len(self._criterion_args.running_times) > 0:
+            self._beginning_time -= self._criterion_args.running_times[-1]
         self._put_parameters_to_device()
         self._handle_optimizer(lr)
         stop_condition = False
         while self.nb_iteration_done < nb_max_iteration and not stop_condition:
             loss = self._trainstep()
-            criterion = self._compute_criterion_and_update_plotargs(loss, tol)
+            criterion = self._update_criterion_args(loss)
             if abs(criterion) < tol:
                 stop_condition = True
             if verbose and self.nb_iteration_done % 50 == 1:
@@ -711,14 +711,14 @@ class _model(ABC):
         if stop_condition is True:
             print(
                 f"Tolerance {tol} reached "
-                f"in {self._plotargs.iteration_number} iterations"
+                f"in {self._criterion_args.iteration_number} iterations"
             )
         else:
             print(
                 "Maximum number of iterations reached : ",
-                self._plotargs.iteration_number,
+                self._criterion_args.iteration_number,
                 "last criterion = ",
-                np.round(self._plotargs.criterions[-1], 8),
+                np.round(self._criterion_args.criterions[-1], 8),
             )
 
     def _print_stats(self):
@@ -726,11 +726,11 @@ class _model(ABC):
         Print the training statistics.
         """
         print("-------UPDATE-------")
-        print("Iteration number: ", self._plotargs.iteration_number)
-        print("Criterion: ", np.round(self._plotargs.criterions[-1], 8))
-        print("ELBO:", np.round(self._plotargs._elbos_list[-1], 6))
+        print("Iteration number: ", self._criterion_args.iteration_number)
+        print("Criterion: ", np.round(self._criterion_args.criterions[-1], 8))
+        print("ELBO:", np.round(self._criterion_args._elbos_list[-1], 6))
 
-    def _compute_criterion_and_update_plotargs(self, loss, tol):
+    def _update_criterion_args(self, loss):
         """
         Compute the convergence criterion and update the plot arguments.
 
@@ -738,26 +738,15 @@ class _model(ABC):
         ----------
         loss : torch.Tensor
             The loss value.
-        tol : float
-            The tolerance for convergence.
 
         Returns
         -------
         float
             The computed criterion.
         """
-        self._plotargs._elbos_list.append(-loss)
-        self._plotargs.running_times.append(time.time() - self._beginning_time)
-        elbo = -loss
-        self._plotargs.cumulative_elbo_list.append(
-            self._plotargs.cumulative_elbo + elbo
-        )
-        criterion = (
-            self._plotargs.cumulative_elbo_list[-2]
-            - self._plotargs.cumulative_elbo_list[-1]
-        ) / self._plotargs.cumulative_elbo_list[-1]
-        self._plotargs.criterions.append(criterion)
-        return criterion
+        current_running_time = time.time() - self._beginning_time
+        self._criterion_args.update_criterion(-loss, current_running_time)
+        return self._criterion_args.criterion
 
     def _update_closed_forms(self):
         """
@@ -858,8 +847,8 @@ class _model(ABC):
         if axes is None:
             _, axes = plt.subplots(1, nb_axes, figsize=(23, 5))
         if self._fitted is True:
-            self._plotargs._show_loss(ax=axes[2])
-            self._plotargs._show_stopping_criterion(ax=axes[1])
+            self._criterion_args._show_loss(ax=axes[2])
+            self._criterion_args._show_stopping_criterion(ax=axes[1])
             self.display_covariance(ax=axes[0])
         else:
             self.display_covariance(ax=axes)
@@ -870,7 +859,7 @@ class _model(ABC):
         """
         Property representing the list of ELBO values.
         """
-        return self._plotargs._elbos_list
+        return self._criterion_args._elbos_list
 
     @property
     def loglike(self):
@@ -884,8 +873,8 @@ class _model(ABC):
         """
         if len(self._elbos_list) == 0:
             t0 = time.time()
-            self._plotargs._elbos_list.append(self.compute_elbo().item())
-            self._plotargs.running_times.append(time.time() - t0)
+            self._criterion_args._elbos_list.append(self.compute_elbo().item())
+            self._criterion_args.running_times.append(time.time() - t0)
         return self.n_samples * self._elbos_list[-1]
 
     @property
@@ -1512,7 +1501,7 @@ class Pln(_model):
         nb_max_iteration: int = 50000,
         *,
         lr: float = 0.01,
-        tol: float = 1e-3,
+        tol: float = 1e-8,
         do_smart_init: bool = True,
         verbose: bool = False,
         batch_size: int = None,
@@ -2267,7 +2256,7 @@ class PlnPCAcollection:
         nb_max_iteration: int = 50000,
         *,
         lr: float = 0.01,
-        tol: float = 1e-3,
+        tol: float = 1e-8,
         do_smart_init: bool = True,
         verbose: bool = False,
         batch_size: int = None,
@@ -2282,7 +2271,7 @@ class PlnPCAcollection:
         lr : float, optional(keyword-only)
             The learning rate, by default 0.01.
         tol : float, optional(keyword-only)
-            The tolerance, by default 1e-3.
+            The tolerance, by default 1e-8.
         do_smart_init : bool, optional(keyword-only)
             Whether to do smart initialization, by default True.
         verbose : bool, optional(keyword-only)
@@ -2795,7 +2784,7 @@ class PlnPCA(_model):
         nb_max_iteration: int = 50000,
         *,
         lr: float = 0.01,
-        tol: float = 1e-3,
+        tol: float = 1e-8,
         do_smart_init: bool = True,
         verbose: bool = False,
         batch_size=None,
@@ -3474,7 +3463,7 @@ class ZIPln(_model):
         nb_max_iteration: int = 50000,
         *,
         lr: float = 0.01,
-        tol: float = 1e-3,
+        tol: float = 1e-8,
         do_smart_init: bool = True,
         verbose: bool = False,
         batch_size: int = None,
-- 
GitLab


From 4a22c5fe272112e070bab6be2c5882f88e930fbf Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 12 Oct 2023 10:36:39 +0200
Subject: [PATCH 26/68] minor changes.

---
 pyPLNmodels/_utils.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 582f859f..053a6448 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -42,6 +42,7 @@ class _CriterionArgs:
         self.new_derivative = 0
         self.normalized_elbo_list = []
         self.criterion_list = [1]
+        self.criterion = 1
 
     def update_criterion(self, elbo, running_time):
         self._elbos_list.append(elbo)
@@ -58,7 +59,7 @@ class _CriterionArgs:
                 self.new_derivative * (1 - BETA) + current_derivative * BETA
             )
             current_hessian = np.abs(
-                (self.new_derivative - self.old_derivative)
+                (self.new_derivative - old_derivative)
                 / (self.running_times[-2] - self.running_times[-1])
             )
             self.criterion = self.criterion * (1 - BETA) + current_hessian * BETA
-- 
GitLab


From 07c79d4ac3b034c09431cac9f14c18d77eed6cdb Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 12 Oct 2023 10:42:16 +0200
Subject: [PATCH 27/68] error

---
 pyPLNmodels/models.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 41cb12a8..04ed41de 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -718,7 +718,7 @@ class _model(ABC):
                 "Maximum number of iterations reached : ",
                 self._criterion_args.iteration_number,
                 "last criterion = ",
-                np.round(self._criterion_args.criterions[-1], 8),
+                np.round(self._criterion_args.criterion_list[-1], 8),
             )
 
     def _print_stats(self):
@@ -727,7 +727,7 @@ class _model(ABC):
         """
         print("-------UPDATE-------")
         print("Iteration number: ", self._criterion_args.iteration_number)
-        print("Criterion: ", np.round(self._criterion_args.criterions[-1], 8))
+        print("Criterion: ", np.round(self._criterion_args.criterion_list[-1], 8))
         print("ELBO:", np.round(self._criterion_args._elbos_list[-1], 6))
 
     def _update_criterion_args(self, loss):
-- 
GitLab


From b9cade3447e93e74841edcccde903ed02e3dae2a Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Fri, 13 Oct 2023 09:31:54 +0200
Subject: [PATCH 28/68] add needed abstract methods to implement a new model.

---
 pyPLNmodels/models.py | 67 ++++++++++++++++++++++++-------------------
 1 file changed, 38 insertions(+), 29 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 04ed41de..e8970fec 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -370,6 +370,7 @@ class _model(ABC):
             parameter.requires_grad_(True)
 
     @property
+    @abstractmethod
     def _list_of_parameters_needing_gradient(self):
         """
         A list containing all the parameters that need to be upgraded via a gradient step.
@@ -380,6 +381,41 @@ class _model(ABC):
             List of parameters needing gradient.
         """
 
+    def _print_beginning_message(self) -> str:
+        """
+        Method for printing the beginning message.
+        """
+        print(f"Fitting a {self._NAME} model with {self._description} \n")
+
+    @abstractmethod
+    def _endog_predictions(self):
+        pass
+
+    @abstractmethod
+    def number_of_parameters(self):
+        pass
+
+    @abstractmethod
+    def _compute_elbo_b(self):
+        pass
+
+    @property
+    @abstractmethod
+    def covariance(self):
+        pass
+
+    @covariance.setter
+    @abstractmethod
+    def covariance(self, covariance):
+        pass
+
+    @property
+    @abstractmethod
+    def _description(self):
+        """
+        Describes the model and what it does.
+        """
+
     def fit(
         self,
         nb_max_iteration: int = 50000,
@@ -1793,12 +1829,6 @@ class Pln(_model):
         covariances = components_var @ (sk_components.T.unsqueeze(0))
         return covariances
 
-    def _print_beginning_message(self):
-        """
-        Method for printing the beginning message.
-        """
-        print(f"Fitting a Pln model with {self._description}")
-
     @property
     @_add_doc(
         _model,
@@ -2216,17 +2246,6 @@ class PlnPCAcollection:
         """
         return [model.rank for model in self.values()]
 
-    def _print_beginning_message(self) -> str:
-        """
-        Method for printing the beginning message.
-
-        Returns
-        -------
-        str
-            The beginning message.
-        """
-        return f"Adjusting {len(self.ranks)} Pln models for PCA analysis \n"
-
     @property
     def dim(self) -> int:
         """
@@ -3045,13 +3064,6 @@ class PlnPCA(_model):
         """
         return self._rank
 
-    def _print_beginning_message(self):
-        """
-        Print the beginning message when fitted.
-        """
-        print("-" * NB_CHARACTERS_FOR_NICE_PLOT)
-        print(f"Fitting a PlnPCA model with {self._rank} components")
-
     @property
     def model_parameters(self) -> Dict[str, torch.Tensor]:
         """
@@ -3235,7 +3247,7 @@ class PlnPCA(_model):
     @property
     def _description(self) -> str:
         """
-        Property representing the description.
+        Description output when fitting and printing the model.
 
         Returns
         -------
@@ -3496,7 +3508,7 @@ class ZIPln(_model):
 
     @property
     def _description(self):
-        return "with full covariance model and zero-inflation."
+        return " full covariance model and zero-inflation."
 
     def _random_init_model_parameters(self):
         super()._random_init_model_parameters()
@@ -3513,9 +3525,6 @@ class ZIPln(_model):
         if not hasattr(self, "_coef_inflation"):
             self._coef_inflation = torch.randn(self.nb_cov, self.dim)
 
-    def _print_beginning_message(self):
-        print("Fitting a ZIPln model.")
-
     def _random_init_latent_parameters(self):
         self._dirac = self._endog == 0
         self._latent_mean = torch.randn(self.n_samples, self.dim)
-- 
GitLab


From fe79d40b507cbd069e36e57f9afaf14b06067a76 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Fri, 13 Oct 2023 12:42:27 +0200
Subject: [PATCH 29/68] rrange the abstract methods.

---
 pyPLNmodels/models.py | 543 +++++++++++++++++++++---------------------
 1 file changed, 271 insertions(+), 272 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 41cb12a8..d89c9025 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -324,19 +324,7 @@ class _model(ABC):
             self._coef = None
         self._coef = torch.randn((self.nb_cov, self.dim), device=DEVICE)
 
-    @abstractmethod
-    def _random_init_model_parameters(self):
-        """
-        Abstract method to randomly initialize model parameters.
-        """
-        pass
 
-    @abstractmethod
-    def _random_init_latent_parameters(self):
-        """
-        Abstract method to randomly initialize latent parameters.
-        """
-        pass
 
     def _smart_init_latent_parameters(self):
         """
@@ -369,16 +357,6 @@ class _model(ABC):
         for parameter in self._list_of_parameters_needing_gradient:
             parameter.requires_grad_(True)
 
-    @property
-    def _list_of_parameters_needing_gradient(self):
-        """
-        A list containing all the parameters that need to be upgraded via a gradient step.
-
-        Returns
-        -------
-        List[torch.Tensor]
-            List of parameters needing gradient.
-        """
 
     def fit(
         self,
@@ -579,7 +557,7 @@ class _model(ABC):
         return pca
 
     @property
-    def latent_var(self) -> torch.Tensor:
+    def latent_variance(self) -> torch.Tensor:
         """
         Property representing the latent variance.
 
@@ -689,13 +667,18 @@ class _model(ABC):
         )
         plt.show()
 
+
     @property
-    @abstractmethod
-    def latent_variables(self):
+    def _latent_var(self) -> torch.Tensor:
         """
-        Abstract property representing the latent variables.
+        Property representing the latent variance.
+
+        Returns
+        -------
+        torch.Tensor
+            The latent variance tensor.
         """
-        pass
+        return self._latent_sqrt_var**2
 
     def _print_end_of_fitting_message(self, stop_condition: bool, tol: float):
         """
@@ -754,13 +737,6 @@ class _model(ABC):
         """
         pass
 
-    @abstractmethod
-    def compute_elbo(self):
-        """
-        Compute the Evidence Lower BOund (ELBO) that will be maximized
-        by pytorch.
-        """
-        pass
 
     def display_covariance(self, ax=None, savefig=False, name_file=""):
         """
@@ -1385,8 +1361,74 @@ class _model(ABC):
         ax.legend()
         return ax
 
+    @property
+    @abstractmethod
+    def latent_variables(self) -> torch.Tensor:
+        """
+        Property representing the latent variables.
+
+        Returns
+        -------
+        torch.Tensor
+            The latent variables of size (n_samples, dim).
+        """
+
+    @abstractmethod
+    def compute_elbo(self):
+        """
+        Compute the Evidence Lower BOund (ELBO) that will be maximized
+        by pytorch.
+
+        Returns
+        -------
+        torch.Tensor
+            The computed ELBO.
+        """
+
+    @abstractmethod
+    def _compute_elbo_b(self):
+        """
+        Compute the Evidence Lower BOund (ELBO) for the current mini-batch.
+        Returns
+        -------
+        torch.Tensor
+            The computed ELBO on the current batch.
+        """
+
+    @abstractmethod
+    def _random_init_model_parameters(self):
+        """
+        Abstract method to randomly initialize model parameters.
+        """
+
+    @abstractmethod
+    def _random_init_latent_parameters(self):
+        """
+        Abstract method to randomly initialize latent parameters.
+        """
+    @abstractmethod
+    def _smart_init_latent_parameters(self):
+        """
+        Method for smartly initializing the latent parameters.
+        """
+    @abstractmethod
+    def _smart_init_model_parameters(self):
+        """
+        Method for smartly initializing the model parameters.
+        """
+
+    @property
+    @abstractmethod
+    def _list_of_parameters_needing_gradient(self):
+        """
+        A list containing all the parameters that need to be upgraded via a gradient step.
+
+        Returns
+        -------
+        List[torch.Tensor]
+            List of parameters needing gradient.
+        """
 
-# need to do a good init for M and S
 class Pln(_model):
     """
     Pln class.
@@ -1661,34 +1703,6 @@ class Pln(_model):
             self._offsets + self._latent_mean + 1 / 2 * self._latent_sqrt_var**2
         )
 
-    def _smart_init_latent_parameters(self):
-        """
-        Method for smartly initializing the latent parameters.
-        """
-        self._random_init_latent_parameters()
-
-    def _random_init_latent_parameters(self):
-        """
-        Method for randomly initializing the latent parameters.
-        """
-        if not hasattr(self, "_latent_sqrt_var"):
-            self._latent_sqrt_var = (
-                1 / 2 * torch.ones((self.n_samples, self.dim)).to(DEVICE)
-            )
-        if not hasattr(self, "_latent_mean"):
-            self._latent_mean = torch.ones((self.n_samples, self.dim)).to(DEVICE)
-
-    @property
-    def _list_of_parameters_needing_gradient(self):
-        """
-        Property representing the list of parameters needing gradient.
-
-        Returns
-        -------
-        list
-            The list of parameters needing gradient.
-        """
-        return [self._latent_mean, self._latent_sqrt_var]
 
     def _get_max_components(self):
         """
@@ -1701,60 +1715,6 @@ class Pln(_model):
         """
         return self.dim
 
-    def compute_elbo(self):
-        """
-        Method for computing the evidence lower bound (ELBO).
-
-        Returns
-        -------
-        torch.Tensor
-            The computed ELBO.
-        Examples
-        --------
-        >>> from pyPLNmodels import Pln, get_real_count_data
-        >>> endog, labels = get_real_count_data(return_labels = True)
-        >>> pln = Pln(endog,add_const = True)
-        >>> pln.fit()
-        >>> elbo = pln.compute_elbo()
-        >>> print("elbo", elbo)
-        >>> print("loglike/n", pln.loglike/pln.n_samples)
-        """
-        return profiled_elbo_pln(
-            self._endog,
-            self._exog,
-            self._offsets,
-            self._latent_mean,
-            self._latent_sqrt_var,
-        )
-
-    def _compute_elbo_b(self):
-        """
-        Method for computing the evidence lower bound (ELBO) on the current batch.
-
-        Returns
-        -------
-        torch.Tensor
-            The computed ELBO on the current batch.
-        """
-        return profiled_elbo_pln(
-            self._endog_b,
-            self._exog_b,
-            self._offsets_b,
-            self._latent_mean_b,
-            self._latent_sqrt_var_b,
-        )
-
-    def _smart_init_model_parameters(self):
-        """
-        Method for smartly initializing the model parameters.
-        """
-        # no model parameters since we are doing a profiled ELBO
-
-    def _random_init_model_parameters(self):
-        """
-        Method for randomly initializing the model parameters.
-        """
-        # no model parameters since we are doing a profiled ELBO
 
     @property
     def _coef(self):
@@ -1799,19 +1759,6 @@ class Pln(_model):
         """
         print(f"Fitting a Pln model with {self._description}")
 
-    @property
-    @_add_doc(
-        _model,
-        example="""
-        >>> from pyPLNmodels import Pln, get_real_count_data
-        >>> endog, labels = get_real_count_data(return_labels = True)
-        >>> pln = Pln(endog,add_const = True)
-        >>> pln.fit()
-        >>> print(pln.latent_variables.shape)
-        """,
-    )
-    def latent_variables(self):
-        return self.latent_mean.detach()
 
     @property
     def number_of_parameters(self):
@@ -1861,6 +1808,80 @@ class Pln(_model):
         """
         raise AttributeError("You can not set the covariance for the Pln model.")
 
+    def _random_init_latent_sqrt_var(self):
+        if not hasattr(self, "_latent_sqrt_var"):
+            self._latent_sqrt_var = (
+                1 / 2 * torch.ones((self.n_samples, self.dim)).to(DEVICE)
+            )
+
+    @property
+    @_add_doc(
+        _model,
+        example="""
+        >>> from pyPLNmodels import Pln, get_real_count_data
+        >>> endog, labels = get_real_count_data(return_labels = True)
+        >>> pln = Pln(endog,add_const = True)
+        >>> pln.fit()
+        >>> print(pln.latent_variables.shape)
+        """,
+    )
+    def latent_variables(self):
+        return self.latent_mean.detach()
+
+    @_add_doc(
+            _model,
+            example="""
+            >>> from pyPLNmodels import Pln, get_real_count_data
+            >>> endog, labels = get_real_count_data(return_labels = True)
+            >>> pln = Pln(endog,add_const = True)
+            >>> pln.fit()
+            >>> elbo = pln.compute_elbo()
+            >>> print("elbo", elbo)
+            >>> print("loglike/n", pln.loglike/pln.n_samples)
+            """
+            )
+    def compute_elbo(self):
+        return profiled_elbo_pln(
+            self._endog,
+            self._exog,
+            self._offsets,
+            self._latent_mean,
+            self._latent_sqrt_var,
+        )
+    @_add_doc(_model)
+    def _compute_elbo_b(self):
+        return profiled_elbo_pln(
+            self._endog_b,
+            self._exog_b,
+            self._offsets_b,
+            self._latent_mean_b,
+            self._latent_sqrt_var_b,
+        )
+    @_add_doc(_model)
+    def _smart_init_model_parameters(self):
+        pass
+        # no model parameters since we are doing a profiled ELBO
+
+    @_add_doc(_model)
+    def _random_init_model_parameters(self):
+        pass
+        # no model parameters since we are doing a profiled ELBO
+    @_add_doc(_model)
+    def _smart_init_latent_parameters(self):
+        self._random_init_latent_sqrt_var()
+        if not hasattr(self, "_latent_mean"):
+            self._latent_mean = torch.log(self._endog + (self._endog == 0))
+
+    @_add_doc(_model)
+    def _random_init_latent_parameters(self):
+        self._random_init_latent_sqrt_var()
+        if not hasattr(self, "_latent_mean"):
+            self._latent_mean = torch.ones((self.n_samples, self.dim)).to(DEVICE)
+
+    @_add_doc(_model)
+    @property
+    def _list_of_parameters_needing_gradient(self):
+        return [self._latent_mean, self._latent_sqrt_var]
 
 class PlnPCAcollection:
     """
@@ -2655,7 +2676,7 @@ class PlnPCAcollection:
         return ".BIC, .AIC, .loglikes"
 
 
-# Here, setting the value for each key in _dict_parameters
+# Here, setting the value for each key  _dict_parameters
 class PlnPCA(_model):
     """
     PlnPCA object where the covariance has low rank.
@@ -2881,19 +2902,6 @@ class PlnPCA(_model):
             variables_names=variables_names, indices_of_variables=indices_of_variables
         )
 
-    def _check_if_rank_is_too_high(self):
-        """
-        Check if the rank is too high and issue a warning if necessary.
-        """
-        if self.dim < self.rank:
-            warning_string = (
-                f"\nThe requested rank of approximation {self.rank} "
-                f"is greater than the number of variables {self.dim}. "
-                f"Setting rank to {self.dim}"
-            )
-            warnings.warn(warning_string)
-            self._rank = self.dim
-
     @property
     @_add_doc(
         _model,
@@ -2909,29 +2917,7 @@ class PlnPCA(_model):
     def latent_mean(self) -> torch.Tensor:
         return self._cpu_attribute_or_none("_latent_mean")
 
-    @property
-    def latent_sqrt_var(self) -> torch.Tensor:
-        """
-        Property representing the unsigned square root of the latent variance.
-
-        Returns
-        -------
-        torch.Tensor
-            The latent variance tensor.
-        """
-        return self._cpu_attribute_or_none("_latent_sqrt_var")
-
-    @property
-    def _latent_var(self) -> torch.Tensor:
-        """
-        Property representing the latent variance.
 
-        Returns
-        -------
-        torch.Tensor
-            The latent variance tensor.
-        """
-        return self._latent_sqrt_var**2
 
     def _endog_predictions(self):
         covariance_a_posteriori = torch.sum(
@@ -3064,103 +3050,9 @@ class PlnPCA(_model):
         """
         return {"coef": self.coef, "components": self.components}
 
-    def _smart_init_model_parameters(self):
-        """
-        Initialize the model parameters smartly.
-        """
-        if not hasattr(self, "_coef"):
-            super()._smart_init_coef()
-        if not hasattr(self, "_components"):
-            self._components = _init_components(self._endog, self._exog, self._rank)
 
-    def _random_init_model_parameters(self):
-        """
-        Randomly initialize the model parameters.
-        """
-        super()._random_init_coef()
-        self._components = torch.randn((self.dim, self._rank)).to(DEVICE)
 
-    def _random_init_latent_parameters(self):
-        """
-        Randomly initialize the latent parameters.
-        """
-        self._latent_sqrt_var = (
-            1 / 2 * torch.ones((self.n_samples, self._rank)).to(DEVICE)
-        )
-        self._latent_mean = torch.ones((self.n_samples, self._rank)).to(DEVICE)
 
-    def _smart_init_latent_parameters(self):
-        """
-        Initialize the latent parameters smartly.
-        """
-        if not hasattr(self, "_latent_mean"):
-            self._latent_mean = (
-                _init_latent_mean(
-                    self._endog,
-                    self._exog,
-                    self._offsets,
-                    self._coef,
-                    self._components,
-                )
-                .to(DEVICE)
-                .detach()
-            )
-        if not hasattr(self, "_latent_sqrt_var"):
-            self._latent_sqrt_var = (
-                1 / 2 * torch.ones((self.n_samples, self._rank)).to(DEVICE)
-            )
-
-    @property
-    def _list_of_parameters_needing_gradient(self):
-        """
-        Property representing the list of parameters needing gradient.
-
-        Returns
-        -------
-        List[torch.Tensor]
-            The list of parameters needing gradient.
-        """
-        if self._coef is None:
-            return [self._components, self._latent_mean, self._latent_sqrt_var]
-        return [self._components, self._coef, self._latent_mean, self._latent_sqrt_var]
-
-    def _compute_elbo_b(self) -> torch.Tensor:
-        """
-        Compute the evidence lower bound (ELBO) with the current batch.
-
-        Returns
-        -------
-        torch.Tensor
-            The ELBO value on the current batch.
-        """
-        return elbo_plnpca(
-            self._endog_b,
-            self._exog_b,
-            self._offsets_b,
-            self._latent_mean_b,
-            self._latent_sqrt_var_b,
-            self._components,
-            self._coef,
-        )
-
-    def compute_elbo(self) -> torch.Tensor:
-        """
-        Compute the evidence lower bound (ELBO).
-
-        Returns
-        -------
-        torch.Tensor
-            The ELBO value.
-        """
-        return elbo_plnpca(
-            self._endog,
-            self._exog,
-            self._offsets,
-            self._latent_mean,
-            self._latent_sqrt_var,
-            self._components,
-            self._coef,
-        )
 
     @property
     def number_of_parameters(self) -> int:
@@ -3244,17 +3136,6 @@ class PlnPCA(_model):
         """
         return f" {self.rank} principal component."
 
-    @property
-    def latent_variables(self) -> torch.Tensor:
-        """
-        Property representing the latent variables.
-
-        Returns
-        -------
-        torch.Tensor
-            The latent variables of size (n_samples, dim).
-        """
-        return torch.matmul(self._latent_mean, self._components.T).detach()
 
     @property
     def projected_latent_variables(self) -> torch.Tensor:
@@ -3337,6 +3218,100 @@ class PlnPCA(_model):
             return self.projected_latent_variables
         return self.latent_variables
 
+    @property
+    @_add_doc(
+        _model,
+        example="""
+        >>> from pyPLNmodels import PlnPCA, get_real_count_data
+        >>> endog = get_real_count_data(return_labels=False)
+        >>> pca = PlnPCA(endog,add_const = True)
+        >>> pca.fit()
+        >>> print(pca.latent_variables.shape)
+        """,
+    )
+    def latent_variables(self) -> torch.Tensor:
+        return torch.matmul(self._latent_mean, self._components.T).detach()
+
+    @_add_doc(
+            _model,
+            example="""
+            >>> from pyPLNmodels import PlnPCA, get_real_count_data
+            >>> endog = get_real_count_data(return_labels = False)
+            >>> pca = PlnPCA(endog,add_const = True)
+            >>> pca.fit()
+            >>> elbo = pca.compute_elbo()
+            >>> print("elbo", elbo)
+            >>> print("loglike/n", pln.loglike/pln.n_samples)
+            """
+            )
+    def compute_elbo(self) -> torch.Tensor:
+        return elbo_plnpca(
+            self._endog,
+            self._exog,
+            self._offsets,
+            self._latent_mean,
+            self._latent_sqrt_var,
+            self._components,
+            self._coef,
+        )
+    @_add_doc(_model)
+    def _compute_elbo_b(self) -> torch.Tensor:
+        return elbo_plnpca(
+            self._endog_b,
+            self._exog_b,
+            self._offsets_b,
+            self._latent_mean_b,
+            self._latent_sqrt_var_b,
+            self._components,
+            self._coef,
+        )
+    @_add_doc(_model)
+    def _random_init_model_parameters(self):
+        super()._random_init_coef()
+        self._components = torch.randn((self.dim, self._rank)).to(DEVICE)
+
+    @_add_doc(_model)
+    def _smart_init_model_parameters(self):
+        if not hasattr(self, "_coef"):
+            super()._smart_init_coef()
+        if not hasattr(self, "_components"):
+            self._components = _init_components(self._endog, self._exog, self._rank)
+
+    @_add_doc(_model)
+    def _random_init_latent_parameters(self):
+        """
+        Randomly initialize the latent parameters.
+        """
+        self._latent_sqrt_var = (
+            1 / 2 * torch.ones((self.n_samples, self._rank)).to(DEVICE)
+        )
+        self._latent_mean = torch.ones((self.n_samples, self._rank)).to(DEVICE)
+
+    @_add_doc(_model)
+    def _smart_init_latent_parameters(self):
+        if not hasattr(self, "_latent_mean"):
+            self._latent_mean = (
+                _init_latent_mean(
+                    self._endog,
+                    self._exog,
+                    self._offsets,
+                    self._coef,
+                    self._components,
+                )
+                .to(DEVICE)
+                .detach()
+            )
+        if not hasattr(self, "_latent_sqrt_var"):
+            self._latent_sqrt_var = (
+                1 / 2 * torch.ones((self.n_samples, self._rank)).to(DEVICE)
+            )
+
+    @_add_doc(_model)
+    @property
+    def _list_of_parameters_needing_gradient(self):
+        if self._coef is None:
+            return [self._components, self._latent_mean, self._latent_sqrt_var]
+        return [self._components, self._coef, self._latent_mean, self._latent_sqrt_var]
 
 class ZIPln(_model):
     _NAME = "ZIPln"
@@ -3347,6 +3322,10 @@ class ZIPln(_model):
 
     @_add_doc(
         _model,
+        params= """
+        use_closed_form_prob: bool, optional
+            Whether or not use the closed formula for the latent probability
+        """
         example="""
             >>> from pyPLNmodels import ZIPln, get_real_count_data
             >>> endog= get_real_count_data()
@@ -3532,7 +3511,26 @@ class ZIPln(_model):
     def _covariance(self):
         return self._components @ (self._components.T)
 
-    def latent_variables(self):
+    def latent_variables(self) -> tuple(torch.Tensor, torch.Tensor):
+        """
+        Property representing the latent variables. Two latent
+        variables are available if exog is not None
+
+        Returns
+        -------
+        tuple(torch.Tensor, torch.Tensor)
+            The latent variables of a classic Pln model (size (n_samples, dim))
+            and zero inflated latent variables of size (n_samples, dim).
+        Examples
+        --------
+        >>> from pyPLNmodels import ZIPln, get_real_count_data
+        >>> endog, labels = get_real_count_data(return_labels = True)
+        >>> zi = ZIPln(endog,add_const = True)
+        >>> zi.fit()
+        >>> latent_mean, latent_inflated = zi.latent_variables
+        >>> print(latent_mean.shape)
+        >>> print(latent_inflated.shape)
+        """
         return self.latent_mean, self.latent_prob
 
     def _update_parameters(self):
@@ -3624,6 +3622,7 @@ class ZIPln(_model):
     def number_of_parameters(self):
         return self.dim * (2 * self.nb_cov + (self.dim + 1) / 2)
 
+    @_add_doc(_model)
     @property
     def _list_of_parameters_needing_gradient(self):
         list_parameters = [
-- 
GitLab


From e3bbe240e421d5f6d8bac8a0568154894e29091d Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Fri, 13 Oct 2023 12:44:34 +0200
Subject: [PATCH 30/68] beginning

---
 pyPLNmodels/new_model.py | 9 +++++++++
 1 file changed, 9 insertions(+)
 create mode 100644 pyPLNmodels/new_model.py

diff --git a/pyPLNmodels/new_model.py b/pyPLNmodels/new_model.py
new file mode 100644
index 00000000..2d4acd45
--- /dev/null
+++ b/pyPLNmodels/new_model.py
@@ -0,0 +1,9 @@
+from pyPLNmodels import ZIPln, get_real_count_data
+
+
+endog = get_real_count_data()
+zi = ZIPln(endog, add_const = True)
+zi.fit(nb_max_iteration = 10)
+zi.show()
+
+
-- 
GitLab


From 9ff71980a0638c399a94f60242b685999cc782db Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Fri, 13 Oct 2023 13:21:08 +0200
Subject: [PATCH 31/68] fixed compilation erros.

---
 pyPLNmodels/models.py | 61 ++++++++++++++++++++++---------------------
 1 file changed, 31 insertions(+), 30 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index d89c9025..905b465b 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -1002,30 +1002,6 @@ class _model(ABC):
             )
         self._latent_mean = latent_mean
 
-    @latent_sqrt_var.setter
-    @_array2tensor
-    def latent_sqrt_var(
-        self, latent_sqrt_var: Union[torch.Tensor, np.ndarray, pd.DataFrame]
-    ):
-        """
-        Setter for the latent variance property.
-
-        Parameters
-        ----------
-        latent_sqrt_var : Union[torch.Tensor, np.ndarray, pd.DataFrame]
-            The latent variance.
-
-        Raises
-        ------
-        ValueError
-            If the shape of the latent variance is incorrect.
-        """
-        if latent_sqrt_var.shape != (self.n_samples, self.dim):
-            raise ValueError(
-                f"Wrong shape. Expected {self.n_samples, self.dim}, got {latent_sqrt_var.shape}"
-            )
-        self._latent_sqrt_var = latent_sqrt_var
-
     def _cpu_attribute_or_none(self, attribute_name):
         """
         Get the CPU attribute or return None.
@@ -1760,6 +1736,31 @@ class Pln(_model):
         print(f"Fitting a Pln model with {self._description}")
 
 
+
+    @_model.latent_sqrt_var.setter
+    @_array2tensor
+    def latent_sqrt_var(
+        self, latent_sqrt_var: Union[torch.Tensor, np.ndarray, pd.DataFrame]
+    ):
+        """
+        Setter for the latent variance property.
+
+        Parameters
+        ----------
+        latent_sqrt_var : Union[torch.Tensor, np.ndarray, pd.DataFrame]
+            The latent variance.
+
+        Raises
+        ------
+        ValueError
+            If the shape of the latent variance is incorrect.
+        """
+        if latent_sqrt_var.shape != (self.n_samples, self.dim):
+            raise ValueError(
+                f"Wrong shape. Expected {self.n_samples, self.dim}, got {latent_sqrt_var.shape}"
+            )
+        self._latent_sqrt_var = latent_sqrt_var
+
     @property
     def number_of_parameters(self):
         """
@@ -1878,8 +1879,8 @@ class Pln(_model):
         if not hasattr(self, "_latent_mean"):
             self._latent_mean = torch.ones((self.n_samples, self.dim)).to(DEVICE)
 
-    @_add_doc(_model)
     @property
+    @_add_doc(_model)
     def _list_of_parameters_needing_gradient(self):
         return [self._latent_mean, self._latent_sqrt_var]
 
@@ -2950,7 +2951,7 @@ class PlnPCA(_model):
             )
         self._latent_mean = latent_mean
 
-    @latent_sqrt_var.setter
+    @_model.latent_sqrt_var.setter
     @_array2tensor
     def latent_sqrt_var(self, latent_sqrt_var: torch.Tensor):
         """
@@ -3306,8 +3307,8 @@ class PlnPCA(_model):
                 1 / 2 * torch.ones((self.n_samples, self._rank)).to(DEVICE)
             )
 
-    @_add_doc(_model)
     @property
+    @_add_doc(_model)
     def _list_of_parameters_needing_gradient(self):
         if self._coef is None:
             return [self._components, self._latent_mean, self._latent_sqrt_var]
@@ -3325,7 +3326,7 @@ class ZIPln(_model):
         params= """
         use_closed_form_prob: bool, optional
             Whether or not use the closed formula for the latent probability
-        """
+        """,
         example="""
             >>> from pyPLNmodels import ZIPln, get_real_count_data
             >>> endog= get_real_count_data()
@@ -3511,7 +3512,7 @@ class ZIPln(_model):
     def _covariance(self):
         return self._components @ (self._components.T)
 
-    def latent_variables(self) -> tuple(torch.Tensor, torch.Tensor):
+    def latent_variables(self) -> tuple([torch.Tensor, torch.Tensor]):
         """
         Property representing the latent variables. Two latent
         variables are available if exog is not None
@@ -3622,8 +3623,8 @@ class ZIPln(_model):
     def number_of_parameters(self):
         return self.dim * (2 * self.nb_cov + (self.dim + 1) / 2)
 
-    @_add_doc(_model)
     @property
+    @_add_doc(_model)
     def _list_of_parameters_needing_gradient(self):
         list_parameters = [
             self._latent_mean,
-- 
GitLab


From 59c117760c96b5fd604b7ab4378b73f8150b282a Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Fri, 13 Oct 2023 13:23:15 +0200
Subject: [PATCH 32/68] multiple same parameters in list_of_parameters of the
 zi

---
 pyPLNmodels/models.py | 2 --
 1 file changed, 2 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 905b465b..47231871 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -3629,9 +3629,7 @@ class ZIPln(_model):
         list_parameters = [
             self._latent_mean,
             self._latent_sqrt_var,
-            self._coef_inflation,
             self._components,
-            self._coef,
         ]
         if self._use_closed_form_prob:
             list_parameters.append(self._latent_prob)
-- 
GitLab


From 7b1762af75eb0eec036f10a3e599aabf322f1e90 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Fri, 13 Oct 2023 13:26:54 +0200
Subject: [PATCH 33/68] error when returning batches of zi

---
 pyPLNmodels/models.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 47231871..f5d356ae 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -3372,7 +3372,7 @@ class ZIPln(_model):
     def _return_batch(self, indices, beginning, end):
         pln_batch = super()._return_batch(indices, beginning, end)
         if self._use_closed_form_prob is False:
-            return pln_batch + torch.index_select(self._latent_prob, 0, to_take)
+            return (pln_batch + torch.index_select(self._latent_prob, 0, to_take))
         return pln_batch
 
     def _return_batch(self, indices, beginning, end):
-- 
GitLab


From 9a76ec73155489697309310199b9890509d27b44 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Fri, 13 Oct 2023 13:42:49 +0200
Subject: [PATCH 34/68] zi can be fitted now.

---
 pyPLNmodels/_utils.py |  2 +-
 pyPLNmodels/models.py | 38 ++++++++++++++------------------------
 2 files changed, 15 insertions(+), 25 deletions(-)

diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 053a6448..f5f02942 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -112,7 +112,7 @@ class _CriterionArgs:
         ax = plt.gca() if ax is None else ax
         ax.plot(
             self.running_times,
-            self.criterions,
+            self.criterion_list,
             label="Delta",
         )
         ax.set_yscale("log")
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index f5d356ae..a5ece2aa 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -701,7 +701,7 @@ class _model(ABC):
                 "Maximum number of iterations reached : ",
                 self._criterion_args.iteration_number,
                 "last criterion = ",
-                np.round(self._criterion_args.criterions[-1], 8),
+                np.round(self._criterion_args.criterion_list[-1], 8),
             )
 
     def _print_stats(self):
@@ -710,7 +710,7 @@ class _model(ABC):
         """
         print("-------UPDATE-------")
         print("Iteration number: ", self._criterion_args.iteration_number)
-        print("Criterion: ", np.round(self._criterion_args.criterions[-1], 8))
+        print("Criterion: ", np.round(self._criterion_args.criterion_list[-1], 8))
         print("ELBO:", np.round(self._criterion_args._elbos_list[-1], 6))
 
     def _update_criterion_args(self, loss):
@@ -3366,28 +3366,18 @@ class ZIPln(_model):
 
     def _extract_batch(self, batch):
         super()._extract_batch(batch)
+        self._dirac_b = batch[5]
         if self._use_closed_form_prob is False:
-            self._latent_prob_b = batch[5]
+            self._latent_prob_b = batch[6]
 
     def _return_batch(self, indices, beginning, end):
         pln_batch = super()._return_batch(indices, beginning, end)
+        to_take = torch.tensor(indices[beginning:end]).to(DEVICE)
+        batch = pln_batch + (torch.index_select(self._dirac, 0, to_take),)
         if self._use_closed_form_prob is False:
-            return (pln_batch + torch.index_select(self._latent_prob, 0, to_take))
-        return pln_batch
+            return batch + (torch.index_select(self._latent_prob, 0, to_take),)
+        return batch
 
-    def _return_batch(self, indices, beginning, end):
-        to_take = torch.tensor(indices[beginning:end]).to(DEVICE)
-        if self._exog is not None:
-            exog_b = torch.index_select(self._exog, 0, to_take)
-        else:
-            exog_b = None
-        return (
-            torch.index_select(self._endog, 0, to_take),
-            exog_b,
-            torch.index_select(self._offsets, 0, to_take),
-            torch.index_select(self._latent_mean, 0, to_take),
-            torch.index_select(self._latent_sqrt_var, 0, to_take),
-        )
 
     @classmethod
     @_add_doc(
@@ -3542,15 +3532,15 @@ class ZIPln(_model):
         """
         Project the latent probability since it must be between 0 and 1.
         """
-        if self.use_closed_form_prob is False:
+        if self._use_closed_form_prob is False:
             with torch.no_grad():
-                self._latent_prob = torch.maximum(
-                    self._latent_prob, torch.tensor([0]), out=self._latent_prob
+                self._latent_prob_b = torch.maximum(
+                    self._latent_prob_b, torch.tensor([0]), out=self._latent_prob_b
                 )
-                self._latent_prob = torch.minimum(
-                    self._latent_prob, torch.tensor([1]), out=self._latent_prob
+                self._latent_prob_b = torch.minimum(
+                    self._latent_prob, torch.tensor([1]), out=self._latent_prob_b
                 )
-                self._latent_prob *= self._dirac
+                self._latent_prob_b *= self._dirac_b
 
     @property
     def covariance(self) -> torch.Tensor:
-- 
GitLab


From 3ad5040306d9c259a835227289a5184cd4526e49 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Fri, 13 Oct 2023 18:24:18 +0200
Subject: [PATCH 35/68] add contributing, readme and model from new_model
 branch.

---
 CONTRIBUTING.md       |  71 +++-
 README.md             |  37 ++-
 pyPLNmodels/models.py | 734 ++++++++++++++++++++----------------------
 3 files changed, 438 insertions(+), 404 deletions(-)

diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index faf13f3b..e1fb39dc 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -1,15 +1,72 @@
-# Clone the repo
+# What to work on
 
+A public roadmap will be available soon.
+
+
+## Fork/clone/pull
+
+The typical workflow for contributing is:
+
+1. Fork the `main` branch from the [GitLab repository](https://forgemia.inra.fr/bbatardiere/pyplnmodels).
+2. Clone your fork locally.
+3. Run `pip install pre-commit` if pre-commit is not already installed.
+4. Inside the repository, run 'pre-commit install'.
+5. Commit changes.
+6. Push the changes to your fork.
+7. Send a pull request from your fork back to the original `main` branch.
+
+## How to implement a new model
+You can implement a new model `newmodel` by inheriting from the abstract `_model` class in the `models` module.
+The `newmodel` class should contains at least the following code:
 ```
-git clone git@forgemia.inra.fr:bbatardiere/pyplnmodels
-```
+class newmodel(_model):
+    _NAME=""
+    def _random_init_latent_sqrt_var(self):
+        "Implement here"
+
+    @property
+    def latent_variables(self):
+        "Implement here"
 
-# Install precommit
+    def compute_elbo(self):
+        "Implement here"
 
-In the directory:
+    def _compute_elbo_b(self):
+        "Implement here"
 
+    def _smart_init_model_parameters(self):
+        "Implement here"
+
+    def _random_init_model_parameters(self):
+        "Implement here"
+
+    def _smart_init_latent_parameters(self):
+        "Implement here"
+
+    def _random_init_latent_parameters(self):
+        "Implement here"
+
+    @property
+    def _list_of_parameters_needing_gradient(self):
+        "Implement here"
+    @property
+    def _description(self):
+        "Implement here"
+
+    @property
+    def number_of_parameters(self):
+        "Implement here"
 ```
-pre-commit install
+Then, add `newmodel` in the `__init__.py` file of the pyPLNmodels module.
+If `newmodel` is well implemented, running
 ```
+from pyPLNmodels import newmodel, get_real_count_data
 
-If not found use `pip install pre-commit` before this command.
+endog = get_real_count_data()
+zi = newmodel(endog, add_const = True)
+zi.fit(nb_max_iteration = 10, tol = 0)
+```
+should increase the elbo of the model. You should document your functions with
+[numpy-style
+docstrings](https://numpydoc.readthedocs.io/en/latest/format.html). You can use
+the `_add_doc` decorator to inherit the docstrings of the `_model` class.
diff --git a/README.md b/README.md
index 9401cfe6..f8adaa3f 100644
--- a/README.md
+++ b/README.md
@@ -16,22 +16,10 @@
 <!-- > slides](https://pln-team.github.io/slideshow/) for a -->
 <!-- > comprehensive introduction. -->
 
-## Getting started
-The getting started can be found [here](https://forgemia.inra.fr/bbatardiere/pyplnmodels/-/raw/dev/Getting_started.ipynb?inline=false). If you need just a quick view of the package, see next.
+##  Getting started
+The getting started can be found [here](https://forgemia.inra.fr/bbatardiere/pyplnmodels/-/raw/dev/Getting_started.ipynb?inline=false). If you need just a quick view of the package, see the quickstart next.
 
-## Installation
-
-**pyPLNmodels** is available on
-[pypi](https://pypi.org/project/pyPLNmodels/). The development
-version is available on [GitHub](https://github.com/PLN-team/pyPLNmodels).
-
-### Package installation
-
-```
-pip install pyPLNmodels
-```
-
-## Usage and main fitting functions
+## ⚡️ Quickstart
 
 The package comes with an ecological data set to present the functionality
 ```
@@ -61,7 +49,24 @@ transformed_data = pln.transform()
 ```
 
 
-## References
+## 🛠 Installation
+
+**pyPLNmodels** is available on
+[pypi](https://pypi.org/project/pyPLNmodels/). The development
+version is available on [GitHub](https://github.com/PLN-team/pyPLNmodels).
+
+### Package installation
+
+```
+pip install pyPLNmodels
+```
+
+## 👐 Contributing
+
+Feel free to contribute, but read the [CONTRIBUTING.md](https://forgemia.inra.fr/bbatardiere/pyplnmodels/-/blob/main/CONTRIBUTING.md) first. A public roadmap will be available soon.
+
+
+## ⚡️ Citations
 
 Please cite our work using the following references:
 -   J. Chiquet, M. Mariadassou and S. Robin: Variational inference for
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index e8970fec..e8b316a5 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -324,20 +324,6 @@ class _model(ABC):
             self._coef = None
         self._coef = torch.randn((self.nb_cov, self.dim), device=DEVICE)
 
-    @abstractmethod
-    def _random_init_model_parameters(self):
-        """
-        Abstract method to randomly initialize model parameters.
-        """
-        pass
-
-    @abstractmethod
-    def _random_init_latent_parameters(self):
-        """
-        Abstract method to randomly initialize latent parameters.
-        """
-        pass
-
     def _smart_init_latent_parameters(self):
         """
         Initialize latent parameters smartly.
@@ -369,53 +355,6 @@ class _model(ABC):
         for parameter in self._list_of_parameters_needing_gradient:
             parameter.requires_grad_(True)
 
-    @property
-    @abstractmethod
-    def _list_of_parameters_needing_gradient(self):
-        """
-        A list containing all the parameters that need to be upgraded via a gradient step.
-
-        Returns
-        -------
-        List[torch.Tensor]
-            List of parameters needing gradient.
-        """
-
-    def _print_beginning_message(self) -> str:
-        """
-        Method for printing the beginning message.
-        """
-        print(f"Fitting a {self._NAME} model with {self._description} \n")
-
-    @abstractmethod
-    def _endog_predictions(self):
-        pass
-
-    @abstractmethod
-    def number_of_parameters(self):
-        pass
-
-    @abstractmethod
-    def _compute_elbo_b(self):
-        pass
-
-    @property
-    @abstractmethod
-    def covariance(self):
-        pass
-
-    @covariance.setter
-    @abstractmethod
-    def covariance(self, covariance):
-        pass
-
-    @property
-    @abstractmethod
-    def _description(self):
-        """
-        Describes the model and what it does.
-        """
-
     def fit(
         self,
         nb_max_iteration: int = 50000,
@@ -615,7 +554,7 @@ class _model(ABC):
         return pca
 
     @property
-    def latent_var(self) -> torch.Tensor:
+    def latent_variance(self) -> torch.Tensor:
         """
         Property representing the latent variance.
 
@@ -726,12 +665,16 @@ class _model(ABC):
         plt.show()
 
     @property
-    @abstractmethod
-    def latent_variables(self):
+    def _latent_var(self) -> torch.Tensor:
         """
-        Abstract property representing the latent variables.
+        Property representing the latent variance.
+
+        Returns
+        -------
+        torch.Tensor
+            The latent variance tensor.
         """
-        pass
+        return self._latent_sqrt_var**2
 
     def _print_end_of_fitting_message(self, stop_condition: bool, tol: float):
         """
@@ -790,14 +733,6 @@ class _model(ABC):
         """
         pass
 
-    @abstractmethod
-    def compute_elbo(self):
-        """
-        Compute the Evidence Lower BOund (ELBO) that will be maximized
-        by pytorch.
-        """
-        pass
-
     def display_covariance(self, ax=None, savefig=False, name_file=""):
         """
         Display the covariance matrix.
@@ -1062,30 +997,6 @@ class _model(ABC):
             )
         self._latent_mean = latent_mean
 
-    @latent_sqrt_var.setter
-    @_array2tensor
-    def latent_sqrt_var(
-        self, latent_sqrt_var: Union[torch.Tensor, np.ndarray, pd.DataFrame]
-    ):
-        """
-        Setter for the latent variance property.
-
-        Parameters
-        ----------
-        latent_sqrt_var : Union[torch.Tensor, np.ndarray, pd.DataFrame]
-            The latent variance.
-
-        Raises
-        ------
-        ValueError
-            If the shape of the latent variance is incorrect.
-        """
-        if latent_sqrt_var.shape != (self.n_samples, self.dim):
-            raise ValueError(
-                f"Wrong shape. Expected {self.n_samples, self.dim}, got {latent_sqrt_var.shape}"
-            )
-        self._latent_sqrt_var = latent_sqrt_var
-
     def _cpu_attribute_or_none(self, attribute_name):
         """
         Get the CPU attribute or return None.
@@ -1421,8 +1332,95 @@ class _model(ABC):
         ax.legend()
         return ax
 
+    def _print_beginning_message(self):
+        """
+        Method for printing the beginning message.
+        """
+        print(f"Fitting a {self._NAME} model with {self._description}")
+
+    @property
+    @abstractmethod
+    def latent_variables(self) -> torch.Tensor:
+        """
+        Property representing the latent variables.
+
+        Returns
+        -------
+        torch.Tensor
+            The latent variables of size (n_samples, dim).
+        """
+
+    @abstractmethod
+    def compute_elbo(self):
+        """
+        Compute the Evidence Lower BOund (ELBO) that will be maximized
+        by pytorch.
+
+        Returns
+        -------
+        torch.Tensor
+            The computed ELBO.
+        """
+
+    @abstractmethod
+    def _compute_elbo_b(self):
+        """
+        Compute the Evidence Lower BOund (ELBO) for the current mini-batch.
+        Returns
+        -------
+        torch.Tensor
+            The computed ELBO on the current batch.
+        """
+
+    @abstractmethod
+    def _random_init_model_parameters(self):
+        """
+        Abstract method to randomly initialize model parameters.
+        """
+
+    @abstractmethod
+    def _random_init_latent_parameters(self):
+        """
+        Abstract method to randomly initialize latent parameters.
+        """
+
+    @abstractmethod
+    def _smart_init_latent_parameters(self):
+        """
+        Method for smartly initializing the latent parameters.
+        """
+
+    @abstractmethod
+    def _smart_init_model_parameters(self):
+        """
+        Method for smartly initializing the model parameters.
+        """
+
+    @property
+    @abstractmethod
+    def _list_of_parameters_needing_gradient(self):
+        """
+        A list containing all the parameters that need to be upgraded via a gradient step.
+
+        Returns
+        -------
+        List[torch.Tensor]
+            List of parameters needing gradient.
+        """
+
+    @property
+    @abstractmethod
+    def _description(self):
+        pass
+
+    @property
+    @abstractmethod
+    def number_of_parameters(self):
+        """
+        Number of parameters of the model.
+        """
+
 
-# need to do a good init for M and S
 class Pln(_model):
     """
     Pln class.
@@ -1511,15 +1509,13 @@ class Pln(_model):
         dict_initialization: Optional[Dict[str, torch.Tensor]] = None,
         take_log_offsets: bool = False,
     ):
-        endog, exog, offsets = _extract_data_from_formula(formula, data)
-        return cls(
-            endog,
-            exog=exog,
-            offsets=offsets,
+        super().from_formula(
+            cls=cls,
+            formula=formula,
+            data=data,
             offsets_formula=offsets_formula,
             dict_initialization=dict_initialization,
             take_log_offsets=take_log_offsets,
-            add_const=False,
         )
 
     @_add_doc(
@@ -1697,35 +1693,6 @@ class Pln(_model):
             self._offsets + self._latent_mean + 1 / 2 * self._latent_sqrt_var**2
         )
 
-    def _smart_init_latent_parameters(self):
-        """
-        Method for smartly initializing the latent parameters.
-        """
-        self._random_init_latent_parameters()
-
-    def _random_init_latent_parameters(self):
-        """
-        Method for randomly initializing the latent parameters.
-        """
-        if not hasattr(self, "_latent_sqrt_var"):
-            self._latent_sqrt_var = (
-                1 / 2 * torch.ones((self.n_samples, self.dim)).to(DEVICE)
-            )
-        if not hasattr(self, "_latent_mean"):
-            self._latent_mean = torch.ones((self.n_samples, self.dim)).to(DEVICE)
-
-    @property
-    def _list_of_parameters_needing_gradient(self):
-        """
-        Property representing the list of parameters needing gradient.
-
-        Returns
-        -------
-        list
-            The list of parameters needing gradient.
-        """
-        return [self._latent_mean, self._latent_sqrt_var]
-
     def _get_max_components(self):
         """
         Method for getting the maximum number of components.
@@ -1737,61 +1704,6 @@ class Pln(_model):
         """
         return self.dim
 
-    def compute_elbo(self):
-        """
-        Method for computing the evidence lower bound (ELBO).
-
-        Returns
-        -------
-        torch.Tensor
-            The computed ELBO.
-        Examples
-        --------
-        >>> from pyPLNmodels import Pln, get_real_count_data
-        >>> endog, labels = get_real_count_data(return_labels = True)
-        >>> pln = Pln(endog,add_const = True)
-        >>> pln.fit()
-        >>> elbo = pln.compute_elbo()
-        >>> print("elbo", elbo)
-        >>> print("loglike/n", pln.loglike/pln.n_samples)
-        """
-        return profiled_elbo_pln(
-            self._endog,
-            self._exog,
-            self._offsets,
-            self._latent_mean,
-            self._latent_sqrt_var,
-        )
-
-    def _compute_elbo_b(self):
-        """
-        Method for computing the evidence lower bound (ELBO) on the current batch.
-
-        Returns
-        -------
-        torch.Tensor
-            The computed ELBO on the current batch.
-        """
-        return profiled_elbo_pln(
-            self._endog_b,
-            self._exog_b,
-            self._offsets_b,
-            self._latent_mean_b,
-            self._latent_sqrt_var_b,
-        )
-
-    def _smart_init_model_parameters(self):
-        """
-        Method for smartly initializing the model parameters.
-        """
-        # no model parameters since we are doing a profiled ELBO
-
-    def _random_init_model_parameters(self):
-        """
-        Method for randomly initializing the model parameters.
-        """
-        # no model parameters since we are doing a profiled ELBO
-
     @property
     def _coef(self):
         """
@@ -1829,19 +1741,29 @@ class Pln(_model):
         covariances = components_var @ (sk_components.T.unsqueeze(0))
         return covariances
 
-    @property
-    @_add_doc(
-        _model,
-        example="""
-        >>> from pyPLNmodels import Pln, get_real_count_data
-        >>> endog, labels = get_real_count_data(return_labels = True)
-        >>> pln = Pln(endog,add_const = True)
-        >>> pln.fit()
-        >>> print(pln.latent_variables.shape)
-        """,
-    )
-    def latent_variables(self):
-        return self.latent_mean.detach()
+    @_model.latent_sqrt_var.setter
+    @_array2tensor
+    def latent_sqrt_var(
+        self, latent_sqrt_var: Union[torch.Tensor, np.ndarray, pd.DataFrame]
+    ):
+        """
+        Setter for the latent variance property.
+
+        Parameters
+        ----------
+        latent_sqrt_var : Union[torch.Tensor, np.ndarray, pd.DataFrame]
+            The latent variance.
+
+        Raises
+        ------
+        ValueError
+            If the shape of the latent variance is incorrect.
+        """
+        if latent_sqrt_var.shape != (self.n_samples, self.dim):
+            raise ValueError(
+                f"Wrong shape. Expected {self.n_samples, self.dim}, got {latent_sqrt_var.shape}"
+            )
+        self._latent_sqrt_var = latent_sqrt_var
 
     @property
     def number_of_parameters(self):
@@ -1891,6 +1813,84 @@ class Pln(_model):
         """
         raise AttributeError("You can not set the covariance for the Pln model.")
 
+    def _random_init_latent_sqrt_var(self):
+        if not hasattr(self, "_latent_sqrt_var"):
+            self._latent_sqrt_var = (
+                1 / 2 * torch.ones((self.n_samples, self.dim)).to(DEVICE)
+            )
+
+    @property
+    @_add_doc(
+        _model,
+        example="""
+        >>> from pyPLNmodels import Pln, get_real_count_data
+        >>> endog, labels = get_real_count_data(return_labels = True)
+        >>> pln = Pln(endog,add_const = True)
+        >>> pln.fit()
+        >>> print(pln.latent_variables.shape)
+        """,
+    )
+    def latent_variables(self):
+        return self.latent_mean.detach()
+
+    @_add_doc(
+        _model,
+        example="""
+            >>> from pyPLNmodels import Pln, get_real_count_data
+            >>> endog, labels = get_real_count_data(return_labels = True)
+            >>> pln = Pln(endog,add_const = True)
+            >>> pln.fit()
+            >>> elbo = pln.compute_elbo()
+            >>> print("elbo", elbo)
+            >>> print("loglike/n", pln.loglike/pln.n_samples)
+            """,
+    )
+    def compute_elbo(self):
+        return profiled_elbo_pln(
+            self._endog,
+            self._exog,
+            self._offsets,
+            self._latent_mean,
+            self._latent_sqrt_var,
+        )
+
+    @_add_doc(_model)
+    def _compute_elbo_b(self):
+        return profiled_elbo_pln(
+            self._endog_b,
+            self._exog_b,
+            self._offsets_b,
+            self._latent_mean_b,
+            self._latent_sqrt_var_b,
+        )
+
+    @_add_doc(_model)
+    def _smart_init_model_parameters(self):
+        pass
+        # no model parameters since we are doing a profiled ELBO
+
+    @_add_doc(_model)
+    def _random_init_model_parameters(self):
+        pass
+        # no model parameters since we are doing a profiled ELBO
+
+    @_add_doc(_model)
+    def _smart_init_latent_parameters(self):
+        self._random_init_latent_sqrt_var()
+        if not hasattr(self, "_latent_mean"):
+            self._latent_mean = torch.log(self._endog + (self._endog == 0))
+
+    @_add_doc(_model)
+    def _random_init_latent_parameters(self):
+        self._random_init_latent_sqrt_var()
+        if not hasattr(self, "_latent_mean"):
+            self._latent_mean = torch.ones((self.n_samples, self.dim)).to(DEVICE)
+
+    @property
+    @_add_doc(_model)
+    def _list_of_parameters_needing_gradient(self):
+        return [self._latent_mean, self._latent_sqrt_var]
+
 
 class PlnPCAcollection:
     """
@@ -2246,6 +2246,17 @@ class PlnPCAcollection:
         """
         return [model.rank for model in self.values()]
 
+    def _print_beginning_message(self) -> str:
+        """
+        Method for printing the beginning message.
+
+        Returns
+        -------
+        str
+            The beginning message.
+        """
+        return f"Adjusting {len(self.ranks)} Pln models for PCA analysis \n"
+
     @property
     def dim(self) -> int:
         """
@@ -2674,7 +2685,7 @@ class PlnPCAcollection:
         return ".BIC, .AIC, .loglikes"
 
 
-# Here, setting the value for each key in _dict_parameters
+# Here, setting the value for each key  _dict_parameters
 class PlnPCA(_model):
     """
     PlnPCA object where the covariance has low rank.
@@ -2900,19 +2911,6 @@ class PlnPCA(_model):
             variables_names=variables_names, indices_of_variables=indices_of_variables
         )
 
-    def _check_if_rank_is_too_high(self):
-        """
-        Check if the rank is too high and issue a warning if necessary.
-        """
-        if self.dim < self.rank:
-            warning_string = (
-                f"\nThe requested rank of approximation {self.rank} "
-                f"is greater than the number of variables {self.dim}. "
-                f"Setting rank to {self.dim}"
-            )
-            warnings.warn(warning_string)
-            self._rank = self.dim
-
     @property
     @_add_doc(
         _model,
@@ -2928,30 +2926,6 @@ class PlnPCA(_model):
     def latent_mean(self) -> torch.Tensor:
         return self._cpu_attribute_or_none("_latent_mean")
 
-    @property
-    def latent_sqrt_var(self) -> torch.Tensor:
-        """
-        Property representing the unsigned square root of the latent variance.
-
-        Returns
-        -------
-        torch.Tensor
-            The latent variance tensor.
-        """
-        return self._cpu_attribute_or_none("_latent_sqrt_var")
-
-    @property
-    def _latent_var(self) -> torch.Tensor:
-        """
-        Property representing the latent variance.
-
-        Returns
-        -------
-        torch.Tensor
-            The latent variance tensor.
-        """
-        return self._latent_sqrt_var**2
-
     def _endog_predictions(self):
         covariance_a_posteriori = torch.sum(
             (self._components**2).unsqueeze(0)
@@ -2983,7 +2957,7 @@ class PlnPCA(_model):
             )
         self._latent_mean = latent_mean
 
-    @latent_sqrt_var.setter
+    @_model.latent_sqrt_var.setter
     @_array2tensor
     def latent_sqrt_var(self, latent_sqrt_var: torch.Tensor):
         """
@@ -3076,104 +3050,6 @@ class PlnPCA(_model):
         """
         return {"coef": self.coef, "components": self.components}
 
-    def _smart_init_model_parameters(self):
-        """
-        Initialize the model parameters smartly.
-        """
-        if not hasattr(self, "_coef"):
-            super()._smart_init_coef()
-        if not hasattr(self, "_components"):
-            self._components = _init_components(self._endog, self._exog, self._rank)
-
-    def _random_init_model_parameters(self):
-        """
-        Randomly initialize the model parameters.
-        """
-        super()._random_init_coef()
-        self._components = torch.randn((self.dim, self._rank)).to(DEVICE)
-
-    def _random_init_latent_parameters(self):
-        """
-        Randomly initialize the latent parameters.
-        """
-        self._latent_sqrt_var = (
-            1 / 2 * torch.ones((self.n_samples, self._rank)).to(DEVICE)
-        )
-        self._latent_mean = torch.ones((self.n_samples, self._rank)).to(DEVICE)
-
-    def _smart_init_latent_parameters(self):
-        """
-        Initialize the latent parameters smartly.
-        """
-        if not hasattr(self, "_latent_mean"):
-            self._latent_mean = (
-                _init_latent_mean(
-                    self._endog,
-                    self._exog,
-                    self._offsets,
-                    self._coef,
-                    self._components,
-                )
-                .to(DEVICE)
-                .detach()
-            )
-        if not hasattr(self, "_latent_sqrt_var"):
-            self._latent_sqrt_var = (
-                1 / 2 * torch.ones((self.n_samples, self._rank)).to(DEVICE)
-            )
-
-    @property
-    def _list_of_parameters_needing_gradient(self):
-        """
-        Property representing the list of parameters needing gradient.
-
-        Returns
-        -------
-        List[torch.Tensor]
-            The list of parameters needing gradient.
-        """
-        if self._coef is None:
-            return [self._components, self._latent_mean, self._latent_sqrt_var]
-        return [self._components, self._coef, self._latent_mean, self._latent_sqrt_var]
-
-    def _compute_elbo_b(self) -> torch.Tensor:
-        """
-        Compute the evidence lower bound (ELBO) with the current batch.
-
-        Returns
-        -------
-        torch.Tensor
-            The ELBO value on the current batch.
-        """
-        return elbo_plnpca(
-            self._endog_b,
-            self._exog_b,
-            self._offsets_b,
-            self._latent_mean_b,
-            self._latent_sqrt_var_b,
-            self._components,
-            self._coef,
-        )
-
-    def compute_elbo(self) -> torch.Tensor:
-        """
-        Compute the evidence lower bound (ELBO).
-
-        Returns
-        -------
-        torch.Tensor
-            The ELBO value.
-        """
-        return elbo_plnpca(
-            self._endog,
-            self._exog,
-            self._offsets,
-            self._latent_mean,
-            self._latent_sqrt_var,
-            self._components,
-            self._coef,
-        )
-
     @property
     def number_of_parameters(self) -> int:
         """
@@ -3247,7 +3123,7 @@ class PlnPCA(_model):
     @property
     def _description(self) -> str:
         """
-        Description output when fitting and printing the model.
+        Property representing the description.
 
         Returns
         -------
@@ -3256,18 +3132,6 @@ class PlnPCA(_model):
         """
         return f" {self.rank} principal component."
 
-    @property
-    def latent_variables(self) -> torch.Tensor:
-        """
-        Property representing the latent variables.
-
-        Returns
-        -------
-        torch.Tensor
-            The latent variables of size (n_samples, dim).
-        """
-        return torch.matmul(self._latent_mean, self._components.T).detach()
-
     @property
     def projected_latent_variables(self) -> torch.Tensor:
         """
@@ -3349,6 +3213,103 @@ class PlnPCA(_model):
             return self.projected_latent_variables
         return self.latent_variables
 
+    @property
+    @_add_doc(
+        _model,
+        example="""
+        >>> from pyPLNmodels import PlnPCA, get_real_count_data
+        >>> endog = get_real_count_data(return_labels=False)
+        >>> pca = PlnPCA(endog,add_const = True)
+        >>> pca.fit()
+        >>> print(pca.latent_variables.shape)
+        """,
+    )
+    def latent_variables(self) -> torch.Tensor:
+        return torch.matmul(self._latent_mean, self._components.T).detach()
+
+    @_add_doc(
+        _model,
+        example="""
+            >>> from pyPLNmodels import PlnPCA, get_real_count_data
+            >>> endog = get_real_count_data(return_labels = False)
+            >>> pca = PlnPCA(endog,add_const = True)
+            >>> pca.fit()
+            >>> elbo = pca.compute_elbo()
+            >>> print("elbo", elbo)
+            >>> print("loglike/n", pln.loglike/pln.n_samples)
+            """,
+    )
+    def compute_elbo(self) -> torch.Tensor:
+        return elbo_plnpca(
+            self._endog,
+            self._exog,
+            self._offsets,
+            self._latent_mean,
+            self._latent_sqrt_var,
+            self._components,
+            self._coef,
+        )
+
+    @_add_doc(_model)
+    def _compute_elbo_b(self) -> torch.Tensor:
+        return elbo_plnpca(
+            self._endog_b,
+            self._exog_b,
+            self._offsets_b,
+            self._latent_mean_b,
+            self._latent_sqrt_var_b,
+            self._components,
+            self._coef,
+        )
+
+    @_add_doc(_model)
+    def _random_init_model_parameters(self):
+        super()._random_init_coef()
+        self._components = torch.randn((self.dim, self._rank)).to(DEVICE)
+
+    @_add_doc(_model)
+    def _smart_init_model_parameters(self):
+        if not hasattr(self, "_coef"):
+            super()._smart_init_coef()
+        if not hasattr(self, "_components"):
+            self._components = _init_components(self._endog, self._exog, self._rank)
+
+    @_add_doc(_model)
+    def _random_init_latent_parameters(self):
+        """
+        Randomly initialize the latent parameters.
+        """
+        self._latent_sqrt_var = (
+            1 / 2 * torch.ones((self.n_samples, self._rank)).to(DEVICE)
+        )
+        self._latent_mean = torch.ones((self.n_samples, self._rank)).to(DEVICE)
+
+    @_add_doc(_model)
+    def _smart_init_latent_parameters(self):
+        if not hasattr(self, "_latent_mean"):
+            self._latent_mean = (
+                _init_latent_mean(
+                    self._endog,
+                    self._exog,
+                    self._offsets,
+                    self._coef,
+                    self._components,
+                )
+                .to(DEVICE)
+                .detach()
+            )
+        if not hasattr(self, "_latent_sqrt_var"):
+            self._latent_sqrt_var = (
+                1 / 2 * torch.ones((self.n_samples, self._rank)).to(DEVICE)
+            )
+
+    @property
+    @_add_doc(_model)
+    def _list_of_parameters_needing_gradient(self):
+        if self._coef is None:
+            return [self._components, self._latent_mean, self._latent_sqrt_var]
+        return [self._components, self._coef, self._latent_mean, self._latent_sqrt_var]
+
 
 class ZIPln(_model):
     _NAME = "ZIPln"
@@ -3359,6 +3320,10 @@ class ZIPln(_model):
 
     @_add_doc(
         _model,
+        params="""
+        use_closed_form_prob: bool, optional
+            Whether or not use the closed formula for the latent probability
+        """,
         example="""
             >>> from pyPLNmodels import ZIPln, get_real_count_data
             >>> endog= get_real_count_data()
@@ -3398,28 +3363,17 @@ class ZIPln(_model):
 
     def _extract_batch(self, batch):
         super()._extract_batch(batch)
+        self._dirac_b = batch[5]
         if self._use_closed_form_prob is False:
-            self._latent_prob_b = batch[5]
+            self._latent_prob_b = batch[6]
 
     def _return_batch(self, indices, beginning, end):
         pln_batch = super()._return_batch(indices, beginning, end)
-        if self._use_closed_form_prob is False:
-            return pln_batch + torch.index_select(self._latent_prob, 0, to_take)
-        return pln_batch
-
-    def _return_batch(self, indices, beginning, end):
         to_take = torch.tensor(indices[beginning:end]).to(DEVICE)
-        if self._exog is not None:
-            exog_b = torch.index_select(self._exog, 0, to_take)
-        else:
-            exog_b = None
-        return (
-            torch.index_select(self._endog, 0, to_take),
-            exog_b,
-            torch.index_select(self._offsets, 0, to_take),
-            torch.index_select(self._latent_mean, 0, to_take),
-            torch.index_select(self._latent_sqrt_var, 0, to_take),
-        )
+        batch = pln_batch + (torch.index_select(self._dirac, 0, to_take),)
+        if self._use_closed_form_prob is False:
+            return batch + (torch.index_select(self._latent_prob, 0, to_take),)
+        return batch
 
     @classmethod
     @_add_doc(
@@ -3446,7 +3400,7 @@ class ZIPln(_model):
         offsets_formula: str = "logsum",
         dict_initialization: Optional[Dict[str, torch.Tensor]] = None,
         take_log_offsets: bool = False,
-        use_closed_form: bool = True,
+        use_closed_form_prob: bool = True,
     ):
         endog, exog, offsets = _extract_data_from_formula(formula, data)
         return cls(
@@ -3457,7 +3411,7 @@ class ZIPln(_model):
             dict_initialization=dict_initialization,
             take_log_offsets=take_log_offsets,
             add_const=False,
-            use_closed_form=use_closed_form,
+            use_closed_form_prob=use_closed_form_prob,
         )
 
     @_add_doc(
@@ -3508,7 +3462,7 @@ class ZIPln(_model):
 
     @property
     def _description(self):
-        return " full covariance model and zero-inflation."
+        return "with full covariance model and zero-inflation."
 
     def _random_init_model_parameters(self):
         super()._random_init_model_parameters()
@@ -3541,7 +3495,26 @@ class ZIPln(_model):
     def _covariance(self):
         return self._components @ (self._components.T)
 
-    def latent_variables(self):
+    def latent_variables(self) -> tuple([torch.Tensor, torch.Tensor]):
+        """
+        Property representing the latent variables. Two latent
+        variables are available if exog is not None
+
+        Returns
+        -------
+        tuple(torch.Tensor, torch.Tensor)
+            The latent variables of a classic Pln model (size (n_samples, dim))
+            and zero inflated latent variables of size (n_samples, dim).
+        Examples
+        --------
+        >>> from pyPLNmodels import ZIPln, get_real_count_data
+        >>> endog, labels = get_real_count_data(return_labels = True)
+        >>> zi = ZIPln(endog,add_const = True)
+        >>> zi.fit()
+        >>> latent_mean, latent_inflated = zi.latent_variables
+        >>> print(latent_mean.shape)
+        >>> print(latent_inflated.shape)
+        """
         return self.latent_mean, self.latent_prob
 
     def _update_parameters(self):
@@ -3552,15 +3525,15 @@ class ZIPln(_model):
         """
         Project the latent probability since it must be between 0 and 1.
         """
-        if self.use_closed_form_prob is False:
+        if self._use_closed_form_prob is False:
             with torch.no_grad():
-                self._latent_prob = torch.maximum(
-                    self._latent_prob, torch.tensor([0]), out=self._latent_prob
+                self._latent_prob_b = torch.maximum(
+                    self._latent_prob_b, torch.tensor([0]), out=self._latent_prob_b
                 )
-                self._latent_prob = torch.minimum(
-                    self._latent_prob, torch.tensor([1]), out=self._latent_prob
+                self._latent_prob_b = torch.minimum(
+                    self._latent_prob, torch.tensor([1]), out=self._latent_prob_b
                 )
-                self._latent_prob *= self._dirac
+                self._latent_prob_b *= self._dirac_b
 
     @property
     def covariance(self) -> torch.Tensor:
@@ -3634,13 +3607,12 @@ class ZIPln(_model):
         return self.dim * (2 * self.nb_cov + (self.dim + 1) / 2)
 
     @property
+    @_add_doc(_model)
     def _list_of_parameters_needing_gradient(self):
         list_parameters = [
             self._latent_mean,
             self._latent_sqrt_var,
-            self._coef_inflation,
             self._components,
-            self._coef,
         ]
         if self._use_closed_form_prob:
             list_parameters.append(self._latent_prob)
-- 
GitLab


From c25e4af324567c8b21a8fa8782cffbf89b33db02 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Fri, 13 Oct 2023 18:28:27 +0200
Subject: [PATCH 36/68] typo in the contributin

---
 CONTRIBUTING.md | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index e1fb39dc..530ced7b 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -63,8 +63,8 @@ If `newmodel` is well implemented, running
 from pyPLNmodels import newmodel, get_real_count_data
 
 endog = get_real_count_data()
-zi = newmodel(endog, add_const = True)
-zi.fit(nb_max_iteration = 10, tol = 0)
+model = newmodel(endog, add_const = True)
+model.fit(nb_max_iteration = 10, tol = 0)
 ```
 should increase the elbo of the model. You should document your functions with
 [numpy-style
-- 
GitLab


From 3be6b936d76c4d8551c63e863aa3cb8ff67b54b9 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Mon, 16 Oct 2023 08:16:32 +0200
Subject: [PATCH 37/68] continue to add the contributing and create the tests
 for the ZI.

---
 CONTRIBUTING.md                             |  75 +++++++++--
 pyPLNmodels/models.py                       | 140 +++++++++++---------
 tests/conftest.py                           |  12 +-
 tests/create_readme_and_docstrings_tests.py |   2 +-
 4 files changed, 153 insertions(+), 76 deletions(-)

diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 530ced7b..2f718217 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -21,42 +21,91 @@ The `newmodel` class should contains at least the following code:
 ```
 class newmodel(_model):
     _NAME=""
-    def _random_init_latent_sqrt_var(self):
+    @property
+    def latent_variables(self) -> torch.Tensor:
         "Implement here"
 
-    @property
-    def latent_variables(self):
+    def compute_elbo(self) -> torch.Tensor:
         "Implement here"
 
-    def compute_elbo(self):
+    def _compute_elbo_b(self) -> torch.Tensor:
         "Implement here"
 
-    def _compute_elbo_b(self):
+    def _smart_init_model_parameters(self)-> None:
         "Implement here"
 
-    def _smart_init_model_parameters(self):
+    def _random_init_model_parameters(self)-> None:
         "Implement here"
 
-    def _random_init_model_parameters(self):
+    def _smart_init_latent_parameters(self)-> None:
         "Implement here"
 
-    def _smart_init_latent_parameters(self):
+    def _random_init_latent_parameters(self)-> None:
         "Implement here"
 
-    def _random_init_latent_parameters(self):
+    @property
+    def _list_of_parameters_needing_gradient(self)-> list:
+        "Implement here"
+    @property
+    def _description(self)-> str:
         "Implement here"
 
     @property
-    def _list_of_parameters_needing_gradient(self):
+    def number_of_parameters(self) -> int:
         "Implement here"
+
     @property
-    def _description(self):
+    def model_parameters(self)-> Dict[str, torch.Tensor]:
         "Implement here"
 
     @property
-    def number_of_parameters(self):
+    def latent_parameters(self)-> Dict[str, torch.Tensor]:
         "Implement here"
 ```
+Each value of the 'latent_parameters' dict should be implemented (and protected) both in the
+`_random_init_latent_parameters` and '_smart_init_latent_parameters'.
+Each value of the 'model_parameters' dict should be implemented (and protected) both in the
+`_random_init_model_parameters` and '_smart_init_model_parameters'.
+For example, if you have one model parameters `coef` and latent_parameters `latent_mean` and `latent_var`, you should implement such as
+```py
+class newmodel(_model):
+    @property
+    def model_parameters(self) -> Dict[str, torch.Tensor]:
+        return {"coef":self.coef}
+    @property
+    def latent_parameters(self) -> Dict[str, torch.Tensor]:
+        return {"latent_mean":self.latent_mean, "latent_var":self.latent_var}
+
+    def _random_init_latent_parameters(self):
+        self._latent_mean = init_latent_mean()
+        self._latent_var = init_latent_var()
+
+    @property
+    def _smart_init_model_parameters(self):
+        self._latent_mean = random_init_latent_mean()
+        self._latent_var = random_init_latent_var()
+
+    @property
+    def latent_var(self):
+        return self._latent_var
+
+    @property
+    def latent_mean(self):
+        return self._latent_mean
+
+    def _random_init_model_parameters(self):
+        self._coef = init_coef()
+
+    def _smart_init_model_parameters(self):
+        self._coef = random_init_latent_coef()
+
+    @property
+    def coef(self):
+        return self._coef
+```
+
+
+
 Then, add `newmodel` in the `__init__.py` file of the pyPLNmodels module.
 If `newmodel` is well implemented, running
 ```
@@ -69,4 +118,4 @@ model.fit(nb_max_iteration = 10, tol = 0)
 should increase the elbo of the model. You should document your functions with
 [numpy-style
 docstrings](https://numpydoc.readthedocs.io/en/latest/format.html). You can use
-the `_add_doc` decorator to inherit the docstrings of the `_model` class.
+the `_add_doc` decorator (implemented in the `_utils` module) to inherit the docstrings of the `_model` class.
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index e8b316a5..9228de82 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -360,7 +360,7 @@ class _model(ABC):
         nb_max_iteration: int = 50000,
         *,
         lr: float = 0.01,
-        tol: float = 1e-8,
+        tol: float = 1e-3,
         do_smart_init: bool = True,
         verbose: bool = False,
         batch_size=None,
@@ -872,33 +872,6 @@ class _model(ABC):
         """
         return -self.loglike + self.number_of_parameters
 
-    @property
-    def latent_parameters(self):
-        """
-        Property representing the latent parameters.
-
-        Returns
-        -------
-        dict
-            The dictionary of latent parameters.
-        """
-        return {
-            "latent_sqrt_var": self.latent_sqrt_var,
-            "latent_mean": self.latent_mean,
-        }
-
-    @property
-    def model_parameters(self):
-        """
-        Property representing the model parameters.
-
-        Returns
-        -------
-        dict
-            The dictionary of model parameters.
-        """
-        return {"coef": self.coef, "covariance": self.covariance}
-
     @property
     def dict_data(self):
         """
@@ -1284,18 +1257,6 @@ class _model(ABC):
         """
         return f"{self._NAME}_nbcov_{self.nb_cov}_dim_{self.dim}"
 
-    @property
-    def _path_to_directory(self):
-        """
-        Property representing the path to the directory.
-
-        Returns
-        -------
-        str
-            The path to the directory.
-        """
-        return ""
-
     def plot_expected_vs_true(self, ax=None, colors=None):
         """
         Plot the predicted value of the endog against the endog.
@@ -1420,6 +1381,30 @@ class _model(ABC):
         Number of parameters of the model.
         """
 
+    @property
+    @abstractmethod
+    def model_parameters(self) -> Dict[str, torch.Tensor]:
+        """
+        Property representing the model parameters.
+
+        Returns
+        -------
+        dict
+            The dictionary of model parameters.
+        """
+
+    @property
+    @abstractmethod
+    def latent_parameters(self) -> Dict[str, torch.Tensor]:
+        """
+        Property representing the latent parameters.
+
+        Returns
+        -------
+        dict
+            The dictionary of latent parameters.
+        """
+
 
 class Pln(_model):
     """
@@ -1509,8 +1494,7 @@ class Pln(_model):
         dict_initialization: Optional[Dict[str, torch.Tensor]] = None,
         take_log_offsets: bool = False,
     ):
-        super().from_formula(
-            cls=cls,
+        return super().from_formula(
             formula=formula,
             data=data,
             offsets_formula=offsets_formula,
@@ -1533,7 +1517,7 @@ class Pln(_model):
         nb_max_iteration: int = 50000,
         *,
         lr: float = 0.01,
-        tol: float = 1e-8,
+        tol: float = 1e-3,
         do_smart_init: bool = True,
         verbose: bool = False,
         batch_size: int = None,
@@ -1686,7 +1670,8 @@ class Pln(_model):
         ------
         AttributeError since you can not set the coef in the Pln model.
         """
-        raise AttributeError("You can not set the coef in the Pln model.")
+        msg = "You can not set the coef in the Pln model."
+        warnings.warn(msg)
 
     def _endog_predictions(self):
         return torch.exp(
@@ -1811,7 +1796,7 @@ class Pln(_model):
         covariance : torch.Tensor
             The covariance matrix.
         """
-        raise AttributeError("You can not set the covariance for the Pln model.")
+        warnings.warn("You can not set the covariance for the Pln model.")
 
     def _random_init_latent_sqrt_var(self):
         if not hasattr(self, "_latent_sqrt_var"):
@@ -1891,6 +1876,19 @@ class Pln(_model):
     def _list_of_parameters_needing_gradient(self):
         return [self._latent_mean, self._latent_sqrt_var]
 
+    @property
+    @_add_doc(_model)
+    def model_parameters(self) -> Dict[str, torch.Tensor]:
+        return {"coef": self.coef, "covariance": self.covariance}
+
+    @property
+    @_add_doc(_model)
+    def latent_parameters(self):
+        return {
+            "latent_sqrt_var": self.latent_sqrt_var,
+            "latent_mean": self.latent_mean,
+        }
+
 
 class PlnPCAcollection:
     """
@@ -2286,7 +2284,7 @@ class PlnPCAcollection:
         nb_max_iteration: int = 50000,
         *,
         lr: float = 0.01,
-        tol: float = 1e-8,
+        tol: float = 1e-3,
         do_smart_init: bool = True,
         verbose: bool = False,
         batch_size: int = None,
@@ -2814,7 +2812,7 @@ class PlnPCA(_model):
         nb_max_iteration: int = 50000,
         *,
         lr: float = 0.01,
-        tol: float = 1e-8,
+        tol: float = 1e-3,
         do_smart_init: bool = True,
         verbose: bool = False,
         batch_size=None,
@@ -3038,18 +3036,6 @@ class PlnPCA(_model):
         """
         return self._rank
 
-    @property
-    def model_parameters(self) -> Dict[str, torch.Tensor]:
-        """
-        Property representing the model parameters.
-
-        Returns
-        -------
-        Dict[str, torch.Tensor]
-            The model parameters.
-        """
-        return {"coef": self.coef, "components": self.components}
-
     @property
     def number_of_parameters(self) -> int:
         """
@@ -3310,6 +3296,19 @@ class PlnPCA(_model):
             return [self._components, self._latent_mean, self._latent_sqrt_var]
         return [self._components, self._coef, self._latent_mean, self._latent_sqrt_var]
 
+    @property
+    @_add_doc(_model)
+    def model_parameters(self) -> Dict[str, torch.Tensor]:
+        return {"coef": self.coef, "components": self.components}
+
+    @property
+    @_add_doc(_model)
+    def latent_parameters(self):
+        return {
+            "latent_sqrt_var": self.latent_sqrt_var,
+            "latent_mean": self.latent_mean,
+        }
+
 
 class ZIPln(_model):
     _NAME = "ZIPln"
@@ -3429,7 +3428,7 @@ class ZIPln(_model):
         nb_max_iteration: int = 50000,
         *,
         lr: float = 0.01,
-        tol: float = 1e-8,
+        tol: float = 1e-3,
         do_smart_init: bool = True,
         verbose: bool = False,
         batch_size: int = None,
@@ -3495,6 +3494,7 @@ class ZIPln(_model):
     def _covariance(self):
         return self._components @ (self._components.T)
 
+    @property
     def latent_variables(self) -> tuple([torch.Tensor, torch.Tensor]):
         """
         Property representing the latent variables. Two latent
@@ -3624,6 +3624,26 @@ class ZIPln(_model):
     def _update_closed_forms(self):
         pass
 
+    @property
+    @_add_doc(_model)
+    def model_parameters(self) -> Dict[str, torch.Tensor]:
+        return {
+            "coef": self.coef,
+            "components": self.components,
+            "coef_inflation": self.coef_inflation,
+        }
+
+    @property
+    @_add_doc(_model)
+    def latent_parameters(self):
+        latent_param = {
+            "latent_sqrt_var": self.latent_sqrt_var,
+            "latent_mean": self.latent_mean,
+        }
+        if self._use_closed_form_prob is True:
+            latent_param["latent_prob"] = self.latent_prob
+        return latent_param
+
     def grad_M(self):
         if self.use_closed_form_prob is True:
             latent_prob = self.closed_formula_latent_prob
diff --git a/tests/conftest.py b/tests/conftest.py
index 3a072f20..588b3e4e 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -7,7 +7,7 @@ from pytest_lazyfixture import lazy_fixture as lf
 import pandas as pd
 
 from pyPLNmodels import load_model, load_plnpcacollection
-from pyPLNmodels.models import Pln, PlnPCA, PlnPCAcollection
+from pyPLNmodels.models import Pln, PlnPCA, PlnPCAcollection, ZIPln
 
 
 sys.path.append("../")
@@ -78,14 +78,20 @@ def convenient_PlnPCAcollection(*args, **kwargs):
 
 
 def convenientpln(*args, **kwargs):
+    # no need to dict init since we do not have ranks
     if isinstance(args[0], str):
         return Pln.from_formula(*args, **kwargs)
     return Pln(*args, **kwargs)
 
 
+def convenientzi(*args, **kwargs):
+    if isinstance(args[0], str):
+        return ZIPln.from_formula(*args, **kwargs)
+    return ZIPln(*args, **kwargs)
+
+
 def generate_new_model(model, *args, **kwargs):
     name_dir = model._directory_name
-    print("directory name", name_dir)
     name = model._NAME
     if name in ("Pln", "PlnPCA"):
         path = model._path_to_directory + name_dir
@@ -94,6 +100,8 @@ def generate_new_model(model, *args, **kwargs):
             new = convenientpln(*args, **kwargs, dict_initialization=init)
         if name == "PlnPCA":
             new = convenient_PlnPCA(*args, **kwargs, dict_initialization=init)
+        if name == "ZIPln":
+            new = convenientzi(*args, **kwargs, dict_initialization=init)
     if name == "PlnPCAcollection":
         init = load_plnpcacollection(name_dir)
         new = convenient_PlnPCAcollection(*args, **kwargs, dict_initialization=init)
diff --git a/tests/create_readme_and_docstrings_tests.py b/tests/create_readme_and_docstrings_tests.py
index d9f27aeb..63aecf9d 100644
--- a/tests/create_readme_and_docstrings_tests.py
+++ b/tests/create_readme_and_docstrings_tests.py
@@ -43,7 +43,7 @@ def get_example_readme(lines):
                     in_example = False
             elif in_example is True:
                 example.append(line)
-    example.pop(0)  # The first is pip install pyPLNmodels which is not python code.
+    example.pop()  # The last line is pip install pyPLNmodels which is not python code.
     return [example]
 
 
-- 
GitLab


From c367d3b4406b886630a878b31a821e4dc6159a91 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Mon, 16 Oct 2023 11:41:59 +0200
Subject: [PATCH 38/68] renamed pln with model inf fixtures and fixed some
 tests for the pln. Now we cannot set exog to None and add_const to False for
 the ZIPln.

---
 pyPLNmodels/models.py  | 142 ++++++++++++++++------
 tests/conftest.py      | 259 +++++++++++++++++++++--------------------
 tests/test_common.py   |  95 +++++++--------
 tests/test_pln_full.py |   6 +-
 4 files changed, 289 insertions(+), 213 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 9228de82..0588c469 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -397,6 +397,7 @@ class _model(ABC):
         while self.nb_iteration_done < nb_max_iteration and not stop_condition:
             loss = self._trainstep()
             criterion = self._update_criterion_args(loss)
+            print("criterion", criterion)
             if abs(criterion) < tol:
                 stop_condition = True
             if verbose and self.nb_iteration_done % 50 == 1:
@@ -3317,26 +3318,6 @@ class ZIPln(_model):
     _coef_inflation: torch.Tensor
     _dirac: torch.Tensor
 
-    @_add_doc(
-        _model,
-        params="""
-        use_closed_form_prob: bool, optional
-            Whether or not use the closed formula for the latent probability
-        """,
-        example="""
-            >>> from pyPLNmodels import ZIPln, get_real_count_data
-            >>> endog= get_real_count_data()
-            >>> zi = ZIPln(endog, add_const = True)
-            >>> zi.fit()
-            >>> print(zi)
-        """,
-        returns="""
-            ZIPln
-        """,
-        see_also="""
-        :func:`pyPLNmodels.ZIPln.from_formula`
-        """,
-    )
     def __init__(
         self,
         endog: Optional[Union[torch.Tensor, np.ndarray, pd.DataFrame]],
@@ -3349,6 +3330,54 @@ class ZIPln(_model):
         add_const: bool = True,
         use_closed_form_prob: bool = False,
     ):
+        """
+        Initializes the ZIPln class.
+
+        Parameters
+        ----------
+        endog : Union[torch.Tensor, np.ndarray, pd.DataFrame]
+            The count data.
+        exog : Union[torch.Tensor, np.ndarray, pd.DataFrame], optional(keyword-only)
+            The covariate data. Defaults to None.
+        offsets : Union[torch.Tensor, np.ndarray, pd.DataFrame], optional(keyword-only)
+            The offsets data. Defaults to None.
+        offsets_formula : str, optional(keyword-only)
+            The formula for offsets. Defaults to "logsum". Overriden if
+            offsets is not None.
+        dict_initialization : dict, optional(keyword-only)
+            The initialization dictionary. Defaults to None.
+        take_log_offsets : bool, optional(keyword-only)
+            Whether to take the log of offsets. Defaults to False.
+        add_const : bool, optional(keyword-only)
+            Whether to add a column of one in the exog. Defaults to True.
+            If exog is None, add_const is set to True anyway and a warnings
+            is launched.
+        use_closed_form_prob : bool, optional
+            Whether or not use the closed formula for the latent probability.
+            Default is False.
+        Raises
+        ------
+        ValueError
+            If the batch_size is greater than the number of samples, or not int.
+        Returns
+        -------
+        A ZIPln object
+        See also
+        --------
+        :func:`pyPLNmodels.ZIPln.from_formula`
+        Examples
+        --------
+        >>> from pyPLNmodels import ZIPln, get_real_count_data
+        >>> endog= get_real_count_data()
+        >>> zi = ZIPln(endog, add_const = True)
+        >>> zi.fit()
+        >>> print(zi)
+        """
+        if exog is None and add_const is False:
+            msg = "No covariates has been given. An intercept is added since "
+            msg += "a ZIPln must have at least an intercept."
+            warnings.warn(msg)
+            add_const = True
         super().__init__(
             endog=endog,
             exog=exog,
@@ -3375,22 +3404,6 @@ class ZIPln(_model):
         return batch
 
     @classmethod
-    @_add_doc(
-        _model,
-        example="""
-            >>> from pyPLNmodels import ZIPln, get_real_count_data
-            >>> endog = get_real_count_data()
-            >>> data = {"endog": endog}
-            >>> zi = ZIPln.from_formula("endog ~ 1", data = data)
-        """,
-        returns="""
-            ZIPln
-        """,
-        see_also="""
-        :class:`pyPLNmodels.ZIPln`
-        :func:`pyPLNmodels.ZIPln.__init__`
-    """,
-    )
     def from_formula(
         cls,
         formula: str,
@@ -3401,6 +3414,39 @@ class ZIPln(_model):
         take_log_offsets: bool = False,
         use_closed_form_prob: bool = True,
     ):
+        """
+        Create a model instance from a formula and data.
+
+        Parameters
+        ----------
+        formula : str
+            The formula.
+        data : dict
+            The data dictionary. Each value can be either a torch.Tensor,
+            a np.ndarray or pd.DataFrame
+        offsets_formula : str, optional(keyword-only)
+            The formula for offsets. Defaults to "logsum".
+        dict_initialization : dict, optional(keyword-only)
+            The initialization dictionary. Defaults to None.
+        take_log_offsets : bool, optional(keyword-only)
+            Whether to take the log of offsets. Defaults to False.
+        use_closed_form_prob : bool, optional
+            Whether or not use the closed formula for the latent probability.
+            Default is False.
+        Returns
+        -------
+        A ZIPln object
+        See also
+        --------
+        :class:`pyPLNmodels.ZIPln`
+        :func:`pyPLNmodels.ZIPln.__init__`
+        Examples
+        --------
+        >>> from pyPLNmodels import ZIPln, get_real_count_data
+        >>> endog = get_real_count_data()
+        >>> data = {"endog": endog}
+        >>> zi = ZIPln.from_formula("endog ~ 1", data = data)
+        """
         endog, exog, offsets = _extract_data_from_formula(formula, data)
         return cls(
             endog,
@@ -3494,6 +3540,18 @@ class ZIPln(_model):
     def _covariance(self):
         return self._components @ (self._components.T)
 
+    @property
+    def components(self) -> torch.Tensor:
+        """
+        Property representing the components.
+
+        Returns
+        -------
+        torch.Tensor
+            The components.
+        """
+        return self._cpu_attribute_or_none("_components")
+
     @property
     def latent_variables(self) -> tuple([torch.Tensor, torch.Tensor]):
         """
@@ -3517,6 +3575,18 @@ class ZIPln(_model):
         """
         return self.latent_mean, self.latent_prob
 
+    @property
+    def coef_inflation(self):
+        """
+        Property representing the coefficients of the zero inflated model.
+
+        Returns
+        -------
+        torch.Tensor or None
+            The coefficients or None.
+        """
+        return self._cpu_attribute_or_none("_coef_inflation")
+
     def _update_parameters(self):
         super()._update_parameters()
         self._project_latent_prob()
diff --git a/tests/conftest.py b/tests/conftest.py
index 588b3e4e..e40558dd 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -12,10 +12,10 @@ from pyPLNmodels.models import Pln, PlnPCA, PlnPCAcollection, ZIPln
 
 sys.path.append("../")
 
-pytest_plugins = [
-    fixture_file.replace("/", ".").replace(".py", "")
-    for fixture_file in glob.glob("src/**/tests/fixtures/[!__]*.py", recursive=True)
-]
+# pytest_plugins = [
+#     fixture_file.replace("/", ".").replace(".py", "")
+#     for fixture_file in glob.glob("src/**/tests/fixtures/[!__]*.py", recursive=True)
+# ]
 
 
 from tests.import_data import (
@@ -78,7 +78,6 @@ def convenient_PlnPCAcollection(*args, **kwargs):
 
 
 def convenientpln(*args, **kwargs):
-    # no need to dict init since we do not have ranks
     if isinstance(args[0], str):
         return Pln.from_formula(*args, **kwargs)
     return Pln(*args, **kwargs)
@@ -93,8 +92,8 @@ def convenientzi(*args, **kwargs):
 def generate_new_model(model, *args, **kwargs):
     name_dir = model._directory_name
     name = model._NAME
-    if name in ("Pln", "PlnPCA"):
-        path = model._path_to_directory + name_dir
+    if name in ("Pln", "PlnPCA", "ZIPln"):
+        path = model._directory_name
         init = load_model(path)
         if name == "Pln":
             new = convenientpln(*args, **kwargs, dict_initialization=init)
@@ -103,7 +102,7 @@ def generate_new_model(model, *args, **kwargs):
         if name == "ZIPln":
             new = convenientzi(*args, **kwargs, dict_initialization=init)
     if name == "PlnPCAcollection":
-        init = load_plnpcacollection(name_dir)
+        init = load_plnpcacollection(model._directory_name)
         new = convenient_PlnPCAcollection(*args, **kwargs, dict_initialization=init)
     return new
 
@@ -119,67 +118,67 @@ def cache(func):
     return new_func
 
 
-params = [convenientpln, convenient_PlnPCA, convenient_PlnPCAcollection]
+params = [convenientpln, convenient_PlnPCA, convenient_PlnPCAcollection, convenientzi]
 dict_fixtures = {}
 
 
 @pytest.fixture(params=params)
-def simulated_pln_0cov_array(request):
+def simulated_model_0cov_array(request):
     cls = request.param
-    pln = cls(
+    model = cls(
         endog_sim_0cov,
         exog=exog_sim_0cov,
         offsets=offsets_sim_0cov,
         add_const=False,
     )
-    return pln
+    return model
 
 
 @pytest.fixture(params=params)
 @cache
-def simulated_fitted_pln_0cov_array(request):
+def simulated_fitted_model_0cov_array(request):
     cls = request.param
-    pln = cls(
+    model = cls(
         endog_sim_0cov,
         exog=exog_sim_0cov,
         offsets=offsets_sim_0cov,
         add_const=False,
     )
-    pln.fit()
-    return pln
+    model.fit()
+    return model
 
 
 @pytest.fixture(params=params)
-def simulated_pln_0cov_formula(request):
+def simulated_model_0cov_formula(request):
     cls = request.param
-    pln = cls("endog ~ 0", data_sim_0cov)
-    return pln
+    model = cls("endog ~ 0", data_sim_0cov)
+    return model
 
 
 @pytest.fixture(params=params)
 @cache
-def simulated_fitted_pln_0cov_formula(request):
+def simulated_fitted_model_0cov_formula(request):
     cls = request.param
-    pln = cls("endog ~ 0", data_sim_0cov)
-    pln.fit()
-    return pln
+    model = cls("endog ~ 0", data_sim_0cov)
+    model.fit()
+    return model
 
 
 @pytest.fixture
-def simulated_loaded_pln_0cov_formula(simulated_fitted_pln_0cov_formula):
-    simulated_fitted_pln_0cov_formula.save()
+def simulated_loaded_model_0cov_formula(simulated_fitted_model_0cov_formula):
+    simulated_fitted_model_0cov_formula.save()
     return generate_new_model(
-        simulated_fitted_pln_0cov_formula,
+        simulated_fitted_model_0cov_formula,
         "endog ~ 0",
         data_sim_0cov,
     )
 
 
 @pytest.fixture
-def simulated_loaded_pln_0cov_array(simulated_fitted_pln_0cov_array):
-    simulated_fitted_pln_0cov_array.save()
+def simulated_loaded_model_0cov_array(simulated_fitted_model_0cov_array):
+    simulated_fitted_model_0cov_array.save()
     return generate_new_model(
-        simulated_fitted_pln_0cov_array,
+        simulated_fitted_model_0cov_array,
         endog_sim_0cov,
         exog=exog_sim_0cov,
         offsets=offsets_sim_0cov,
@@ -187,87 +186,89 @@ def simulated_loaded_pln_0cov_array(simulated_fitted_pln_0cov_array):
     )
 
 
-sim_pln_0cov_instance = [
-    "simulated_pln_0cov_array",
-    "simulated_pln_0cov_formula",
+sim_model_0cov_instance = [
+    "simulated_model_0cov_array",
+    "simulated_model_0cov_formula",
 ]
 
-instances = sim_pln_0cov_instance + instances
+instances = sim_model_0cov_instance + instances
 
 dict_fixtures = add_list_of_fixture_to_dict(
-    dict_fixtures, "sim_pln_0cov_instance", sim_pln_0cov_instance
+    dict_fixtures, "sim_model_0cov_instance", sim_model_0cov_instance
 )
 
-sim_pln_0cov_fitted = [
-    "simulated_fitted_pln_0cov_array",
-    "simulated_fitted_pln_0cov_formula",
+sim_model_0cov_fitted = [
+    "simulated_fitted_model_0cov_array",
+    "simulated_fitted_model_0cov_formula",
 ]
 
 dict_fixtures = add_list_of_fixture_to_dict(
-    dict_fixtures, "sim_pln_0cov_fitted", sim_pln_0cov_fitted
+    dict_fixtures, "sim_model_0cov_fitted", sim_model_0cov_fitted
 )
 
-sim_pln_0cov_loaded = [
-    "simulated_loaded_pln_0cov_array",
-    "simulated_loaded_pln_0cov_formula",
+sim_model_0cov_loaded = [
+    "simulated_loaded_model_0cov_array",
+    "simulated_loaded_model_0cov_formula",
 ]
 
 dict_fixtures = add_list_of_fixture_to_dict(
-    dict_fixtures, "sim_pln_0cov_loaded", sim_pln_0cov_loaded
+    dict_fixtures, "sim_model_0cov_loaded", sim_model_0cov_loaded
 )
 
-sim_pln_0cov = sim_pln_0cov_instance + sim_pln_0cov_fitted + sim_pln_0cov_loaded
-dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "sim_pln_0cov", sim_pln_0cov)
+sim_model_0cov = sim_model_0cov_instance + sim_model_0cov_fitted + sim_model_0cov_loaded
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "sim_model_0cov", sim_model_0cov
+)
 
 
 @pytest.fixture(params=params)
 @cache
-def simulated_pln_2cov_array(request):
+def simulated_model_2cov_array(request):
     cls = request.param
-    pln_full = cls(
+    model = cls(
         endog_sim_2cov,
         exog=exog_sim_2cov,
         offsets=offsets_sim_2cov,
         add_const=False,
     )
-    return pln_full
+    return model
 
 
 @pytest.fixture
-def simulated_fitted_pln_2cov_array(simulated_pln_2cov_array):
-    simulated_pln_2cov_array.fit()
-    return simulated_pln_2cov_array
+def simulated_fitted_model_2cov_array(simulated_model_2cov_array):
+    simulated_model_2cov_array.fit()
+    return simulated_model_2cov_array
 
 
 @pytest.fixture(params=params)
 @cache
-def simulated_pln_2cov_formula(request):
+def simulated_model_2cov_formula(request):
     cls = request.param
-    pln_full = cls("endog ~ 0 + exog", data_sim_2cov)
-    return pln_full
+    model = cls("endog ~ 0 + exog", data_sim_2cov)
+    return model
 
 
 @pytest.fixture
-def simulated_fitted_pln_2cov_formula(simulated_pln_2cov_formula):
-    simulated_pln_2cov_formula.fit()
-    return simulated_pln_2cov_formula
+def simulated_fitted_model_2cov_formula(simulated_model_2cov_formula):
+    simulated_model_2cov_formula.fit()
+    return simulated_model_2cov_formula
 
 
 @pytest.fixture
-def simulated_loaded_pln_2cov_formula(simulated_fitted_pln_2cov_formula):
-    simulated_fitted_pln_2cov_formula.save()
+def simulated_loaded_model_2cov_formula(simulated_fitted_model_2cov_formula):
+    simulated_fitted_model_2cov_formula.save()
     return generate_new_model(
-        simulated_fitted_pln_2cov_formula,
+        simulated_fitted_model_2cov_formula,
         "endog ~0 + exog",
         data_sim_2cov,
     )
 
 
 @pytest.fixture
-def simulated_loaded_pln_2cov_array(simulated_fitted_pln_2cov_array):
-    simulated_fitted_pln_2cov_array.save()
+def simulated_loaded_model_2cov_array(simulated_fitted_model_2cov_array):
+    simulated_fitted_model_2cov_array.save()
     return generate_new_model(
-        simulated_fitted_pln_2cov_array,
+        simulated_fitted_model_2cov_array,
         endog_sim_2cov,
         exog=exog_sim_2cov,
         offsets=offsets_sim_2cov,
@@ -275,147 +276,149 @@ def simulated_loaded_pln_2cov_array(simulated_fitted_pln_2cov_array):
     )
 
 
-sim_pln_2cov_instance = [
-    "simulated_pln_2cov_array",
-    "simulated_pln_2cov_formula",
+sim_model_2cov_instance = [
+    "simulated_model_2cov_array",
+    "simulated_model_2cov_formula",
 ]
-instances = sim_pln_2cov_instance + instances
+instances = sim_model_2cov_instance + instances
 
 dict_fixtures = add_list_of_fixture_to_dict(
-    dict_fixtures, "sim_pln_2cov_instance", sim_pln_2cov_instance
+    dict_fixtures, "sim_model_2cov_instance", sim_model_2cov_instance
 )
 
-sim_pln_2cov_fitted = [
-    "simulated_fitted_pln_2cov_array",
-    "simulated_fitted_pln_2cov_formula",
+sim_model_2cov_fitted = [
+    "simulated_fitted_model_2cov_array",
+    "simulated_fitted_model_2cov_formula",
 ]
 
 dict_fixtures = add_list_of_fixture_to_dict(
-    dict_fixtures, "sim_pln_2cov_fitted", sim_pln_2cov_fitted
+    dict_fixtures, "sim_model_2cov_fitted", sim_model_2cov_fitted
 )
 
-sim_pln_2cov_loaded = [
-    "simulated_loaded_pln_2cov_array",
-    "simulated_loaded_pln_2cov_formula",
+sim_model_2cov_loaded = [
+    "simulated_loaded_model_2cov_array",
+    "simulated_loaded_model_2cov_formula",
 ]
 
 dict_fixtures = add_list_of_fixture_to_dict(
-    dict_fixtures, "sim_pln_2cov_loaded", sim_pln_2cov_loaded
+    dict_fixtures, "sim_model_2cov_loaded", sim_model_2cov_loaded
 )
 
-sim_pln_2cov = sim_pln_2cov_instance + sim_pln_2cov_fitted + sim_pln_2cov_loaded
-dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "sim_pln_2cov", sim_pln_2cov)
+sim_model_2cov = sim_model_2cov_instance + sim_model_2cov_fitted + sim_model_2cov_loaded
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "sim_model_2cov", sim_model_2cov
+)
 
 
 @pytest.fixture(params=params)
 @cache
-def real_pln_intercept_array(request):
+def real_model_intercept_array(request):
     cls = request.param
-    pln_full = cls(endog_real, add_const=True)
-    return pln_full
+    model = cls(endog_real, add_const=True)
+    return model
 
 
 @pytest.fixture
-def real_fitted_pln_intercept_array(real_pln_intercept_array):
-    real_pln_intercept_array.fit()
-    return real_pln_intercept_array
+def real_fitted_model_intercept_array(real_model_intercept_array):
+    real_model_intercept_array.fit()
+    return real_model_intercept_array
 
 
 @pytest.fixture(params=params)
 @cache
-def real_pln_intercept_formula(request):
+def real_model_intercept_formula(request):
     cls = request.param
-    pln_full = cls("endog ~ 1", data_real)
-    return pln_full
+    model = cls("endog ~ 1", data_real)
+    return model
 
 
 @pytest.fixture
-def real_fitted_pln_intercept_formula(real_pln_intercept_formula):
-    real_pln_intercept_formula.fit()
-    return real_pln_intercept_formula
+def real_fitted_model_intercept_formula(real_model_intercept_formula):
+    real_model_intercept_formula.fit()
+    return real_model_intercept_formula
 
 
 @pytest.fixture
-def real_loaded_pln_intercept_formula(real_fitted_pln_intercept_formula):
-    real_fitted_pln_intercept_formula.save()
+def real_loaded_model_intercept_formula(real_fitted_model_intercept_formula):
+    real_fitted_model_intercept_formula.save()
     return generate_new_model(
-        real_fitted_pln_intercept_formula, "endog ~ 1", data=data_real
+        real_fitted_model_intercept_formula, "endog ~ 1", data=data_real
     )
 
 
 @pytest.fixture
-def real_loaded_pln_intercept_array(real_fitted_pln_intercept_array):
-    real_fitted_pln_intercept_array.save()
+def real_loaded_model_intercept_array(real_fitted_model_intercept_array):
+    real_fitted_model_intercept_array.save()
     return generate_new_model(
-        real_fitted_pln_intercept_array,
+        real_fitted_model_intercept_array,
         endog_real,
         add_const=True,
     )
 
 
-real_pln_instance = [
-    "real_pln_intercept_array",
-    "real_pln_intercept_formula",
+real_model_instance = [
+    "real_model_intercept_array",
+    "real_model_intercept_formula",
 ]
-instances = real_pln_instance + instances
+instances = real_model_instance + instances
 
 dict_fixtures = add_list_of_fixture_to_dict(
-    dict_fixtures, "real_pln_instance", real_pln_instance
+    dict_fixtures, "real_model_instance", real_model_instance
 )
 
-real_pln_fitted = [
-    "real_fitted_pln_intercept_array",
-    "real_fitted_pln_intercept_formula",
+real_model_fitted = [
+    "real_fitted_model_intercept_array",
+    "real_fitted_model_intercept_formula",
 ]
 dict_fixtures = add_list_of_fixture_to_dict(
-    dict_fixtures, "real_pln_fitted", real_pln_fitted
+    dict_fixtures, "real_model_fitted", real_model_fitted
 )
 
-real_pln_loaded = [
-    "real_loaded_pln_intercept_array",
-    "real_loaded_pln_intercept_formula",
+real_model_loaded = [
+    "real_loaded_model_intercept_array",
+    "real_loaded_model_intercept_formula",
 ]
 dict_fixtures = add_list_of_fixture_to_dict(
-    dict_fixtures, "real_pln_loaded", real_pln_loaded
+    dict_fixtures, "real_model_loaded", real_model_loaded
 )
 
-sim_loaded_pln = sim_pln_0cov_loaded + sim_pln_2cov_loaded
+sim_loaded_model = sim_model_0cov_loaded + sim_model_2cov_loaded
 
-loaded_pln = real_pln_loaded + sim_loaded_pln
-dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "loaded_pln", loaded_pln)
+loaded_model = real_model_loaded + sim_loaded_model
+dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "loaded_model", loaded_model)
 
-simulated_pln_fitted = sim_pln_0cov_fitted + sim_pln_2cov_fitted
+simulated_model_fitted = sim_model_0cov_fitted + sim_model_2cov_fitted
 dict_fixtures = add_list_of_fixture_to_dict(
-    dict_fixtures, "simulated_pln_fitted", simulated_pln_fitted
+    dict_fixtures, "simulated_model_fitted", simulated_model_fitted
 )
-fitted_pln = real_pln_fitted + simulated_pln_fitted
-dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "fitted_pln", fitted_pln)
+fitted_model = real_model_fitted + simulated_model_fitted
+dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "fitted_model", fitted_model)
 
 
-loaded_and_fitted_sim_pln = simulated_pln_fitted + sim_loaded_pln
-loaded_and_fitted_real_pln = real_pln_fitted + real_pln_loaded
+loaded_and_fitted_sim_model = simulated_model_fitted + sim_loaded_model
+loaded_and_fitted_real_model = real_model_fitted + real_model_loaded
 dict_fixtures = add_list_of_fixture_to_dict(
-    dict_fixtures, "loaded_and_fitted_real_pln", loaded_and_fitted_real_pln
+    dict_fixtures, "loaded_and_fitted_real_model", loaded_and_fitted_real_model
 )
 dict_fixtures = add_list_of_fixture_to_dict(
-    dict_fixtures, "loaded_and_fitted_sim_pln", loaded_and_fitted_sim_pln
+    dict_fixtures, "loaded_and_fitted_sim_model", loaded_and_fitted_sim_model
 )
-loaded_and_fitted_pln = fitted_pln + loaded_pln
+loaded_and_fitted_model = fitted_model + loaded_model
 dict_fixtures = add_list_of_fixture_to_dict(
-    dict_fixtures, "loaded_and_fitted_pln", loaded_and_fitted_pln
+    dict_fixtures, "loaded_and_fitted_model", loaded_and_fitted_model
 )
 
-real_pln = real_pln_instance + real_pln_fitted + real_pln_loaded
-dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "real_pln", real_pln)
+real_model = real_model_instance + real_model_fitted + real_model_loaded
+dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "real_model", real_model)
 
-sim_pln = sim_pln_2cov + sim_pln_0cov
-dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "sim_pln", sim_pln)
+sim_model = sim_model_2cov + sim_model_0cov
+dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "sim_model", sim_model)
 
-all_pln = real_pln + sim_pln + instances
+all_model = real_model + sim_model + instances
 dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "instances", instances)
-dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "all_pln", all_pln)
+dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "all_model", all_model)
 
 
-for string_fixture in all_pln:
+for string_fixture in all_model:
     print("string_fixture", string_fixture)
     dict_fixtures = add_fixture_to_dict(dict_fixtures, string_fixture)
diff --git a/tests/test_common.py b/tests/test_common.py
index b1a6837c..cec97a72 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -8,82 +8,85 @@ from tests.utils import MSE, filter_models
 
 from tests.import_data import true_sim_0cov, true_sim_2cov, endog_real
 
+single_models = ["Pln", "PlnPCA", "ZIPln"]
+pln_and_plnpca = ["Pln", "PlnPCA"]
 
-@pytest.mark.parametrize("any_pln", dict_fixtures["loaded_and_fitted_pln"])
-@filter_models(["Pln", "PlnPCA"])
-def test_properties(any_pln):
-    assert hasattr(any_pln, "latent_parameters")
-    assert hasattr(any_pln, "latent_variables")
-    assert hasattr(any_pln, "optim_parameters")
-    assert hasattr(any_pln, "model_parameters")
 
+@pytest.mark.parametrize("any_model", dict_fixtures["loaded_and_fitted_model"])
+@filter_models(single_models)
+def test_properties(any_model):
+    assert hasattr(any_model, "latent_parameters")
+    assert hasattr(any_model, "latent_variables")
+    assert hasattr(any_model, "optim_parameters")
+    assert hasattr(any_model, "model_parameters")
 
-@pytest.mark.parametrize("sim_pln", dict_fixtures["loaded_and_fitted_pln"])
-@filter_models(["Pln", "PlnPCA"])
-def test_predict_simulated(sim_pln):
-    if sim_pln.nb_cov == 0:
-        assert sim_pln.predict() is None
+
+@pytest.mark.parametrize("sim_model", dict_fixtures["loaded_and_fitted_model"])
+@filter_models(pln_and_plnpca)
+def test_predict_simulated(sim_model):
+    if sim_model.nb_cov == 0:
+        assert sim_model.predict() is None
         with pytest.raises(AttributeError):
-            sim_pln.predict(1)
+            sim_model.predict(1)
     else:
-        X = torch.randn((sim_pln.n_samples, sim_pln.nb_cov))
-        prediction = sim_pln.predict(X)
-        expected = X @ sim_pln.coef
+        X = torch.randn((sim_model.n_samples, sim_model.nb_cov))
+        prediction = sim_model.predict(X)
+        expected = X @ sim_model.coef
         assert torch.all(torch.eq(expected, prediction))
 
 
-@pytest.mark.parametrize("any_instance_pln", dict_fixtures["instances"])
-def test_verbose(any_instance_pln):
-    any_instance_pln.fit(verbose=True, tol=0.1)
+@pytest.mark.parametrize("any_instance_model", dict_fixtures["instances"])
+def test_verbose(any_instance_model):
+    any_instance_model.fit(verbose=True, tol=0.1)
 
 
 @pytest.mark.parametrize(
-    "simulated_fitted_any_pln", dict_fixtures["loaded_and_fitted_sim_pln"]
+    "simulated_fitted_any_model", dict_fixtures["loaded_and_fitted_sim_model"]
 )
-@filter_models(["Pln", "PlnPCA"])
-def test_find_right_covariance(simulated_fitted_any_pln):
-    if simulated_fitted_any_pln.nb_cov == 0:
+@filter_models(pln_and_plnpca)
+def test_find_right_covariance(simulated_fitted_any_model):
+    if simulated_fitted_any_model.nb_cov == 0:
         true_covariance = true_sim_0cov["Sigma"]
-    elif simulated_fitted_any_pln.nb_cov == 2:
+    elif simulated_fitted_any_model.nb_cov == 2:
         true_covariance = true_sim_2cov["Sigma"]
     else:
         raise ValueError(
-            f"Not the right numbers of covariance({simulated_fitted_any_pln.nb_cov})"
+            f"Not the right numbers of covariance({simulated_fitted_any_model.nb_cov})"
         )
-    mse_covariance = MSE(simulated_fitted_any_pln.covariance - true_covariance)
+    mse_covariance = MSE(simulated_fitted_any_model.covariance - true_covariance)
     assert mse_covariance < 0.05
 
 
 @pytest.mark.parametrize(
-    "real_fitted_and_loaded_pln", dict_fixtures["loaded_and_fitted_real_pln"]
+    "real_fitted_and_loaded_model", dict_fixtures["loaded_and_fitted_real_model"]
 )
-@filter_models(["Pln", "PlnPCA"])
-def test_right_covariance_shape(real_fitted_and_loaded_pln):
-    assert real_fitted_and_loaded_pln.covariance.shape == (
+@filter_models(single_models)
+def test_right_covariance_shape(real_fitted_and_loaded_model):
+    assert real_fitted_and_loaded_model.covariance.shape == (
         endog_real.shape[1],
         endog_real.shape[1],
     )
 
 
 @pytest.mark.parametrize(
-    "simulated_fitted_any_pln", dict_fixtures["loaded_and_fitted_pln"]
+    "simulated_fitted_any_model", dict_fixtures["loaded_and_fitted_model"]
 )
-@filter_models(["Pln", "PlnPCA"])
-def test_find_right_coef(simulated_fitted_any_pln):
-    if simulated_fitted_any_pln.nb_cov == 2:
+@filter_models(pln_and_plnpca)
+def test_find_right_coef(simulated_fitted_any_model):
+    if simulated_fitted_any_model.nb_cov == 2:
         true_coef = true_sim_2cov["beta"]
-        mse_coef = MSE(simulated_fitted_any_pln.coef - true_coef)
+        mse_coef = MSE(simulated_fitted_any_model.coef - true_coef)
         assert mse_coef < 0.1
-    elif simulated_fitted_any_pln.nb_cov == 0:
-        assert simulated_fitted_any_pln.coef is None
+    elif simulated_fitted_any_model.nb_cov == 0:
+        assert simulated_fitted_any_model.coef is None
 
 
-@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"])
-@filter_models(["Pln", "PlnPCA"])
-def test_fail_count_setter(pln):
+@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_model"])
+@filter_models(single_models)
+def test_fail_count_setter(model):
     wrong_endog = torch.randint(size=(10, 5), low=0, high=10)
     with pytest.raises(Exception):
-        pln.endog = wrong_endog
+        model.endog = wrong_endog
 
 
 @pytest.mark.parametrize("instance", dict_fixtures["instances"])
@@ -96,9 +99,9 @@ def test__print_end_of_fitting_message(instance):
     instance.fit(nb_max_iteration=4)
 
 
-@pytest.mark.parametrize("pln", dict_fixtures["fitted_pln"])
-@filter_models(["Pln", "PlnPCA"])
-def test_fail_wrong_exog_prediction(pln):
-    X = torch.randn(pln.n_samples, pln.nb_cov + 1)
+@pytest.mark.parametrize("model", dict_fixtures["fitted_model"])
+@filter_models(single_models)
+def test_fail_wrong_exog_prediction(model):
+    X = torch.randn(model.n_samples, model.nb_cov + 1)
     with pytest.raises(Exception):
-        pln.predict(X)
+        model.predict(X)
diff --git a/tests/test_pln_full.py b/tests/test_pln_full.py
index 870114a0..1115e1ec 100644
--- a/tests/test_pln_full.py
+++ b/tests/test_pln_full.py
@@ -4,14 +4,14 @@ from tests.conftest import dict_fixtures
 from tests.utils import filter_models
 
 
-@pytest.mark.parametrize("fitted_pln", dict_fixtures["fitted_pln"])
+@pytest.mark.parametrize("fitted_pln", dict_fixtures["fitted_model"])
 @filter_models(["Pln"])
 def test_number_of_iterations_pln_full(fitted_pln):
-    nb_iterations = len(fitted_pln._elbos_list)
+    nb_iterations = len(fitted_pln.elbos_list)
     assert 20 < nb_iterations < 1000
 
 
-@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"])
+@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_model"])
 @filter_models(["Pln"])
 def test_latent_var_full(pln):
     assert pln.transform().shape == pln.endog.shape
-- 
GitLab


From 00564b40079ebea8b80f3171779b76bdea0f64fb Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Mon, 16 Oct 2023 15:10:08 +0200
Subject: [PATCH 39/68] right elbo, problem before.

---
 pyPLNmodels/elbos.py | 203 ++++++++++++++++++++++++-------------------
 1 file changed, 113 insertions(+), 90 deletions(-)

diff --git a/pyPLNmodels/elbos.py b/pyPLNmodels/elbos.py
index ec743430..cf235ec4 100644
--- a/pyPLNmodels/elbos.py
+++ b/pyPLNmodels/elbos.py
@@ -5,63 +5,6 @@ from ._closed_forms import _closed_formula_covariance, _closed_formula_coef
 from typing import Optional
 
 
-def elbo_pln(
-    endog: torch.Tensor,
-    offsets: torch.Tensor,
-    exog: Optional[torch.Tensor],
-    latent_mean: torch.Tensor,
-    latent_sqrt_var: torch.Tensor,
-    covariance: torch.Tensor,
-    coef: torch.Tensor,
-) -> torch.Tensor:
-    """
-    Compute the ELBO (Evidence Lower Bound) for the Pln model.
-
-    Parameters:
-    ----------
-    endog : torch.Tensor
-        Counts with size (n, p).
-    offsets : torch.Tensor
-        Offset with size (n, p).
-    exog : torch.Tensor, optional
-        Covariates with size (n, d).
-    latent_mean : torch.Tensor
-        Variational parameter with size (n, p).
-    latent_sqrt_var : torch.Tensor
-        Variational parameter with size (n, p).
-    covariance : torch.Tensor
-        Model parameter with size (p, p).
-    coef : torch.Tensor
-        Model parameter with size (d, p).
-
-    Returns:
-    -------
-    torch.Tensor
-        The ELBO (Evidence Lower Bound), of size one.
-    """
-    n_samples, dim = endog.shape
-    s_rond_s = torch.square(latent_sqrt_var)
-    offsets_plus_m = offsets + latent_mean
-    if exog is None:
-        XB = torch.zeros_like(endog)
-    else:
-        XB = exog @ coef
-    m_minus_xb = latent_mean - XB
-    d_plus_minus_xb2 = (
-        torch.diag(torch.sum(s_rond_s, dim=0)) + m_minus_xb.T @ m_minus_xb
-    )
-    elbo = -0.5 * n_samples * torch.logdet(covariance)
-    elbo += torch.sum(
-        endog * offsets_plus_m
-        - 0.5 * torch.exp(offsets_plus_m + s_rond_s)
-        + 0.5 * torch.log(s_rond_s)
-    )
-    elbo -= 0.5 * torch.trace(torch.inverse(covariance) @ d_plus_minus_xb2)
-    elbo -= torch.sum(_log_stirling(endog))
-    elbo += 0.5 * n_samples * dim
-    return elbo / n_samples
-
-
 def profiled_elbo_pln(
     endog: torch.Tensor,
     exog: torch.Tensor,
@@ -172,6 +115,78 @@ def elbo_plnpca(
     ) / n_samples
 
 
+def log1pexp(x):
+    # more stable version of log(1 + exp(x))
+    return torch.where(x < 50, torch.log1p(torch.exp(x)), x)
+
+
+def elbo_pln(
+    endog: torch.Tensor,
+    exog: Optional[torch.Tensor],
+    offsets: torch.Tensor,
+    latent_mean: torch.Tensor,
+    latent_sqrt_var: torch.Tensor,
+    covariance: torch.Tensor,
+    coef: torch.Tensor,
+) -> torch.Tensor:
+    """
+    Compute the ELBO (Evidence Lower Bound) for the Pln model.
+
+    Parameters:
+    ----------
+    endog : torch.Tensor
+        Counts with size (n, p).
+    offsets : torch.Tensor
+        Offset with size (n, p).
+    exog : torch.Tensor, optional
+        Covariates with size (n, d).
+    latent_mean : torch.Tensor
+        Variational parameter with size (n, p).
+    latent_sqrt_var : torch.Tensor
+        Variational parameter with size (n, p).
+    covariance : torch.Tensor
+        Model parameter with size (p, p).
+    coef : torch.Tensor
+        Model parameter with size (d, p).
+
+    Returns:
+    -------
+    torch.Tensor
+        The ELBO (Evidence Lower Bound), of size one.
+    """
+    n_samples, dim = endog.shape
+    s_rond_s = torch.square(latent_sqrt_var)
+    offsets_plus_m = offsets + latent_mean
+    Omega = torch.inverse(covariance)
+    if exog is None:
+        XB = torch.zeros_like(endog)
+    else:
+        XB = exog @ coef
+    # print('XB:', XB)
+    m_minus_xb = latent_mean - XB
+    m_moins_xb_outer = torch.mm(m_minus_xb.T, m_minus_xb)
+    A = torch.exp(offsets_plus_m + s_rond_s / 2)
+    first_a = torch.sum(endog * offsets_plus_m)
+    sec_a = -torch.sum(A)
+    third_a = -torch.sum(_log_stirling(endog))
+    a = first_a + sec_a + third_a
+    diag = torch.diag(torch.sum(s_rond_s, dim=0))
+    elbo = torch.clone(a)
+    b = -0.5 * n_samples * torch.logdet(covariance) + torch.sum(
+        -1 / 2 * Omega * m_moins_xb_outer
+    )
+    elbo += b
+    d = n_samples * dim / 2 + torch.sum(+0.5 * torch.log(s_rond_s))
+    elbo += d
+    f = -0.5 * torch.trace(torch.inverse(covariance) @ diag)
+    elbo += f
+    # print("a pln", a)
+    # print("b pln", b)
+    # print("d pln", d)
+    # print("f pln", f)
+    return elbo  # / n_samples
+
+
 ## pb with trunc_log
 ## should rename some variables so that is is clearer when we see the formula
 def elbo_zi_pln(
@@ -194,7 +209,7 @@ def elbo_zi_pln(
         0: torch.tensor. Offset, size (n,p)
         exog: torch.tensor. Covariates, size (n,d)
         latent_mean: torch.tensor. Variational parameter with size (n,p)
-        latent_sqrt_var: torch.tensor. Variational parameter with size (n,p)
+        latent_var: torch.tensor. Variational parameter with size (n,p)
         pi: torch.tensor. Variational parameter with size (n,p)
         covariance: torch.tensor. Model parameter with size (p,p)
         coef: torch.tensor. Model parameter with size (d,p)
@@ -202,52 +217,60 @@ def elbo_zi_pln(
     Returns:
         torch.tensor of size 1 with a gradient.
     """
-    if torch.norm(latent_prob * dirac - latent_prob) > 0.00000001:
-        raise RuntimeError("Latent probability is not zero when it should be.")
     covariance = components @ (components.T)
-    diag_cov = torch.diag(covariance)
-    Omega = torch.inverse(covariance)
-    diag_omega = torch.diag(Omega)
-    un_moins_prob = 1 - latent_prob
+    if torch.norm(latent_prob * dirac - latent_prob) > 0.00000001:
+        print("Bug")
+        raise RuntimeError("rho error")
     n_samples, dim = endog.shape
-    s_rond_s = latent_sqrt_var * latent_sqrt_var
+    s_rond_s = torch.multiply(latent_sqrt_var, latent_sqrt_var)
     o_plus_m = offsets + latent_mean
     if exog is None:
         XB = torch.zeros_like(endog)
-        xcoef_inflation = torch.zeros_like(endog)
+        x_coef_inflation = torch.zeros_like(endog)
     else:
         XB = exog @ coef
-        xcoef_inflation = exog @ coef_inflation
+        x_coef_inflation = exog @ coef_inflation
+
     m_minus_xb = latent_mean - XB
 
     A = torch.exp(o_plus_m + s_rond_s / 2)
-    inside_a = un_moins_prob * (endog * o_plus_m - A - _log_stirling(endog))
+    inside_a = torch.multiply(
+        1 - latent_prob, torch.multiply(endog, o_plus_m) - A - _log_stirling(endog)
+    )
+    Omega = torch.inverse(covariance)
+
     m_moins_xb_outer = torch.mm(m_minus_xb.T, m_minus_xb)
-    un_moins_prob_m_moins_xb = un_moins_prob * m_minus_xb
-    un_moins_prob_m_moins_xb_outer = (
-        un_moins_prob_m_moins_xb.T @ un_moins_prob_m_moins_xb
+    un_moins_rho = 1 - latent_prob
+    un_moins_rho_m_moins_xb = un_moins_rho * m_minus_xb
+    un_moins_rho_m_moins_xb_outer = un_moins_rho_m_moins_xb.T @ un_moins_rho_m_moins_xb
+    inside_b = -1 / 2 * Omega * un_moins_rho_m_moins_xb_outer
+
+    inside_c = torch.multiply(latent_prob, x_coef_inflation) - torch.log(
+        1 + torch.exp(x_coef_inflation)
     )
-    inside_b = -1 / 2 * Omega * un_moins_prob_m_moins_xb_outer
 
-    inside_c = latent_prob * xcoef_inflation - torch.log(1 + torch.exp(xcoef_inflation))
-    log_diag = torch.log(diag_cov)
+    log_diag = torch.log(torch.diag(covariance))
     log_S_term = torch.sum(
-        un_moins_prob * torch.log(torch.abs(latent_sqrt_var)), axis=0
+        torch.multiply(1 - latent_prob, torch.log(torch.abs(latent_sqrt_var))), axis=0
     )
-    sum_prob = torch.sum(latent_prob, axis=0)
-    covariance_term = 1 / 2 * torch.log(diag_cov) * sum_prob
+    y = torch.sum(latent_prob, axis=0)
+    covariance_term = 1 / 2 * torch.log(torch.diag(covariance)) * y
     inside_d = covariance_term + log_S_term
 
-    inside_e = torch.multiply(
-        latent_prob, _trunc_log(latent_prob)
-    ) + un_moins_prob * _trunc_log(un_moins_prob)
-    sum_un_moins_prob_s2 = torch.sum(un_moins_prob * s_rond_s, axis=0)
-    diag_sig_sum_prob = diag_cov * torch.sum(latent_prob, axis=0)
-    new = torch.sum(latent_prob * un_moins_prob * (m_minus_xb**2), axis=0)
-    K = sum_un_moins_prob_s2 + diag_sig_sum_prob + new
-    inside_f = -1 / 2 * diag_omega * K
-    full_diag_omega = diag_omega.expand(exog.shape[0], -1)
-    elbo = torch.sum(inside_a + inside_c + inside_d)
-    elbo += torch.sum(inside_b) - n_samples / 2 * torch.logdet(covariance)
-    elbo += n_samples * dim / 2 + torch.sum(inside_d + inside_f)
-    return elbo
+    inside_e = -torch.multiply(latent_prob, _trunc_log(latent_prob)) - torch.multiply(
+        1 - latent_prob, _trunc_log(1 - latent_prob)
+    )
+    sum_un_moins_rho_s2 = torch.sum(torch.multiply(1 - latent_prob, s_rond_s), axis=0)
+    diag_sig_sum_rho = torch.multiply(
+        torch.diag(covariance), torch.sum(latent_prob, axis=0)
+    )
+    new = torch.sum(latent_prob * un_moins_rho * (m_minus_xb**2), axis=0)
+    K = sum_un_moins_rho_s2 + diag_sig_sum_rho + new
+    inside_f =-1 / 2 *  torch.diag(Omega) * K
+    first = torch.sum(inside_a + inside_c + inside_e)
+    second = torch.sum(inside_b)
+    second -= n_samples / 2 * torch.logdet(covariance)
+    third = torch.sum(inside_d + inside_f)
+    third += n_samples*dim/2
+    res = first + second + third
+    return res
-- 
GitLab


From 2c0b48e005ceec9f58b10c3c360ea66684b33e74 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Mon, 16 Oct 2023 17:06:43 +0200
Subject: [PATCH 40/68] tried to see why the zi bug with real data. The
 components have zero determinant.

---
 pyPLNmodels/_initialization.py |  4 ++--
 pyPLNmodels/elbos.py           | 14 +++++++++++---
 pyPLNmodels/models.py          | 13 ++++++++-----
 3 files changed, 21 insertions(+), 10 deletions(-)

diff --git a/pyPLNmodels/_initialization.py b/pyPLNmodels/_initialization.py
index f5663746..ac4472a6 100644
--- a/pyPLNmodels/_initialization.py
+++ b/pyPLNmodels/_initialization.py
@@ -44,7 +44,7 @@ def _init_covariance(endog: torch.Tensor, exog: torch.Tensor) -> torch.Tensor:
 
 
 def _init_components(
-    endog: torch.Tensor, exog: torch.Tensor, rank: int
+    endog: torch.Tensor, rank: int
 ) -> torch.Tensor:
     """
     Initialization for components for the Pln model. Get a first guess for covariance
@@ -65,7 +65,7 @@ def _init_components(
     log_y = torch.log(endog + (endog == 0) * math.exp(-2))
     pca = PCA(n_components=rank)
     pca.fit(log_y.detach().cpu())
-    pca_comp = pca.components_.T * np.sqrt(pca.explained_variance_)
+    pca_comp = pca.components_.T * np.sqrt(pca.explained_variance_ + 0.001)
     return torch.from_numpy(pca_comp).to(DEVICE)
 
 
diff --git a/pyPLNmodels/elbos.py b/pyPLNmodels/elbos.py
index cf235ec4..3862de9f 100644
--- a/pyPLNmodels/elbos.py
+++ b/pyPLNmodels/elbos.py
@@ -238,17 +238,16 @@ def elbo_zi_pln(
         1 - latent_prob, torch.multiply(endog, o_plus_m) - A - _log_stirling(endog)
     )
     Omega = torch.inverse(covariance)
-
     m_moins_xb_outer = torch.mm(m_minus_xb.T, m_minus_xb)
     un_moins_rho = 1 - latent_prob
     un_moins_rho_m_moins_xb = un_moins_rho * m_minus_xb
     un_moins_rho_m_moins_xb_outer = un_moins_rho_m_moins_xb.T @ un_moins_rho_m_moins_xb
     inside_b = -1 / 2 * Omega * un_moins_rho_m_moins_xb_outer
-
     inside_c = torch.multiply(latent_prob, x_coef_inflation) - torch.log(
         1 + torch.exp(x_coef_inflation)
     )
 
+
     log_diag = torch.log(torch.diag(covariance))
     log_S_term = torch.sum(
         torch.multiply(1 - latent_prob, torch.log(torch.abs(latent_sqrt_var))), axis=0
@@ -267,10 +266,19 @@ def elbo_zi_pln(
     new = torch.sum(latent_prob * un_moins_rho * (m_minus_xb**2), axis=0)
     K = sum_un_moins_rho_s2 + diag_sig_sum_rho + new
     inside_f =-1 / 2 *  torch.diag(Omega) * K
+    print("inside_a",torch.sum(inside_a))
+    print("inside_b",torch.sum(inside_b))
+    print("inside_c",torch.sum(inside_c))
+    print("inside_d",torch.sum(inside_d))
+    print("inside_e",torch.sum(inside_e))
+    print("inside_f",torch.sum(inside_f))
     first = torch.sum(inside_a + inside_c + inside_e)
+    print('first', first)
     second = torch.sum(inside_b)
-    second -= n_samples / 2 * torch.logdet(covariance)
+    second -= n_samples * torch.logdet(components)
+    print('logdet', torch.logdet(components))
     third = torch.sum(inside_d + inside_f)
     third += n_samples*dim/2
+    print('third', third)
     res = first + second + third
     return res
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index e9bcfd01..49298b4f 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -397,7 +397,6 @@ class _model(ABC):
         while self.nb_iteration_done < nb_max_iteration and not stop_condition:
             loss = self._trainstep()
             criterion = self._update_criterion_args(loss)
-            print("criterion", criterion)
             if abs(criterion) < tol:
                 stop_condition = True
             if verbose and self.nb_iteration_done % 50 == 1:
@@ -477,6 +476,8 @@ class _model(ABC):
             self._extract_batch(batch)
             self.optim.zero_grad()
             loss = -self._compute_elbo_b()
+            if torch.sum(torch.isnan(loss)):
+                raise ValueError("test")
             loss.backward()
             elbo += loss.item()
             self._update_parameters()
@@ -3334,7 +3335,7 @@ class PlnPCA(_model):
         if not hasattr(self, "_coef"):
             super()._smart_init_coef()
         if not hasattr(self, "_components"):
-            self._components = _init_components(self._endog, self._exog, self._rank)
+            self._components = _init_components(self._endog, self._rank)
 
     @_add_doc(_model)
     def _random_init_latent_parameters(self):
@@ -3490,7 +3491,7 @@ class ZIPln(_model):
         use_closed_form_prob: bool = True,
     ):
         """
-        Create a model instance from a formula and data.
+        Create a ZIPln instance from a formula and data.
 
         Parameters
         ----------
@@ -3585,7 +3586,6 @@ class ZIPln(_model):
         return "with full covariance model and zero-inflation."
 
     def _random_init_model_parameters(self):
-        super()._random_init_model_parameters()
         self._coef_inflation = torch.randn(self.nb_cov, self.dim)
         self._coef = torch.randn(self.nb_cov, self.dim)
         self._components = torch.randn(self.nb_cov, self.dim)
@@ -3595,7 +3595,10 @@ class ZIPln(_model):
         # init of _coef.
         super()._smart_init_coef()
         if not hasattr(self, "_covariance"):
-            self._components = _init_components(self._endog, self._exog, self.dim)
+            self._components = _init_components(self._endog, self.dim)
+            print('sum components', torch.sum(self._components))
+            print('sum endog', torch.sum(self._endog))
+            print('log det ', torch.logdet(self._components))
         if not hasattr(self, "_coef_inflation"):
             self._coef_inflation = torch.randn(self.nb_cov, self.dim)
 
-- 
GitLab


From 6cfb27cee2e221730d9bfbf116050bc8f84205cb Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Mon, 16 Oct 2023 19:46:36 +0200
Subject: [PATCH 41/68] fixed the elbo issu

---
 pyPLNmodels/_initialization.py |  2 +-
 pyPLNmodels/_utils.py          | 13 +++++++++++++
 pyPLNmodels/elbos.py           | 12 ++----------
 pyPLNmodels/models.py          |  6 +++---
 tests/conftest.py              |  1 -
 5 files changed, 19 insertions(+), 15 deletions(-)

diff --git a/pyPLNmodels/_initialization.py b/pyPLNmodels/_initialization.py
index ac4472a6..410283bf 100644
--- a/pyPLNmodels/_initialization.py
+++ b/pyPLNmodels/_initialization.py
@@ -65,7 +65,7 @@ def _init_components(
     log_y = torch.log(endog + (endog == 0) * math.exp(-2))
     pca = PCA(n_components=rank)
     pca.fit(log_y.detach().cpu())
-    pca_comp = pca.components_.T * np.sqrt(pca.explained_variance_ + 0.001)
+    pca_comp = pca.components_.T * np.sqrt(pca.explained_variance_)
     return torch.from_numpy(pca_comp).to(DEVICE)
 
 
diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index f5f02942..802e5f2e 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -1080,3 +1080,16 @@ def d_h_x3(a, x, y, dirac):
     rho = torch.sigmoid(a - torch.log(phi(x, y))) * dirac
     rho_prime = rho * (1 - rho)
     return -rho_prime * d_varpsi_x2(x, y) / phi(x, y)
+
+def vec_to_mat(C, p, q):
+    c = torch.zeros(p, q)
+    c[torch.tril_indices(p, q, offset=0).tolist()] = C
+    # c = C.reshape(p,q)
+    return c
+
+
+def mat_to_vec(matc, p, q):
+    tril = torch.tril(matc)
+    # tril = matc.reshape(-1,1).squeeze()
+    return tril[torch.tril_indices(p, q, offset=0).tolist()]
+
diff --git a/pyPLNmodels/elbos.py b/pyPLNmodels/elbos.py
index 3862de9f..585b423d 100644
--- a/pyPLNmodels/elbos.py
+++ b/pyPLNmodels/elbos.py
@@ -266,19 +266,11 @@ def elbo_zi_pln(
     new = torch.sum(latent_prob * un_moins_rho * (m_minus_xb**2), axis=0)
     K = sum_un_moins_rho_s2 + diag_sig_sum_rho + new
     inside_f =-1 / 2 *  torch.diag(Omega) * K
-    print("inside_a",torch.sum(inside_a))
-    print("inside_b",torch.sum(inside_b))
-    print("inside_c",torch.sum(inside_c))
-    print("inside_d",torch.sum(inside_d))
-    print("inside_e",torch.sum(inside_e))
-    print("inside_f",torch.sum(inside_f))
     first = torch.sum(inside_a + inside_c + inside_e)
-    print('first', first)
     second = torch.sum(inside_b)
-    second -= n_samples * torch.logdet(components)
-    print('logdet', torch.logdet(components))
+    _, logdet = torch.slogdet(components)
+    second -= n_samples *logdet
     third = torch.sum(inside_d + inside_f)
     third += n_samples*dim/2
-    print('third', third)
     res = first + second + third
     return res
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 49298b4f..db723dba 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -32,6 +32,8 @@ from ._utils import (
     _array2tensor,
     _handle_data,
     _add_doc,
+    vec_to_mat,
+    mat_to_vec,
 )
 
 from ._initialization import (
@@ -3596,9 +3598,7 @@ class ZIPln(_model):
         super()._smart_init_coef()
         if not hasattr(self, "_covariance"):
             self._components = _init_components(self._endog, self.dim)
-            print('sum components', torch.sum(self._components))
-            print('sum endog', torch.sum(self._endog))
-            print('log det ', torch.logdet(self._components))
+
         if not hasattr(self, "_coef_inflation"):
             self._coef_inflation = torch.randn(self.nb_cov, self.dim)
 
diff --git a/tests/conftest.py b/tests/conftest.py
index e40558dd..85ca2aaf 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -37,7 +37,6 @@ endog_real = data_real["endog"]
 endog_real = pd.DataFrame(endog_real)
 endog_real.columns = [f"var_{i}" for i in range(endog_real.shape[1])]
 
-
 def add_fixture_to_dict(my_dict, string_fixture):
     my_dict[string_fixture] = [lf(string_fixture)]
     return my_dict
-- 
GitLab


From cda1ef413b9b26746b843f882b1f4a82cc571184 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Tue, 17 Oct 2023 11:07:58 +0200
Subject: [PATCH 42/68] add tests for the zi model and implement some features
 to pass the tests for the zi

---
 pyPLNmodels/_utils.py |   2 +-
 pyPLNmodels/elbos.py  |   7 +-
 pyPLNmodels/models.py | 144 +++++++++++++++++++++++++++++++++++++++---
 tests/conftest.py     |   6 ++
 tests/test_common.py  |   5 +-
 tests/test_zi.py      |  73 +++++++++++++++++++++
 6 files changed, 222 insertions(+), 15 deletions(-)
 create mode 100644 tests/test_zi.py

diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 802e5f2e..35b02806 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -1081,6 +1081,7 @@ def d_h_x3(a, x, y, dirac):
     rho_prime = rho * (1 - rho)
     return -rho_prime * d_varpsi_x2(x, y) / phi(x, y)
 
+
 def vec_to_mat(C, p, q):
     c = torch.zeros(p, q)
     c[torch.tril_indices(p, q, offset=0).tolist()] = C
@@ -1092,4 +1093,3 @@ def mat_to_vec(matc, p, q):
     tril = torch.tril(matc)
     # tril = matc.reshape(-1,1).squeeze()
     return tril[torch.tril_indices(p, q, offset=0).tolist()]
-
diff --git a/pyPLNmodels/elbos.py b/pyPLNmodels/elbos.py
index 585b423d..5a56bc3d 100644
--- a/pyPLNmodels/elbos.py
+++ b/pyPLNmodels/elbos.py
@@ -247,7 +247,6 @@ def elbo_zi_pln(
         1 + torch.exp(x_coef_inflation)
     )
 
-
     log_diag = torch.log(torch.diag(covariance))
     log_S_term = torch.sum(
         torch.multiply(1 - latent_prob, torch.log(torch.abs(latent_sqrt_var))), axis=0
@@ -265,12 +264,12 @@ def elbo_zi_pln(
     )
     new = torch.sum(latent_prob * un_moins_rho * (m_minus_xb**2), axis=0)
     K = sum_un_moins_rho_s2 + diag_sig_sum_rho + new
-    inside_f =-1 / 2 *  torch.diag(Omega) * K
+    inside_f = -1 / 2 * torch.diag(Omega) * K
     first = torch.sum(inside_a + inside_c + inside_e)
     second = torch.sum(inside_b)
     _, logdet = torch.slogdet(components)
-    second -= n_samples *logdet
+    second -= n_samples * logdet
     third = torch.sum(inside_d + inside_f)
-    third += n_samples*dim/2
+    third += n_samples * dim / 2
     res = first + second + third
     return res
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index db723dba..81d732b2 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -668,7 +668,6 @@ class _model(ABC):
         )
         plt.show()
 
-
     @property
     def _latent_var(self) -> torch.Tensor:
         """
@@ -1915,8 +1914,8 @@ class Pln(_model):
         return self.latent_mean.detach()
 
     @_add_doc(
-            _model,
-            example="""
+        _model,
+        example="""
             >>> from pyPLNmodels import Pln, get_real_count_data
             >>> endog, labels = get_real_count_data(return_labels = True)
             >>> pln = Pln(endog,add_const = True)
@@ -1924,8 +1923,8 @@ class Pln(_model):
             >>> elbo = pln.compute_elbo()
             >>> print("elbo", elbo)
             >>> print("loglike/n", pln.loglike/pln.n_samples)
-            """
-            )
+            """,
+    )
     def compute_elbo(self):
         return profiled_elbo_pln(
             self._endog,
@@ -1934,6 +1933,7 @@ class Pln(_model):
             self._latent_mean,
             self._latent_sqrt_var,
         )
+
     @_add_doc(_model)
     def _compute_elbo_b(self):
         return profiled_elbo_pln(
@@ -1943,6 +1943,7 @@ class Pln(_model):
             self._latent_mean_b,
             self._latent_sqrt_var_b,
         )
+
     @_add_doc(_model)
     def _smart_init_model_parameters(self):
         pass
@@ -1952,6 +1953,7 @@ class Pln(_model):
     def _random_init_model_parameters(self):
         pass
         # no model parameters since we are doing a profiled ELBO
+
     @_add_doc(_model)
     def _smart_init_latent_parameters(self):
         self._random_init_latent_sqrt_var()
@@ -1969,6 +1971,7 @@ class Pln(_model):
     def _list_of_parameters_needing_gradient(self):
         return [self._latent_mean, self._latent_sqrt_var]
 
+
 class PlnPCAcollection:
     """
     A collection where value q corresponds to a PlnPCA object with rank q.
@@ -3451,6 +3454,7 @@ class ZIPln(_model):
         >>> zi.fit()
         >>> print(zi)
         """
+        self._use_closed_form_prob = use_closed_form_prob
         if exog is None and add_const is False:
             msg = "No covariates has been given. An intercept is added since "
             msg += "a ZIPln must have at least an intercept."
@@ -3465,7 +3469,6 @@ class ZIPln(_model):
             take_log_offsets=take_log_offsets,
             add_const=add_const,
         )
-        self._use_closed_form_prob = use_closed_form_prob
 
     def _extract_batch(self, batch):
         super()._extract_batch(batch)
@@ -3490,7 +3493,7 @@ class ZIPln(_model):
         offsets_formula: str = "logsum",
         dict_initialization: Optional[Dict[str, torch.Tensor]] = None,
         take_log_offsets: bool = False,
-        use_closed_form_prob: bool = True,
+        use_closed_form_prob: bool = False,
     ):
         """
         Create a ZIPln instance from a formula and data.
@@ -3590,7 +3593,7 @@ class ZIPln(_model):
     def _random_init_model_parameters(self):
         self._coef_inflation = torch.randn(self.nb_cov, self.dim)
         self._coef = torch.randn(self.nb_cov, self.dim)
-        self._components = torch.randn(self.nb_cov, self.dim)
+        self._components = torch.randn(self.dim, self.dim)
 
     # should change the good initialization for _coef_inflation
     def _smart_init_model_parameters(self):
@@ -3656,7 +3659,7 @@ class ZIPln(_model):
     @property
     def coef_inflation(self):
         """
-        Property representing the coefficients of the zero inflated model.
+        Property representing the coefficients of the inflation.
 
         Returns
         -------
@@ -3665,6 +3668,54 @@ class ZIPln(_model):
         """
         return self._cpu_attribute_or_none("_coef_inflation")
 
+    @coef_inflation.setter
+    @_array2tensor
+    def coef_inflation(
+        self, coef_inflation: Union[torch.Tensor, np.ndarray, pd.DataFrame]
+    ):
+        """
+        Setter for the coef_inflation property.
+
+        Parameters
+        ----------
+        coef : Union[torch.Tensor, np.ndarray, pd.DataFrame]
+            The coefficients.
+
+        Raises
+        ------
+        ValueError
+            If the shape of the coef is incorrect.
+        """
+        if coef_inflation.shape != (self.nb_cov, self.dim):
+            raise ValueError(
+                f"Wrong shape for the coef. Expected {(self.nb_cov, self.dim)}, got {coef_inflation.shape}"
+            )
+        self._coef_inflation = coef_inflation
+
+    @_model.latent_sqrt_var.setter
+    @_array2tensor
+    def latent_sqrt_var(
+        self, latent_sqrt_var: Union[torch.Tensor, np.ndarray, pd.DataFrame]
+    ):
+        """
+        Setter for the latent variance property.
+
+        Parameters
+        ----------
+        latent_sqrt_var : Union[torch.Tensor, np.ndarray, pd.DataFrame]
+            The latent square root of the variance.
+
+        Raises
+        ------
+        ValueError
+            If the shape of the latent variance is incorrect.
+        """
+        if latent_sqrt_var.shape != (self.n_samples, self.dim):
+            raise ValueError(
+                f"Wrong shape. Expected {self.n_samples, self.dim}, got {latent_sqrt_var.shape}"
+            )
+        self._latent_sqrt_var = latent_sqrt_var
+
     def _update_parameters(self):
         super()._update_parameters()
         self._project_latent_prob()
@@ -3695,10 +3746,51 @@ class ZIPln(_model):
         """
         return self._cpu_attribute_or_none("_covariance")
 
+    @components.setter
+    @_array2tensor
+    def components(self, components: torch.Tensor):
+        """
+        Setter for the components.
+
+        Parameters
+        ----------
+        components : torch.Tensor
+            The components to set.
+
+        Raises
+        ------
+        ValueError
+            If the components have an invalid shape.
+        """
+        if components.shape != (self.dim, self.dim):
+            raise ValueError(
+                f"Wrong shape. Expected {self.dim, self.dim}, got {components.shape}"
+            )
+        self._components = components
+
     @property
     def latent_prob(self):
         return self._cpu_attribute_or_none("_latent_prob")
 
+    @latent_prob.setter
+    @_array2tensor
+    def latent_prob(self, latent_prob: Union[torch.Tensor, np.ndarray, pd.DataFrame]):
+        if self._use_closed_form_prob is True:
+            raise ValueError(
+                "Can not set the latent prob when the closed form is used."
+            )
+        if latent_prob.shape != (self.n_samples, self.dim):
+            raise ValueError(
+                f"Wrong shape. Expected {self.n_samples, self.dim}, got {latent_prob.shape}"
+            )
+        if torch.max(latent_prob) > 1 or torch.min(latent_prob) < 0:
+            raise ValueError(f"Wrong value. All values should be between 0 and 1.")
+        if torch.norm(latent_prob * (self._endog == 0) - latent_prob) > 0.00000001:
+            raise ValueError(
+                "You can not assign non zeros inflation probabilities to non zero counts."
+            )
+        self._latent_prob = latent_prob
+
     @property
     def closed_formula_latent_prob(self):
         """
@@ -3781,6 +3873,40 @@ class ZIPln(_model):
             "coef_inflation": self.coef_inflation,
         }
 
+    def predict_prob_inflation(
+        self, exog: Union[torch.Tensor, np.ndarray, pd.DataFrame]
+    ):
+        """
+        Method for estimating the probability of a zero coming from the zero inflated component.
+
+        Parameters
+        ----------
+        exog : Union[torch.Tensor, np.ndarray, pd.DataFrame]
+            The exog.
+
+        Returns
+        -------
+        torch.Tensor
+            The predicted values.
+
+        Raises
+        ------
+        RuntimeError
+            If the shape of the exog is incorrect.
+
+        Notes
+        -----
+        - The mean sigmoid(exog @ coef_inflation) is returned.
+        - `exog` should have the shape `(_, nb_cov)`, where `nb_cov` is the number of exog variables.
+        """
+        if exog is not None and self.nb_cov == 0:
+            raise AttributeError("No exog in the model, can't predict")
+        if exog.shape[-1] != self.nb_cov:
+            error_string = f"X has wrong shape ({exog.shape}). Should"
+            error_string += f" be (_, {self.nb_cov})."
+            raise RuntimeError(error_string)
+        return torch.sigmoid(exog @ self.coef_inflation)
+
     @property
     @_add_doc(_model)
     def latent_parameters(self):
diff --git a/tests/conftest.py b/tests/conftest.py
index 85ca2aaf..93e50ab5 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -37,6 +37,7 @@ endog_real = data_real["endog"]
 endog_real = pd.DataFrame(endog_real)
 endog_real.columns = [f"var_{i}" for i in range(endog_real.shape[1])]
 
+
 def add_fixture_to_dict(my_dict, string_fixture):
     my_dict[string_fixture] = [lf(string_fixture)]
     return my_dict
@@ -219,6 +220,11 @@ dict_fixtures = add_list_of_fixture_to_dict(
     dict_fixtures, "sim_model_0cov", sim_model_0cov
 )
 
+sim_model_0cov_fitted_and_loaded = sim_model_0cov_fitted + sim_model_0cov_loaded
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "sim_model_0cov_fitted_and_loaded", sim_model_0cov_fitted_and_loaded
+)
+
 
 @pytest.fixture(params=params)
 @cache
diff --git a/tests/test_common.py b/tests/test_common.py
index cec97a72..0aa81d54 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -85,8 +85,11 @@ def test_find_right_coef(simulated_fitted_any_model):
 @filter_models(single_models)
 def test_fail_count_setter(model):
     wrong_endog = torch.randint(size=(10, 5), low=0, high=10)
-    with pytest.raises(Exception):
+    negative_endog = -model._endog
+    with pytest.raises(ValueError):
         model.endog = wrong_endog
+    with pytest.raises(ValueError):
+        model.endog = negative_endog
 
 
 @pytest.mark.parametrize("instance", dict_fixtures["instances"])
diff --git a/tests/test_zi.py b/tests/test_zi.py
new file mode 100644
index 00000000..4ba5af04
--- /dev/null
+++ b/tests/test_zi.py
@@ -0,0 +1,73 @@
+import pytest
+import torch
+
+from pyPLNmodels import get_simulation_parameters, sample_pln, ZIPln
+from tests.conftest import dict_fixtures
+from tests.utils import filter_models, MSE
+
+
+@pytest.mark.parametrize("zi", dict_fixtures["loaded_and_fitted_model"])
+@filter_models(["ZIPln"])
+def test_properties(zi):
+    assert hasattr(zi, "latent_prob")
+    assert hasattr(zi, "coef_inflation")
+
+
+@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_model"])
+@filter_models(["ZIPln"])
+def test_predict(model):
+    X = torch.randn((model.n_samples, model.nb_cov))
+    prediction = model.predict(X)
+    expected = X @ model.coef
+    assert torch.all(torch.eq(expected, prediction))
+
+
+@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_model"])
+@filter_models(["ZIPln"])
+def test_predict_prob(model):
+    X = torch.randn((model.n_samples, model.nb_cov))
+    prediction = model.predict_prob_inflation(X)
+    expected = torch.sigmoid(X @ model.coef_inflation)
+    assert torch.all(torch.eq(expected, prediction))
+
+
+@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_model"])
+@filter_models(["ZIPln"])
+def test_fail_predict_prob(model):
+    X1 = torch.randn((model.n_samples, model.nb_cov + 1))
+    X2 = torch.randn((model.n_samples, model.nb_cov - 1))
+    with pytest.raises(RuntimeError):
+        model.predict_prob_inflation(X1)
+    with pytest.raises(RuntimeError):
+        model.predict_prob_inflation(X2)
+
+
+@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_model"])
+@filter_models(["ZIPln"])
+def test_fail_predict(model):
+    X1 = torch.randn((model.n_samples, model.nb_cov + 1))
+    X2 = torch.randn((model.n_samples, model.nb_cov - 1))
+    with pytest.raises(RuntimeError):
+        model.predict(X1)
+    with pytest.raises(RuntimeError):
+        model.predict(X2)
+
+
+@pytest.mark.parametrize("model", dict_fixtures["sim_model_0cov_fitted_and_loaded"])
+@filter_models(["ZIPln"])
+def test_no_exog_not_possible(model):
+    assert model.nb_cov == 1
+    assert model._coef_inflation.shape[0] == 1
+
+
+def test_find_right_covariance_and_coef():
+    pln_param = get_simulation_parameters(
+        n_samples=300, dim=50, nb_cov=2, rank=5, add_const=True
+    )
+    pln_param._coef += 5
+    endog = sample_pln(pln_param, seed=0, return_latent=False)
+    zi = ZIPln(endog, exog=pln_param.exog, offsets=pln_param.offsets)
+    zi.fit()
+    mse_covariance = MSE(zi.covariance - pln_param.covariance)
+    mse_coef = MSE(zi.coef)
+    assert mse_covariance < 0.5
-- 
GitLab


From 5a30f0f0b14ffdd34202c914846c8a949ddc2f50 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Tue, 17 Oct 2023 19:14:29 +0200
Subject: [PATCH 43/68] write the tests for the zi, the batch size and rewrite
 the tests for the collection. Need to pass the tests on the getting started.

---
 .gitignore                     |   1 +
 pyPLNmodels/_utils.py          |  76 +++++++++--
 pyPLNmodels/models.py          | 142 +++++++++++---------
 tests/conftest.py              |  10 +-
 tests/test_common.py           |  16 ++-
 tests/test_pln_full.py         |   2 +-
 tests/test_plnpcacollection.py |  47 +++++--
 tests/test_setters.py          | 231 ++++++++++++++++++---------------
 tests/test_viz.py              | 108 +++++++--------
 tests/test_zi.py               |  72 ++++++++--
 10 files changed, 452 insertions(+), 253 deletions(-)

diff --git a/.gitignore b/.gitignore
index a95ada79..c00f1395 100644
--- a/.gitignore
+++ b/.gitignore
@@ -159,3 +159,4 @@ paper/*
 tests/test_models*
 tests/test_load*
 tests/test_readme*
+Getting_started.py
diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 35b02806..1082259e 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -494,8 +494,12 @@ def _get_simulation_components(dim: int, rank: int) -> torch.Tensor:
     return components.to("cpu")
 
 
-def _get_simulation_coef_cov_offsets(
-    n_samples: int, nb_cov: int, dim: int, add_const: bool
+def _get_simulation_coef_cov_offsets_coefzi(
+    n_samples: int,
+    nb_cov: int,
+    dim: int,
+    add_const: bool,
+    zero_inflated: bool,
 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     """
     Get offsets, covariance coefficients with right shapes.
@@ -513,6 +517,8 @@ def _get_simulation_coef_cov_offsets(
         Dimension required of the data.
     add_const : bool, optional
         If True, will add a vector of ones in the exog.
+    zero_inflated : bool
+        If True, will return a zero_inflated coefficient.
 
     Returns
     -------
@@ -537,14 +543,23 @@ def _get_simulation_coef_cov_offsets(
         if add_const is True:
             exog = torch.cat((exog, torch.ones(n_samples, 1)), axis=1)
     if exog is None:
+        if zero_inflated is True:
+            msg = "Can not instantiate a zero inflate model without covariates."
+            msg += " Please give at least an intercept by setting add_const to True"
+            raise ValueError(msg)
         coef = None
+        coef_inflation = None
     else:
         coef = torch.randn(exog.shape[1], dim, device="cpu")
+        if zero_inflated is True:
+            coef_inflation = torch.randn(exog.shape[1], dim, device="cpu")
+        else:
+            coef_inflation = None
     offsets = torch.randint(
         low=0, high=2, size=(n_samples, dim), dtype=torch.float64, device="cpu"
     )
     torch.random.set_rng_state(prev_state)
-    return coef, exog, offsets
+    return coef, exog, offsets, coef_inflation
 
 
 class PlnParameters:
@@ -555,7 +570,7 @@ class PlnParameters:
         coef: Union[torch.Tensor, np.ndarray, pd.DataFrame],
         exog: Union[torch.Tensor, np.ndarray, pd.DataFrame],
         offsets: Union[torch.Tensor, np.ndarray, pd.DataFrame],
-        coef_inflation=None,
+        coef_inflation: Union[torch.Tensor, np.ndarray, pd.DataFrame, None] = None,
     ):
         """
         Instantiate all the needed parameters to sample from the PLN model.
@@ -570,9 +585,8 @@ class PlnParameters:
             Covariates, size (n, d) or None
         offsets : : Union[torch.Tensor, np.ndarray, pd.DataFrame](keyword-only)
             Offset, size (n, p)
-        _coef_inflation : : Union[torch.Tensor, np.ndarray, pd.DataFrame] or None, optional(keyword-only)
+        coef_inflation : Union[torch.Tensor, np.ndarray, pd.DataFrame, None], optional(keyword-only)
             Coefficient for zero-inflation model, size (d, p) or None. Default is None.
-
         """
         self._components = _format_data(components)
         self._coef = _format_data(coef)
@@ -713,6 +727,7 @@ def get_simulation_parameters(
     nb_cov: int = 1,
     rank: int = 5,
     add_const: bool = True,
+    zero_inflated: bool = False,
 ) -> PlnParameters:
     """
     Generate simulation parameters for a Poisson-lognormal model.
@@ -731,18 +746,26 @@ def get_simulation_parameters(
             The rank of the data components, by default 5.
         add_const : bool, optional(keyword-only)
             If True, will add a vector of ones in the exog.
+        zero_inflated : bool, optional(keyword-only)
+            If True, the model will be zero inflated.
+            Default is False.
 
     Returns
     -------
         PlnParameters
             The generated simulation parameters.
-
     """
-    coef, exog, offsets = _get_simulation_coef_cov_offsets(
-        n_samples, nb_cov, dim, add_const
+    coef, exog, offsets, coef_inflation = _get_simulation_coef_cov_offsets_coefzi(
+        n_samples, nb_cov, dim, add_const, zero_inflated
     )
     components = _get_simulation_components(dim, rank)
-    return PlnParameters(components=components, coef=coef, exog=exog, offsets=offsets)
+    return PlnParameters(
+        components=components,
+        coef=coef,
+        exog=exog,
+        offsets=offsets,
+        coef_inflation=coef_inflation,
+    )
 
 
 def get_simulated_count_data(
@@ -753,6 +776,7 @@ def get_simulated_count_data(
     nb_cov: int = 1,
     return_true_param: bool = False,
     add_const: bool = True,
+    zero_inflated=False,
     seed: int = 0,
 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     """
@@ -772,19 +796,45 @@ def get_simulated_count_data(
         Number of exog, by default 1.
     return_true_param : bool, optional(keyword-only)
         Whether to return the true parameters of the model, by default False.
+    zero_inflated: bool, optional(keyword-only)
+        Whether to use a zero inflated model or not.
+        Default to False.
     seed : int, optional(keyword-only)
         Seed value for random number generation, by default 0.
 
     Returns
     -------
-    Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
-        Tuple containing endog, exog, and offsets.
+    if return_true_param is False:
+        Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+            Tuple containing endog, exog, and offsets.
+    else:
+        if zero_inflated is True:
+            Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
+                Tuple containing endog, exog, offsets, covariance, coef, coef_inflation .
+        else:
+            Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
+                Tuple containing endog, exog, offsets, covariance, coef.
+
     """
     pln_param = get_simulation_parameters(
-        n_samples=n_samples, dim=dim, nb_cov=nb_cov, rank=rank, add_const=add_const
+        n_samples=n_samples,
+        dim=dim,
+        nb_cov=nb_cov,
+        rank=rank,
+        add_const=add_const,
+        zero_inflated=zero_inflated,
     )
     endog = sample_pln(pln_param, seed=seed, return_latent=False)
     if return_true_param is True:
+        if zero_inflated is True:
+            return (
+                endog,
+                pln_param.exog,
+                pln_param.offsets,
+                pln_param.covariance,
+                pln_param.coef,
+                pln_param.coef_inflation,
+            )
         return (
             endog,
             pln_param.exog,
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 81d732b2..01f9353b 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -99,10 +99,6 @@ class _model(ABC):
             Whether to take the log of offsets. Defaults to False.
         add_const: bool, optional(keyword-only)
             Whether to add a column of one in the exog. Defaults to True.
-        Raises
-        ------
-        ValueError
-            If the batch_size is greater than the number of samples, or not int.
         """
         (
             self._endog,
@@ -116,6 +112,7 @@ class _model(ABC):
         self._criterion_args = _CriterionArgs()
         if dict_initialization is not None:
             self._set_init_parameters(dict_initialization)
+        self._dirac = self._endog == 0
 
     @classmethod
     def from_formula(
@@ -253,7 +250,10 @@ class _model(ABC):
 
     def _handle_batch_size(self, batch_size):
         if batch_size is None:
-            batch_size = self.n_samples
+            if hasattr(self, "batch_size"):
+                batch_size = self.batch_size
+            else:
+                batch_size = self.n_samples
         if batch_size > self.n_samples:
             raise ValueError(
                 f"batch_size ({batch_size}) can not be greater than the number of samples ({self.n_samples})"
@@ -385,6 +385,10 @@ class _model(ABC):
         batch_size: int, optional(keyword-only)
             The batch size when optimizing the elbo. If None,
             batch gradient descent will be performed (i.e. batch_size = n_samples).
+        Raises
+        ------
+        ValueError
+            If the batch_size is greater than the number of samples, or not int.
         """
         self._print_beginning_message()
         self._beginning_time = time.time()
@@ -531,7 +535,7 @@ class _model(ABC):
 
     def sk_PCA(self, n_components=None):
         """
-        Perform PCA on the latent variables.
+        Perform the scikit-learn PCA on the latent variables.
 
         Parameters
         ----------
@@ -553,8 +557,9 @@ class _model(ABC):
             raise ValueError(
                 f"You ask more components ({n_components}) than variables ({self.dim})"
             )
+        latent_variables = self.transform()
         pca = PCA(n_components=n_components)
-        pca.fit(self.latent_variables.cpu())
+        pca.fit(latent_variables.cpu())
         return pca
 
     @property
@@ -595,9 +600,9 @@ class _model(ABC):
                 f"You ask more components ({n_components}) than variables ({self.dim})"
             )
         pca = self.sk_PCA(n_components=n_components)
-        proj_variables = pca.transform(self.latent_variables)
+        latent_variables = self.transform()
+        proj_variables = pca.transform(latent_variables)
         components = torch.from_numpy(pca.components_)
-
         labels = {
             str(i): f"PC{i+1}: {np.round(pca.explained_variance_ratio_*100, 1)[i]}%"
             for i in range(n_components)
@@ -655,7 +660,7 @@ class _model(ABC):
 
         n_components = 2
         pca = self.sk_PCA(n_components=n_components)
-        variables = self.latent_variables
+        variables = self.transform()
         proj_variables = pca.transform(variables)
         ## the package is not correctly printing the variance ratio
         figure, correlation_matrix = plot_pca_correlation_graph(
@@ -1808,51 +1813,6 @@ class Pln(_model):
                 1 / 2 * torch.ones((self.n_samples, self.dim)).to(DEVICE)
             )
 
-    @property
-    @_add_doc(
-        _model,
-        example="""
-        >>> from pyPLNmodels import Pln, get_real_count_data
-        >>> endog, labels = get_real_count_data(return_labels = True)
-        >>> pln = Pln(endog,add_const = True)
-        >>> pln.fit()
-        >>> print(pln.latent_variables.shape)
-        """,
-    )
-    def latent_variables(self):
-        return self.latent_mean.detach()
-
-    @_add_doc(
-        _model,
-        example="""
-            >>> from pyPLNmodels import Pln, get_real_count_data
-            >>> endog, labels = get_real_count_data(return_labels = True)
-            >>> pln = Pln(endog,add_const = True)
-            >>> pln.fit()
-            >>> elbo = pln.compute_elbo()
-            >>> print("elbo", elbo)
-            >>> print("loglike/n", pln.loglike/pln.n_samples)
-            """,
-    )
-    def compute_elbo(self):
-        return profiled_elbo_pln(
-            self._endog,
-            self._exog,
-            self._offsets,
-            self._latent_mean,
-            self._latent_sqrt_var,
-        )
-
-    @_add_doc(_model)
-    def _compute_elbo_b(self):
-        return profiled_elbo_pln(
-            self._endog_b,
-            self._exog_b,
-            self._offsets_b,
-            self._latent_mean_b,
-            self._latent_sqrt_var_b,
-        )
-
     @_add_doc(_model)
     def _smart_init_model_parameters(self):
         pass
@@ -2057,7 +2017,7 @@ class PlnPCAcollection:
             endog, exog, offsets, offsets_formula, take_log_offsets, add_const
         )
         self._fitted = False
-        self._init_models(ranks, dict_of_dict_initialization)
+        self._init_models(ranks, dict_of_dict_initialization, add_const=add_const)
 
     @classmethod
     def from_formula(
@@ -2129,6 +2089,18 @@ class PlnPCAcollection:
         """
         return self[self.ranks[0]].exog
 
+    @property
+    def batch_size(self) -> torch.Tensor:
+        """
+        Property representing the batch_size.
+
+        Returns
+        -------
+        torch.Tensor
+            The batch_size.
+        """
+        return self[self.ranks[0]].batch_size
+
     @property
     def endog(self) -> torch.Tensor:
         """
@@ -2203,6 +2175,19 @@ class PlnPCAcollection:
         for model in self.values():
             model.endog = endog
 
+    @batch_size.setter
+    def batch_size(self, batch_size: int):
+        """
+        Setter for the batch_size property.
+
+        Parameters
+        ----------
+        batch_size : int
+            The batch size.
+        """
+        for model in self.values():
+            model.batch_size = batch_size
+
     @coef.setter
     @_array2tensor
     def coef(self, coef: Union[torch.Tensor, np.ndarray, pd.DataFrame]):
@@ -2258,7 +2243,10 @@ class PlnPCAcollection:
             model.offsets = offsets
 
     def _init_models(
-        self, ranks: Iterable[int], dict_of_dict_initialization: Optional[dict]
+        self,
+        ranks: Iterable[int],
+        dict_of_dict_initialization: Optional[dict],
+        add_const: bool,
     ):
         """
         Method for initializing the models.
@@ -2282,6 +2270,7 @@ class PlnPCAcollection:
                         offsets=self._offsets,
                         rank=rank,
                         dict_initialization=dict_initialization,
+                        add_const=add_const,
                     )
                 else:
                     raise TypeError(
@@ -2389,6 +2378,10 @@ class PlnPCAcollection:
         batch_size: int, optional(keyword-only)
             The batch size when optimizing the elbo. If None,
             batch gradient descent will be performed (i.e. batch_size = n_samples).
+        Raises
+        ------
+        ValueError
+            If the batch_size is greater than the number of samples, or not int.
         """
         self._print_beginning_message()
         for i in range(len(self.values())):
@@ -3606,7 +3599,6 @@ class ZIPln(_model):
             self._coef_inflation = torch.randn(self.nb_cov, self.dim)
 
     def _random_init_latent_parameters(self):
-        self._dirac = self._endog == 0
         self._latent_mean = torch.randn(self.n_samples, self.dim)
         self._latent_sqrt_var = torch.randn(self.n_samples, self.dim)
         self._latent_prob = (
@@ -3621,6 +3613,17 @@ class ZIPln(_model):
     def _covariance(self):
         return self._components @ (self._components.T)
 
+    def _get_max_components(self):
+        """
+        Method for getting the maximum number of components.
+
+        Returns
+        -------
+        int
+            The maximum number of components.
+        """
+        return self.dim
+
     @property
     def components(self) -> torch.Tensor:
         """
@@ -3656,6 +3659,27 @@ class ZIPln(_model):
         """
         return self.latent_mean, self.latent_prob
 
+    def transform(self, return_latent_prob=False):
+        """
+        Method for transforming the endog. Can be seen as a normalization of the endog.
+
+        Parameters
+        ----------
+        return_latent_prob: bool, optional
+            Wheter to return or not the latent_probability of zero inflation.
+        Returns
+        -------
+        The latent mean if `return_latent_prob` is False and (latent_mean, latent_prob) else.
+        """
+        if return_latent_prob is True:
+            return self.latent_variables
+        return self.latent_mean
+
+    def _endog_predictions(self):
+        return torch.exp(
+            self._offsets + self._latent_mean + 1 / 2 * self._latent_sqrt_var**2
+        ) * (1 - self._latent_prob)
+
     @property
     def coef_inflation(self):
         """
@@ -3730,7 +3754,7 @@ class ZIPln(_model):
                     self._latent_prob_b, torch.tensor([0]), out=self._latent_prob_b
                 )
                 self._latent_prob_b = torch.minimum(
-                    self._latent_prob, torch.tensor([1]), out=self._latent_prob_b
+                    self._latent_prob_b, torch.tensor([1]), out=self._latent_prob_b
                 )
                 self._latent_prob_b *= self._dirac_b
 
diff --git a/tests/conftest.py b/tests/conftest.py
index 93e50ab5..22ea4307 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,6 +1,5 @@
 import sys
 import glob
-from functools import singledispatch
 import pytest
 import torch
 from pytest_lazyfixture import lazy_fixture as lf
@@ -8,6 +7,7 @@ import pandas as pd
 
 from pyPLNmodels import load_model, load_plnpcacollection
 from pyPLNmodels.models import Pln, PlnPCA, PlnPCAcollection, ZIPln
+from pyPLNmodels import get_simulated_count_data
 
 
 sys.path.append("../")
@@ -206,6 +206,7 @@ dict_fixtures = add_list_of_fixture_to_dict(
     dict_fixtures, "sim_model_0cov_fitted", sim_model_0cov_fitted
 )
 
+
 sim_model_0cov_loaded = [
     "simulated_loaded_model_0cov_array",
     "simulated_loaded_model_0cov_formula",
@@ -285,12 +286,17 @@ sim_model_2cov_instance = [
     "simulated_model_2cov_array",
     "simulated_model_2cov_formula",
 ]
+sim_model_instance = sim_model_0cov_instance + sim_model_2cov_instance
+
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "sim_model_instance", sim_model_instance
+)
 instances = sim_model_2cov_instance + instances
 
+
 dict_fixtures = add_list_of_fixture_to_dict(
     dict_fixtures, "sim_model_2cov_instance", sim_model_2cov_instance
 )
-
 sim_model_2cov_fitted = [
     "simulated_fitted_model_2cov_array",
     "simulated_fitted_model_2cov_formula",
diff --git a/tests/test_common.py b/tests/test_common.py
index 0aa81d54..6cba2cf6 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -8,8 +8,8 @@ from tests.utils import MSE, filter_models
 
 from tests.import_data import true_sim_0cov, true_sim_2cov, endog_real
 
-single_models = ["Pln", "PlnPCA", "ZIPln"]
 pln_and_plnpca = ["Pln", "PlnPCA"]
+single_models = ["Pln", "PlnPCA", "ZIPln"]
 
 
 @pytest.mark.parametrize("any_model", dict_fixtures["loaded_and_fitted_model"])
@@ -108,3 +108,17 @@ def test_fail_wrong_exog_prediction(model):
     X = torch.randn(model.n_samples, model.nb_cov + 1)
     with pytest.raises(Exception):
         model.predict(X)
+
+
+@pytest.mark.parametrize("model", dict_fixtures["sim_model_instance"])
+@filter_models(pln_and_plnpca)
+def test_batch(model):
+    model.fit(batch_size=20)
+    print(model)
+    model.show()
+    if model.nb_cov == 2:
+        true_coef = true_sim_2cov["beta"]
+        mse_coef = MSE(model.coef - true_coef)
+        assert mse_coef < 0.1
+    elif model.nb_cov == 0:
+        assert model.coef is None
diff --git a/tests/test_pln_full.py b/tests/test_pln_full.py
index 1115e1ec..6a8ced3a 100644
--- a/tests/test_pln_full.py
+++ b/tests/test_pln_full.py
@@ -13,5 +13,5 @@ def test_number_of_iterations_pln_full(fitted_pln):
 
 @pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_model"])
 @filter_models(["Pln"])
-def test_latent_var_full(pln):
+def test_latent_variables(pln):
     assert pln.transform().shape == pln.endog.shape
diff --git a/tests/test_plnpcacollection.py b/tests/test_plnpcacollection.py
index 6634f2d2..19b49b18 100644
--- a/tests/test_plnpcacollection.py
+++ b/tests/test_plnpcacollection.py
@@ -6,16 +6,17 @@ import numpy as np
 
 from tests.conftest import dict_fixtures
 from tests.utils import MSE, filter_models
+from tests.import_data import true_sim_0cov, true_sim_2cov
 
 
-@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"])
+@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_model"])
 @filter_models(["PlnPCAcollection"])
 def test_best_model(plnpca):
     best_model = plnpca.best_model()
     print(best_model)
 
 
-@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"])
+@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_model"])
 @filter_models(["PlnPCAcollection"])
 def test_projected_variables(plnpca):
     best_model = plnpca.best_model()
@@ -23,21 +24,20 @@ def test_projected_variables(plnpca):
     assert plv.shape[0] == best_model.n_samples and plv.shape[1] == best_model.rank
 
 
-@pytest.mark.parametrize("fitted_pln", dict_fixtures["fitted_pln"])
-@filter_models(["PlnPCA"])
-def test_number_of_iterations_plnpca(fitted_pln):
-    nb_iterations = len(fitted_pln._elbos_list)
-    assert 100 < nb_iterations < 5000
+@pytest.mark.parametrize("plnpca", dict_fixtures["sim_model_instance"])
+@filter_models(["PlnPCAcollection"])
+def test_right_nbcov(plnpca):
+    assert plnpca.nb_cov == 0 or plnpca.nb_cov == 2
 
 
-@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"])
+@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_model"])
 @filter_models(["PlnPCA"])
 def test_latent_var_pca(plnpca):
     assert plnpca.transform(project=False).shape == plnpca.endog.shape
     assert plnpca.transform().shape == (plnpca.n_samples, plnpca.rank)
 
 
-@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"])
+@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_model"])
 @filter_models(["PlnPCAcollection"])
 def test_additional_methods_pca(plnpca):
     plnpca.show()
@@ -46,14 +46,14 @@ def test_additional_methods_pca(plnpca):
     plnpca.loglikes
 
 
-@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"])
+@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_model"])
 @filter_models(["PlnPCAcollection"])
 def test_wrong_criterion(plnpca):
     with pytest.raises(ValueError):
         plnpca.best_model("AIK")
 
 
-@pytest.mark.parametrize("collection", dict_fixtures["loaded_and_fitted_pln"])
+@pytest.mark.parametrize("collection", dict_fixtures["loaded_and_fitted_model"])
 @filter_models(["PlnPCAcollection"])
 def test_item(collection):
     print(collection[collection.ranks[0]])
@@ -62,3 +62,28 @@ def test_item(collection):
     assert collection.ranks[0] in collection
     assert collection.ranks[0] in list(collection.keys())
     collection.get(collection.ranks[0], None)
+
+
+@pytest.mark.parametrize("collection", dict_fixtures["sim_model_instance"])
+@filter_models(["PlnPCAcollection"])
+def test_batch(collection):
+    collection.fit(batch_size=20)
+    assert collection.nb_cov == 0 or collection.nb_cov == 2
+    if collection.nb_cov == 0:
+        true_covariance = true_sim_0cov["Sigma"]
+        for model in collection.values():
+            assert model.coef is None
+        true_coef = None
+    elif collection.nb_cov == 2:
+        true_covariance = true_sim_2cov["Sigma"]
+        true_coef = true_sim_2cov["beta"]
+    else:
+        raise ValueError(f"Not the right numbers of covariance({collection.nb_cov})")
+    for model in collection.values():
+        mse_covariance = MSE(model.covariance - true_covariance)
+        if true_coef is not None:
+            mse_coef = MSE(model.coef - true_coef)
+            assert mse_coef < 0.35
+        assert mse_covariance < 0.25
+    collection.fit()
+    assert collection.batch_size == collection.n_samples
diff --git a/tests/test_setters.py b/tests/test_setters.py
index eb7814d7..b3012548 100644
--- a/tests/test_setters.py
+++ b/tests/test_setters.py
@@ -5,148 +5,169 @@ import torch
 from tests.conftest import dict_fixtures
 from tests.utils import MSE, filter_models
 
-
-@pytest.mark.parametrize("pln", dict_fixtures["all_pln"])
-def test_data_setter_with_torch(pln):
-    pln.endog = pln.endog
-    pln.exog = pln.exog
-    pln.offsets = pln.offsets
-    pln.fit()
-
-
-@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"])
-@filter_models(["Pln", "PlnPCA"])
-def test_parameters_setter_with_torch(pln):
-    pln.latent_mean = pln.latent_mean
-    pln.latent_sqrt_var = pln.latent_sqrt_var
-    if pln._NAME != "Pln":
-        pln.coef = pln.coef
-    if pln._NAME == "PlnPCA":
-        pln.components = pln.components
-    pln.fit()
-
-
-@pytest.mark.parametrize("pln", dict_fixtures["all_pln"])
-def test_data_setter_with_numpy(pln):
-    np_endog = pln.endog.numpy()
-    if pln.exog is not None:
-        np_exog = pln.exog.numpy()
+single_models = ["Pln", "PlnPCA", "ZIPln"]
+
+
+@pytest.mark.parametrize("model", dict_fixtures["all_model"])
+def test_data_setter_with_torch(model):
+    model.endog = model.endog
+    model.exog = model.exog
+    model.offsets = model.offsets
+    model.fit()
+
+
+@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_model"])
+@filter_models(single_models)
+def test_parameters_setter_with_torch(model):
+    model.latent_mean = model.latent_mean
+    model.latent_sqrt_var = model.latent_sqrt_var
+    if model._NAME != "Pln":
+        model.coef = model.coef
+    if model._NAME == "PlnPCA" or model._NAME == "ZIPln":
+        model.components = model.components
+    if model._NAME == "ZIPln":
+        model.coef_inflation = model.coef_inflation
+    model.fit()
+
+
+@pytest.mark.parametrize("model", dict_fixtures["all_model"])
+def test_data_setter_with_numpy(model):
+    np_endog = model.endog.numpy()
+    if model.exog is not None:
+        np_exog = model.exog.numpy()
     else:
         np_exog = None
-    np_offsets = pln.offsets.numpy()
-    pln.endog = np_endog
-    pln.exog = np_exog
-    pln.offsets = np_offsets
-    pln.fit()
-
-
-@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"])
-@filter_models(["Pln", "PlnPCA"])
-def test_parameters_setter_with_numpy(pln):
-    np_latent_mean = pln.latent_mean.numpy()
-    np_latent_sqrt_var = pln.latent_sqrt_var.numpy()
-    if pln.coef is not None:
-        np_coef = pln.coef.numpy()
+    np_offsets = model.offsets.numpy()
+    model.endog = np_endog
+    model.exog = np_exog
+    model.offsets = np_offsets
+    model.fit()
+
+
+@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_model"])
+@filter_models(single_models)
+def test_parameters_setter_with_numpy(model):
+    np_latent_mean = model.latent_mean.numpy()
+    np_latent_sqrt_var = model.latent_sqrt_var.numpy()
+    if model.coef is not None:
+        np_coef = model.coef.numpy()
     else:
         np_coef = None
-    pln.latent_mean = np_latent_mean
-    pln.latent_sqrt_var = np_latent_sqrt_var
-    if pln._NAME != "Pln":
-        pln.coef = np_coef
-    if pln._NAME == "PlnPCA":
-        pln.components = pln.components.numpy()
-    pln.fit()
-
-
-@pytest.mark.parametrize("pln", dict_fixtures["all_pln"])
-def test_data_setter_with_pandas(pln):
-    pd_endog = pd.DataFrame(pln.endog.numpy())
-    if pln.exog is not None:
-        pd_exog = pd.DataFrame(pln.exog.numpy())
+    model.latent_mean = np_latent_mean
+    model.latent_sqrt_var = np_latent_sqrt_var
+    if model._NAME != "Pln":
+        model.coef = np_coef
+    if model._NAME == "PlnPCA" or model._NAME == "ZIPln":
+        model.components = model.components.numpy()
+    if model._NAME == "ZIPln":
+        model.coef_inflation = model.coef_inflation.numpy()
+    model.fit()
+
+
+@pytest.mark.parametrize("model", dict_fixtures["all_model"])
+def test_batch_size_setter(model):
+    model.batch_size = 20
+    model.fit(nb_max_iteration=3)
+    assert model.batch_size == 20
+
+
+@pytest.mark.parametrize("model", dict_fixtures["all_model"])
+def test_fail_batch_size_setter(model):
+    with pytest.raises(ValueError):
+        model.batch_size = model.n_samples + 1
+
+
+@pytest.mark.parametrize("model", dict_fixtures["all_model"])
+def test_data_setter_with_pandas(model):
+    pd_endog = pd.DataFrame(model.endog.numpy())
+    if model.exog is not None:
+        pd_exog = pd.DataFrame(model.exog.numpy())
     else:
         pd_exog = None
-    pd_offsets = pd.DataFrame(pln.offsets.numpy())
-    pln.endog = pd_endog
-    pln.exog = pd_exog
-    pln.offsets = pd_offsets
-    pln.fit()
-
-
-@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"])
-@filter_models(["Pln", "PlnPCA"])
-def test_parameters_setter_with_pandas(pln):
-    pd_latent_mean = pd.DataFrame(pln.latent_mean.numpy())
-    pd_latent_sqrt_var = pd.DataFrame(pln.latent_sqrt_var.numpy())
-    if pln.coef is not None:
-        pd_coef = pd.DataFrame(pln.coef.numpy())
+    pd_offsets = pd.DataFrame(model.offsets.numpy())
+    model.endog = pd_endog
+    model.exog = pd_exog
+    model.offsets = pd_offsets
+    model.fit()
+
+
+@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_model"])
+@filter_models(single_models)
+def test_parameters_setter_with_pandas(model):
+    pd_latent_mean = pd.DataFrame(model.latent_mean.numpy())
+    pd_latent_sqrt_var = pd.DataFrame(model.latent_sqrt_var.numpy())
+    if model.coef is not None:
+        pd_coef = pd.DataFrame(model.coef.numpy())
     else:
         pd_coef = None
-    pln.latent_mean = pd_latent_mean
-    pln.latent_sqrt_var = pd_latent_sqrt_var
-    if pln._NAME != "Pln":
-        pln.coef = pd_coef
-    if pln._NAME == "PlnPCA":
-        pln.components = pd.DataFrame(pln.components.numpy())
-    pln.fit()
-
-
-@pytest.mark.parametrize("pln", dict_fixtures["all_pln"])
-def test_fail_data_setter_with_torch(pln):
+    model.latent_mean = pd_latent_mean
+    model.latent_sqrt_var = pd_latent_sqrt_var
+    if model._NAME != "Pln":
+        model.coef = pd_coef
+    if model._NAME == "PlnPCA":
+        model.components = pd.DataFrame(model.components.numpy())
+    if model._NAME == "ZIPln":
+        model.coef_inflation = pd.DataFrame(model.coef_inflation.numpy())
+    model.fit()
+
+
+@pytest.mark.parametrize("model", dict_fixtures["all_model"])
+def test_fail_data_setter_with_torch(model):
     with pytest.raises(ValueError):
-        pln.endog = pln.endog - 100
+        model.endog = -model.endog
 
-    n, p = pln.endog.shape
-    if pln.exog is None:
+    n, p = model.endog.shape
+    if model.exog is None:
         d = 0
     else:
-        d = pln.exog.shape[-1]
+        d = model.exog.shape[-1]
     with pytest.raises(ValueError):
-        pln.endog = torch.zeros(n + 1, p)
+        model.endog = torch.zeros(n + 1, p)
     with pytest.raises(ValueError):
-        pln.endog = torch.zeros(n, p + 1)
+        model.endog = torch.zeros(n, p + 1)
 
     with pytest.raises(ValueError):
-        pln.exog = torch.zeros(n + 1, d)
+        model.exog = torch.zeros(n + 1, d)
 
     with pytest.raises(ValueError):
-        pln.offsets = torch.zeros(n + 1, p)
+        model.offsets = torch.zeros(n + 1, p)
 
     with pytest.raises(ValueError):
-        pln.offsets = torch.zeros(n, p + 1)
+        model.offsets = torch.zeros(n, p + 1)
 
 
-@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"])
-@filter_models(["Pln", "PlnPCA"])
-def test_fail_parameters_setter_with_torch(pln):
-    n, dim_latent = pln.latent_mean.shape
-    dim = pln.endog.shape[1]
+@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_model"])
+@filter_models(single_models)
+def test_fail_parameters_setter_with_torch(model):
+    n, dim_latent = model.latent_mean.shape
+    dim = model.endog.shape[1]
 
     with pytest.raises(ValueError):
-        pln.latent_mean = torch.zeros(n + 1, dim_latent)
+        model.latent_mean = torch.zeros(n + 1, dim_latent)
 
     with pytest.raises(ValueError):
-        pln.latent_mean = torch.zeros(n, dim_latent + 1)
+        model.latent_mean = torch.zeros(n, dim_latent + 1)
 
     with pytest.raises(ValueError):
-        pln.latent_sqrt_var = torch.zeros(n + 1, dim_latent)
+        model.latent_sqrt_var = torch.zeros(n + 1, dim_latent)
 
     with pytest.raises(ValueError):
-        pln.latent_sqrt_var = torch.zeros(n, dim_latent + 1)
+        model.latent_sqrt_var = torch.zeros(n, dim_latent + 1)
 
-    if pln._NAME == "PlnPCA":
+    if model._NAME == "PlnPCA":
         with pytest.raises(ValueError):
-            pln.components = torch.zeros(dim, dim_latent + 1)
+            model.components = torch.zeros(dim, dim_latent + 1)
 
         with pytest.raises(ValueError):
-            pln.components = torch.zeros(dim + 1, dim_latent)
+            model.components = torch.zeros(dim + 1, dim_latent)
 
-        if pln.exog is None:
+        if model.exog is None:
             d = 0
         else:
-            d = pln.exog.shape[-1]
-        if pln._NAME != "Pln":
+            d = model.exog.shape[-1]
+        if model._NAME != "Pln":
             with pytest.raises(ValueError):
-                pln.coef = torch.zeros(d + 1, dim)
+                model.coef = torch.zeros(d + 1, dim)
 
             with pytest.raises(ValueError):
-                pln.coef = torch.zeros(d, dim + 1)
+                model.coef = torch.zeros(d, dim + 1)
diff --git a/tests/test_viz.py b/tests/test_viz.py
index be24fcf1..d4f9a738 100644
--- a/tests/test_viz.py
+++ b/tests/test_viz.py
@@ -7,47 +7,49 @@ from tests.utils import MSE, filter_models
 
 from tests.import_data import true_sim_0cov, true_sim_2cov, labels_real
 
+single_models = ["Pln", "PlnPCA", "ZIPln"]
 
-@pytest.mark.parametrize("any_pln", dict_fixtures["loaded_and_fitted_pln"])
-def test_print(any_pln):
-    print(any_pln)
-
-
-@pytest.mark.parametrize("any_pln", dict_fixtures["fitted_pln"])
-@filter_models(["Pln", "PlnPCA"])
-def test_show_coef_transform_covariance_pcaprojected(any_pln):
-    any_pln.show()
-    any_pln._plotargs._show_loss()
-    any_pln._plotargs._show_stopping_criterion()
-    assert hasattr(any_pln, "coef")
-    assert callable(any_pln.transform)
-    assert hasattr(any_pln, "covariance")
-    assert callable(any_pln.sk_PCA)
-    assert any_pln.sk_PCA(n_components=None) is not None
+
+@pytest.mark.parametrize("any_model", dict_fixtures["loaded_and_fitted_model"])
+def test_print(any_model):
+    print(any_model)
+
+
+@pytest.mark.parametrize("any_model", dict_fixtures["fitted_model"])
+@filter_models(single_models)
+def test_show_coef_transform_covariance_pcaprojected(any_model):
+    any_model.show()
+    any_model._criterion_args._show_loss()
+    any_model._criterion_args._show_stopping_criterion()
+    assert hasattr(any_model, "coef")
+    assert callable(any_model.transform)
+    assert hasattr(any_model, "covariance")
+    assert callable(any_model.sk_PCA)
+    assert any_model.sk_PCA(n_components=None) is not None
     with pytest.raises(Exception):
-        any_pln.sk_PCA(n_components=any_pln.dim + 1)
+        any_model.sk_PCA(n_components=any_model.dim + 1)
 
 
-@pytest.mark.parametrize("pln", dict_fixtures["fitted_pln"])
-@filter_models(["Pln"])
-def test_scatter_pca_matrix_pln(pln):
-    pln.scatter_pca_matrix(n_components=8)
+@pytest.mark.parametrize("model", dict_fixtures["fitted_model"])
+@filter_models(["Pln", "ZIPln"])
+def test_scatter_pca_matrix_pln(model):
+    model.scatter_pca_matrix(n_components=8)
 
 
-@pytest.mark.parametrize("pln", dict_fixtures["fitted_pln"])
+@pytest.mark.parametrize("model", dict_fixtures["fitted_model"])
 @filter_models(["PlnPCA"])
-def test_scatter_pca_matrix_plnpca(pln):
-    pln.scatter_pca_matrix(n_components=2)
-    pln.scatter_pca_matrix()
+def test_scatter_pca_matrix_plnpca(model):
+    model.scatter_pca_matrix(n_components=2)
+    model.scatter_pca_matrix()
 
 
-@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_real_pln"])
-@filter_models(["Pln", "PlnPCA"])
-def test_label_scatter_pca_matrix(pln):
-    pln.scatter_pca_matrix(n_components=4, color=labels_real)
+@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_real_model"])
+@filter_models(single_models)
+def test_label_scatter_pca_matrix(model):
+    model.scatter_pca_matrix(n_components=4, color=labels_real)
 
 
-@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"])
+@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_model"])
 @filter_models(["PlnPCAcollection"])
 def test_viz_pcacol(plnpca):
     for model in plnpca.values():
@@ -64,38 +66,38 @@ def test_viz_pcacol(plnpca):
         plt.show()
 
 
-@pytest.mark.parametrize("pln", dict_fixtures["real_fitted_pln_intercept_array"])
-@filter_models(["Pln", "PlnPCA"])
-def test_plot_pca_correlation_graph_with_names_only(pln):
-    pln.plot_pca_correlation_graph([f"var_{i}" for i in range(8)])
+@pytest.mark.parametrize("model", dict_fixtures["real_fitted_model_intercept_array"])
+@filter_models(single_models)
+def test_plot_pca_correlation_graph_with_names_only(model):
+    model.plot_pca_correlation_graph([f"var_{i}" for i in range(8)])
 
 
-@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_sim_pln"])
-@filter_models(["Pln", "PlnPCA"])
-def test_fail_plot_pca_correlation_graph_without_names(pln):
+@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_sim_model"])
+@filter_models(single_models)
+def test_fail_plot_pca_correlation_graph_without_names(model):
     with pytest.raises(ValueError):
-        pln.plot_pca_correlation_graph([f"var_{i}" for i in range(8)])
+        model.plot_pca_correlation_graph([f"var_{i}" for i in range(8)])
     with pytest.raises(ValueError):
-        pln.plot_pca_correlation_graph([f"var_{i}" for i in range(6)], [1, 2, 3])
+        model.plot_pca_correlation_graph([f"var_{i}" for i in range(6)], [1, 2, 3])
 
 
-@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_sim_pln"])
-@filter_models(["Pln", "PlnPCA"])
-def test_plot_pca_correlation_graph_without_names(pln):
-    pln.plot_pca_correlation_graph([f"var_{i}" for i in range(3)], [0, 1, 2])
+@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_sim_model"])
+@filter_models(single_models)
+def test_plot_pca_correlation_graph_without_names(model):
+    model.plot_pca_correlation_graph([f"var_{i}" for i in range(3)], [0, 1, 2])
 
 
-@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"])
-@filter_models(["Pln", "PlnPCA"])
-def test_expected_vs_true(pln):
-    pln.plot_expected_vs_true()
+@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_model"])
+@filter_models(single_models)
+def test_expected_vs_true(model):
+    model.plot_expected_vs_true()
     fig, ax = plt.subplots()
-    pln.plot_expected_vs_true(ax=ax)
+    model.plot_expected_vs_true(ax=ax)
 
 
-@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_real_pln"])
-@filter_models(["Pln", "PlnPCA"])
-def test_expected_vs_true_labels(pln):
-    pln.plot_expected_vs_true(colors=labels_real)
+@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_real_model"])
+@filter_models(single_models)
+def test_expected_vs_true_labels(model):
+    model.plot_expected_vs_true(colors=labels_real)
     fig, ax = plt.subplots()
-    pln.plot_expected_vs_true(ax=ax, colors=labels_real)
+    model.plot_expected_vs_true(ax=ax, colors=labels_real)
diff --git a/tests/test_zi.py b/tests/test_zi.py
index 4ba5af04..2016accf 100644
--- a/tests/test_zi.py
+++ b/tests/test_zi.py
@@ -6,6 +6,9 @@ from tests.conftest import dict_fixtures
 from tests.utils import filter_models, MSE
 
 
+from pyPLNmodels import get_simulated_count_data
+
+
 @pytest.mark.parametrize("zi", dict_fixtures["loaded_and_fitted_model"])
 @filter_models(["ZIPln"])
 def test_properties(zi):
@@ -60,14 +63,67 @@ def test_no_exog_not_possible(model):
     assert model._coef_inflation.shape[0] == 1
 
 
-def test_find_right_covariance_and_coef():
-    pln_param = get_simulation_parameters(
-        n_samples=300, dim=50, nb_cov=2, rank=5, add_const=True
+def test_find_right_covariance_coef_and_infla():
+    pln_param = get_simulation_parameters(zero_inflated=True, n_samples=1000)
+    # pln_param._coef += 5
+    endog = sample_pln(pln_param, seed=0, return_latent=False)
+    exog = pln_param.exog
+    offsets = pln_param.offsets
+    covariance = pln_param.covariance
+    coef = pln_param.coef
+    coef_inflation = pln_param.coef_inflation
+    endog, exog, offsets, covariance, coef, coef_inflation = get_simulated_count_data(
+        zero_inflated=True, return_true_param=True, n_samples=1000
     )
-    pln_param._coef += 5
+    zi = ZIPln(endog, exog=exog, offsets=offsets, use_closed_form_prob=False)
+    zi.fit()
+    mse_covariance = MSE(zi.covariance - covariance)
+    mse_coef = MSE(zi.coef - coef)
+    mse_coef_infla = MSE(zi.coef_inflation - coef_inflation)
+    assert mse_coef < 3
+    assert mse_coef_infla < 3
+    assert mse_covariance < 1
+
+
+@pytest.mark.parametrize("zi", dict_fixtures["loaded_and_fitted_model"])
+@filter_models(["ZIPln"])
+def test_latent_variables(zi):
+    z, w = zi.latent_variables
+    assert z.shape == zi.endog.shape
+    assert w.shape == zi.endog.shape
+
+
+@pytest.mark.parametrize("zi", dict_fixtures["loaded_and_fitted_model"])
+@filter_models(["ZIPln"])
+def test_transform(zi):
+    z = zi.transform()
+    assert z.shape == zi.endog.shape
+    z, w = zi.transform(return_latent_prob=True)
+    assert z.shape == w.shape == zi.endog.shape
+
+
+@pytest.mark.parametrize("model", dict_fixtures["sim_model_instance"])
+@filter_models(["ZIPln"])
+def test_batch(model):
+    pln_param = get_simulation_parameters(zero_inflated=True, n_samples=1000)
+    # pln_param._coef += 5
     endog = sample_pln(pln_param, seed=0, return_latent=False)
-    zi = ZIPln(endog, exog=pln_param.exog, offsets=pln_param.offsets)
+    exog = pln_param.exog
+    offsets = pln_param.offsets
+    covariance = pln_param.covariance
+    coef = pln_param.coef
+    coef_inflation = pln_param.coef_inflation
+    endog, exog, offsets, covariance, coef, coef_inflation = get_simulated_count_data(
+        zero_inflated=True, return_true_param=True, n_samples=1000
+    )
+    zi = ZIPln(endog, exog=exog, offsets=offsets, use_closed_form_prob=False)
+    zi.fit(batch_size=20)
+    mse_covariance = MSE(zi.covariance - covariance)
+    mse_coef = MSE(zi.coef - coef)
+    mse_coef_infla = MSE(zi.coef_inflation - coef_inflation)
+    assert mse_coef < 3
+    assert mse_coef_infla < 3
+    assert mse_covariance < 1
+    zi.show()
+    print(zi)
     zi.fit()
-    mse_covariance = MSE(zi.covariance - pln_param.covariance)
-    mse_coef = MSE(zi.coef)
-    assert mse_covariance < 0.5
-- 
GitLab


From 244e1fdde7c7aadeab6b5b9313540972cc99f7cf Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Tue, 17 Oct 2023 19:53:42 +0200
Subject: [PATCH 44/68] add a file in the gitignore and blacked one file ??!

---
 .gitignore                     | 1 +
 pyPLNmodels/_initialization.py | 4 +---
 2 files changed, 2 insertions(+), 3 deletions(-)

diff --git a/.gitignore b/.gitignore
index c00f1395..08cb3cfe 100644
--- a/.gitignore
+++ b/.gitignore
@@ -160,3 +160,4 @@ tests/test_models*
 tests/test_load*
 tests/test_readme*
 Getting_started.py
+new_model.py
diff --git a/pyPLNmodels/_initialization.py b/pyPLNmodels/_initialization.py
index 410283bf..fe649fe0 100644
--- a/pyPLNmodels/_initialization.py
+++ b/pyPLNmodels/_initialization.py
@@ -43,9 +43,7 @@ def _init_covariance(endog: torch.Tensor, exog: torch.Tensor) -> torch.Tensor:
     return sigma_hat
 
 
-def _init_components(
-    endog: torch.Tensor, rank: int
-) -> torch.Tensor:
+def _init_components(endog: torch.Tensor, rank: int) -> torch.Tensor:
     """
     Initialization for components for the Pln model. Get a first guess for covariance
     that is easier to estimate and then takes the rank largest eigenvectors to get components.
-- 
GitLab


From c8a9739fe488965bcf395b47544ad129ad6f9bd3 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Tue, 17 Oct 2023 22:44:07 +0200
Subject: [PATCH 45/68] fixed some tests.

---
 pyPLNmodels/models.py          | 6 +++---
 tests/test_plnpcacollection.py | 4 ++--
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 01f9353b..8351056c 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -3252,7 +3252,7 @@ class PlnPCA(_model):
         Parameters
         ----------
         project : bool, optional
-            Whether to project the latent variables, by default True.
+            Whether to project the latent variables, by default False.
         """,
         returns="""
         torch.Tensor
@@ -3269,7 +3269,7 @@ class PlnPCA(_model):
             >>> print(transformed_endog_high_dim.shape)
             """,
     )
-    def transform(self, project: bool = True) -> torch.Tensor:
+    def transform(self, project: bool = False) -> torch.Tensor:
         if project is True:
             return self.projected_latent_variables
         return self.latent_variables
@@ -3297,7 +3297,7 @@ class PlnPCA(_model):
             >>> pca.fit()
             >>> elbo = pca.compute_elbo()
             >>> print("elbo", elbo)
-            >>> print("loglike/n", pln.loglike/pln.n_samples)
+            >>> print("loglike/n", pca.loglike/pca.n_samples)
             """,
     )
     def compute_elbo(self) -> torch.Tensor:
diff --git a/tests/test_plnpcacollection.py b/tests/test_plnpcacollection.py
index 19b49b18..77016b73 100644
--- a/tests/test_plnpcacollection.py
+++ b/tests/test_plnpcacollection.py
@@ -33,8 +33,8 @@ def test_right_nbcov(plnpca):
 @pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_model"])
 @filter_models(["PlnPCA"])
 def test_latent_var_pca(plnpca):
-    assert plnpca.transform(project=False).shape == plnpca.endog.shape
-    assert plnpca.transform().shape == (plnpca.n_samples, plnpca.rank)
+    assert plnpca.transform().shape == plnpca.endog.shape
+    assert plnpca.transform(project=True).shape == (plnpca.n_samples, plnpca.rank)
 
 
 @pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_model"])
-- 
GitLab


From 22b274525ac992a5a1ddebc5e59d62dcd49eb40e Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Wed, 18 Oct 2023 14:27:29 +0200
Subject: [PATCH 46/68] fixed some tests for the zi.

---
 pyPLNmodels/elbos.py           |  3 +--
 pyPLNmodels/models.py          | 37 +++++++++++-----------------------
 tests/conftest.py              |  1 -
 tests/test_pln_full.py         |  4 ++--
 tests/test_plnpcacollection.py |  2 +-
 tests/test_setters.py          |  2 +-
 6 files changed, 17 insertions(+), 32 deletions(-)

diff --git a/pyPLNmodels/elbos.py b/pyPLNmodels/elbos.py
index 5a56bc3d..73e77028 100644
--- a/pyPLNmodels/elbos.py
+++ b/pyPLNmodels/elbos.py
@@ -219,8 +219,7 @@ def elbo_zi_pln(
     """
     covariance = components @ (components.T)
     if torch.norm(latent_prob * dirac - latent_prob) > 0.00000001:
-        print("Bug")
-        raise RuntimeError("rho error")
+        raise RuntimeError("latent_prob error")
     n_samples, dim = endog.shape
     s_rond_s = torch.multiply(latent_sqrt_var, latent_sqrt_var)
     o_plus_m = offsets + latent_mean
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 8351056c..d692a9e5 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -242,11 +242,8 @@ class _model(ABC):
                 _plot_ellipse(x[i], y[i], cov=covariances[i], ax=ax)
         return ax
 
-    def _update_parameters(self):
-        """
-        Update the parameters with a gradient step and project if necessary.
-        """
-        self.optim.step()
+    def _project_parameters(self):
+        pass
 
     def _handle_batch_size(self, batch_size):
         if batch_size is None:
@@ -486,8 +483,8 @@ class _model(ABC):
                 raise ValueError("test")
             loss.backward()
             elbo += loss.item()
-            self._update_parameters()
-            self._update_closed_forms()
+            self.optim.step()
+        self._project_parameters()
         return elbo / self.nb_batches
 
     def _extract_batch(self, batch):
@@ -736,12 +733,6 @@ class _model(ABC):
         self._criterion_args.update_criterion(-loss, current_running_time)
         return self._criterion_args.criterion
 
-    def _update_closed_forms(self):
-        """
-        Update closed-form expressions.
-        """
-        pass
-
     def display_covariance(self, ax=None, savefig=False, name_file=""):
         """
         Display the covariance matrix.
@@ -3740,8 +3731,7 @@ class ZIPln(_model):
             )
         self._latent_sqrt_var = latent_sqrt_var
 
-    def _update_parameters(self):
-        super()._update_parameters()
+    def _project_parameters(self):
         self._project_latent_prob()
 
     def _project_latent_prob(self):
@@ -3750,13 +3740,13 @@ class ZIPln(_model):
         """
         if self._use_closed_form_prob is False:
             with torch.no_grad():
-                self._latent_prob_b = torch.maximum(
-                    self._latent_prob_b, torch.tensor([0]), out=self._latent_prob_b
+                torch.maximum(
+                    self._latent_prob, torch.tensor([0]), out=self._latent_prob
                 )
-                self._latent_prob_b = torch.minimum(
-                    self._latent_prob_b, torch.tensor([1]), out=self._latent_prob_b
+                torch.minimum(
+                    self._latent_prob, torch.tensor([1]), out=self._latent_prob
                 )
-                self._latent_prob_b *= self._dirac_b
+                self._latent_prob *= self._dirac
 
     @property
     def covariance(self) -> torch.Tensor:
@@ -3878,16 +3868,13 @@ class ZIPln(_model):
             self._latent_sqrt_var,
             self._components,
         ]
-        if self._use_closed_form_prob:
+        if self._use_closed_form_prob is False:
             list_parameters.append(self._latent_prob)
         if self._exog is not None:
             list_parameters.append(self._coef)
             list_parameters.append(self._coef_inflation)
         return list_parameters
 
-    def _update_closed_forms(self):
-        pass
-
     @property
     @_add_doc(_model)
     def model_parameters(self) -> Dict[str, torch.Tensor]:
@@ -3938,7 +3925,7 @@ class ZIPln(_model):
             "latent_sqrt_var": self.latent_sqrt_var,
             "latent_mean": self.latent_mean,
         }
-        if self._use_closed_form_prob is True:
+        if self._use_closed_form_prob is False:
             latent_param["latent_prob"] = self.latent_prob
         return latent_param
 
diff --git a/tests/conftest.py b/tests/conftest.py
index 22ea4307..b77a5d67 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -17,7 +17,6 @@ sys.path.append("../")
 #     for fixture_file in glob.glob("src/**/tests/fixtures/[!__]*.py", recursive=True)
 # ]
 
-
 from tests.import_data import (
     data_sim_0cov,
     data_sim_2cov,
diff --git a/tests/test_pln_full.py b/tests/test_pln_full.py
index 6a8ced3a..e5959b0e 100644
--- a/tests/test_pln_full.py
+++ b/tests/test_pln_full.py
@@ -7,8 +7,8 @@ from tests.utils import filter_models
 @pytest.mark.parametrize("fitted_pln", dict_fixtures["fitted_model"])
 @filter_models(["Pln"])
 def test_number_of_iterations_pln_full(fitted_pln):
-    nb_iterations = len(fitted_pln.elbos_list)
-    assert 20 < nb_iterations < 1000
+    nb_iterations = len(fitted_pln._elbos_list)
+    assert 20 < nb_iterations < 2000
 
 
 @pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_model"])
diff --git a/tests/test_plnpcacollection.py b/tests/test_plnpcacollection.py
index 77016b73..0d982d60 100644
--- a/tests/test_plnpcacollection.py
+++ b/tests/test_plnpcacollection.py
@@ -86,4 +86,4 @@ def test_batch(collection):
             assert mse_coef < 0.35
         assert mse_covariance < 0.25
     collection.fit()
-    assert collection.batch_size == collection.n_samples
+    assert collection.batch_size == 20
diff --git a/tests/test_setters.py b/tests/test_setters.py
index b3012548..f230d858 100644
--- a/tests/test_setters.py
+++ b/tests/test_setters.py
@@ -8,7 +8,7 @@ from tests.utils import MSE, filter_models
 single_models = ["Pln", "PlnPCA", "ZIPln"]
 
 
-@pytest.mark.parametrize("model", dict_fixtures["all_model"])
+@pytest.mark.parametrize("model", dict_fixtures["loaded_model"])
 def test_data_setter_with_torch(model):
     model.endog = model.endog
     model.exog = model.exog
-- 
GitLab


From 7f49f341e25b436984ef6d7133b6f647c686a8f5 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Wed, 18 Oct 2023 15:19:51 +0200
Subject: [PATCH 47/68] change gitignore

---
 .gitignore | 1 +
 1 file changed, 1 insertion(+)

diff --git a/.gitignore b/.gitignore
index 08cb3cfe..b2d69b27 100644
--- a/.gitignore
+++ b/.gitignore
@@ -150,6 +150,7 @@ test.py
 
 ## directories that outputs when running the tests
 tests/Pln*
+tests/ZIPln*
 slides/
 index.html
 
-- 
GitLab


From 4432bdfbc77cb687257acd34bd94e5b7f968eefc Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Wed, 18 Oct 2023 19:32:30 +0200
Subject: [PATCH 48/68] fixed some bug

---
 pyPLNmodels/models.py | 31 ++++++++++++++++++++++---------
 tests/conftest.py     | 10 +++++-----
 2 files changed, 27 insertions(+), 14 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index d692a9e5..76f76fea 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -14,6 +14,7 @@ import plotly.express as px
 from mlxtend.plotting import plot_pca_correlation_graph
 import matplotlib
 from scipy import stats
+from statsmodels.discrete.count_model import ZeroInflatedPoisson
 
 from ._closed_forms import (
     _closed_formula_coef,
@@ -430,11 +431,11 @@ class _model(ABC):
         indices = np.arange(self.n_samples)
         if shuffle:
             np.random.shuffle(indices)
-
         for i in range(self._nb_full_batch):
-            yield self._return_batch(
+            batch = self._return_batch(
                 indices, i * self._batch_size, (i + 1) * self._batch_size
             )
+            yield batch
         # Last batch
         if self._last_batch_size != 0:
             yield self._return_batch(indices, -self._last_batch_size, self.n_samples)
@@ -475,7 +476,7 @@ class _model(ABC):
             The loss value.
         """
         elbo = 0
-        for batch in self._get_batch(shuffle=True):
+        for batch in self._get_batch(shuffle=False):
             self._extract_batch(batch)
             self.optim.zero_grad()
             loss = -self._compute_elbo_b()
@@ -1005,6 +1006,13 @@ class _model(ABC):
         os.makedirs(path, exist_ok=True)
         for key, value in self._dict_parameters.items():
             filename = f"{path}/{key}.csv"
+            if key == "latent_prob":
+                if torch.max(value) > 1 or torch.min(value) < 0:
+                    if (
+                        torch.norm(self.dirac * self.latent_prob - self.latent_prob)
+                        > 0.0001
+                    ):
+                        raise Exception("Error is here")
             if isinstance(value, torch.Tensor):
                 pd.DataFrame(np.array(value.cpu().detach())).to_csv(
                     filename, header=None, index=None
@@ -3465,6 +3473,9 @@ class ZIPln(_model):
         to_take = torch.tensor(indices[beginning:end]).to(DEVICE)
         batch = pln_batch + (torch.index_select(self._dirac, 0, to_take),)
         if self._use_closed_form_prob is False:
+            to_return = torch.index_select(self._latent_prob, 0, to_take)
+            print("max latent_prbo", torch.max(self._latent_prob))
+            print("max to return", torch.max(to_return))
             return batch + (torch.index_select(self._latent_prob, 0, to_take),)
         return batch
 
@@ -3587,6 +3598,12 @@ class ZIPln(_model):
             self._components = _init_components(self._endog, self.dim)
 
         if not hasattr(self, "_coef_inflation"):
+            # print('shape', self.exog.shape[1])
+            # for j in range(self.exog.shape[1]):
+            #     Y_j = self._endog[:,j].numpy()
+            #     offsets_j = self.offsets[:,j].numpy()
+            #     zip_training_results = ZeroInflatedPoisson(endog=Y_j,exog = self.exog.numpy(), exog_infl = self.exog.numpy(), inflation='logit', offsets = offsets_j).fit()
+            #     print('params', zip_training_results.params)
             self._coef_inflation = torch.randn(self.nb_cov, self.dim)
 
     def _random_init_latent_parameters(self):
@@ -3740,12 +3757,8 @@ class ZIPln(_model):
         """
         if self._use_closed_form_prob is False:
             with torch.no_grad():
-                torch.maximum(
-                    self._latent_prob, torch.tensor([0]), out=self._latent_prob
-                )
-                torch.minimum(
-                    self._latent_prob, torch.tensor([1]), out=self._latent_prob
-                )
+                self._latent_prob = torch.maximum(self._latent_prob, torch.tensor([0]))
+                self._latent_prob = torch.minimum(self._latent_prob, torch.tensor([1]))
                 self._latent_prob *= self._dirac
 
     @property
diff --git a/tests/conftest.py b/tests/conftest.py
index b77a5d67..d89a919a 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,5 +1,4 @@
 import sys
-import glob
 import pytest
 import torch
 from pytest_lazyfixture import lazy_fixture as lf
@@ -12,10 +11,6 @@ from pyPLNmodels import get_simulated_count_data
 
 sys.path.append("../")
 
-# pytest_plugins = [
-#     fixture_file.replace("/", ".").replace(".py", "")
-#     for fixture_file in glob.glob("src/**/tests/fixtures/[!__]*.py", recursive=True)
-# ]
 
 from tests.import_data import (
     data_sim_0cov,
@@ -42,6 +37,11 @@ def add_fixture_to_dict(my_dict, string_fixture):
     return my_dict
 
 
+# zi = ZIPln(endog_sim_2cov, exog = exog_sim_2cov)
+# zi.fit()
+# print(zi)
+
+
 def add_list_of_fixture_to_dict(
     my_dict, name_of_list_of_fixtures, list_of_string_fixtures
 ):
-- 
GitLab


From 8f2d824d7ac833f5a0020a7272acbcba1ae3c48c Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 19 Oct 2023 08:21:21 +0200
Subject: [PATCH 49/68] add init of coef infla but useless

---
 pyPLNmodels/models.py | 12 +++++-------
 1 file changed, 5 insertions(+), 7 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 76f76fea..49c590d1 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -14,7 +14,6 @@ import plotly.express as px
 from mlxtend.plotting import plot_pca_correlation_graph
 import matplotlib
 from scipy import stats
-from statsmodels.discrete.count_model import ZeroInflatedPoisson
 
 from ._closed_forms import (
     _closed_formula_coef,
@@ -3474,8 +3473,6 @@ class ZIPln(_model):
         batch = pln_batch + (torch.index_select(self._dirac, 0, to_take),)
         if self._use_closed_form_prob is False:
             to_return = torch.index_select(self._latent_prob, 0, to_take)
-            print("max latent_prbo", torch.max(self._latent_prob))
-            print("max to return", torch.max(to_return))
             return batch + (torch.index_select(self._latent_prob, 0, to_take),)
         return batch
 
@@ -3598,13 +3595,14 @@ class ZIPln(_model):
             self._components = _init_components(self._endog, self.dim)
 
         if not hasattr(self, "_coef_inflation"):
-            # print('shape', self.exog.shape[1])
+            self._coef_inflation = torch.randn(self.nb_cov, self.dim)
             # for j in range(self.exog.shape[1]):
             #     Y_j = self._endog[:,j].numpy()
             #     offsets_j = self.offsets[:,j].numpy()
-            #     zip_training_results = ZeroInflatedPoisson(endog=Y_j,exog = self.exog.numpy(), exog_infl = self.exog.numpy(), inflation='logit', offsets = offsets_j).fit()
-            #     print('params', zip_training_results.params)
-            self._coef_inflation = torch.randn(self.nb_cov, self.dim)
+            #     exog = self.exog[:,j].unsqueeze(1).numpy()
+            #     undzi = ZeroInflatedPoisson(endog=Y_j,exog = exog, exog_infl = exog, inflation='logit', offset = offsets_j)
+            #     zip_training_results = undzi.fit()
+            #     self._coef_inflation[:,j] = zip_training_results.params[1]
 
     def _random_init_latent_parameters(self):
         self._latent_mean = torch.randn(self.n_samples, self.dim)
-- 
GitLab


From 2ef50c63781fa6cd87d9187e872257d1c55f475a Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 19 Oct 2023 09:17:14 +0200
Subject: [PATCH 50/68] add a checker of getting started file.

---
 .gitignore                                    |  7 ++---
 ...e_getting_started_and_docstrings_tests.py} | 26 +++++++++++++++----
 2 files changed, 25 insertions(+), 8 deletions(-)
 rename tests/{create_readme_and_docstrings_tests.py => create_readme_getting_started_and_docstrings_tests.py} (75%)

diff --git a/.gitignore b/.gitignore
index b2d69b27..85c56ae8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -157,8 +157,9 @@ index.html
 paper/*
 
 
-tests/test_models*
-tests/test_load*
-tests/test_readme*
+tests/docstrings_examples/*
+tests/getting_started/*
+tests/readme_examples/*
+tests/test_getting_started.py
 Getting_started.py
 new_model.py
diff --git a/tests/create_readme_and_docstrings_tests.py b/tests/create_readme_getting_started_and_docstrings_tests.py
similarity index 75%
rename from tests/create_readme_and_docstrings_tests.py
rename to tests/create_readme_getting_started_and_docstrings_tests.py
index 63aecf9d..113bf841 100644
--- a/tests/create_readme_and_docstrings_tests.py
+++ b/tests/create_readme_getting_started_and_docstrings_tests.py
@@ -4,6 +4,7 @@ import os
 
 dir_docstrings = "docstrings_examples"
 dir_readme = "readme_examples"
+dir_getting_started = "getting_started"
 
 
 def get_lines(path_to_file, filename, filetype=".py"):
@@ -47,11 +48,11 @@ def get_example_readme(lines):
     return [example]
 
 
-def write_examples(examples, filename):
+def write_file(examples, filename, string_definer, dir):
     for i in range(len(examples)):
         example = examples[i]
         nb_example = str(i + 1)
-        example_filename = f"test_{filename}_example_{nb_example}.py"
+        example_filename = f"{dir}/test_{filename}_{string_definer}_{nb_example}.py"
         try:
             os.remove(example_filename)
         except FileNotFoundError:
@@ -64,19 +65,34 @@ def write_examples(examples, filename):
 def filename_to_docstring_example_file(filename, dirname):
     lines = get_lines("../pyPLNmodels/", filename)
     examples = get_examples_docstring(lines)
-    write_examples(examples, filename)
+    write_file(examples, filename, "example", dir=dirname)
 
 
 def filename_to_readme_example_file():
     lines = get_lines("../", "README", filetype=".md")
     examples = get_example_readme(lines)
-    write_examples(examples, "readme")
+    write_file(examples, "readme", "example", dir=dir_readme)
+
+
+lines_getting_started = get_lines("./", "test_getting_started")
+new_lines = []
+for line in lines_getting_started:
+    if len(line) > 20:
+        if line[0:11] != "get_ipython":
+            new_lines.append(line)
+    else:
+        new_lines.append(line)
 
 
 os.makedirs(dir_readme, exist_ok=True)
+os.makedirs(dir_docstrings, exist_ok=True)
+os.makedirs(dir_getting_started, exist_ok=True)
+
+write_file([new_lines], "getting_started", "", dir_getting_started)
+
 filename_to_readme_example_file()
 
-os.makedirs("docstrings_examples", exist_ok=True)
+
 filename_to_docstring_example_file("_utils", dir_docstrings)
 filename_to_docstring_example_file("models", dir_docstrings)
 filename_to_docstring_example_file("elbos", dir_docstrings)
-- 
GitLab


From d6d539c3825a91c00965e5f312df58cdb580d96b Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 19 Oct 2023 09:21:57 +0200
Subject: [PATCH 51/68] changed the cd. Goes to file to check the readme
 examples, docstrings examples etc.

---
 .gitlab-ci.yml | 10 +++++++++-
 1 file changed, 9 insertions(+), 1 deletion(-)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index c95c78f1..bf2bf333 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -20,8 +20,16 @@ tests:
     pip install '.[tests]'
   script:
     - pip install .
+    - jupyter nbconvert Getting_started.ipynb --to python --output tests/test_getting_started
     - cd tests
-    - python create_readme_and_docstrings_tests.py
+    - python create_readme_getting_started_and_docstrings_tests.py
+    - cd readme_examples
+    - pytest .
+    - cd ../readme_examples
+    - pytest .
+    - cd ../getting_started
+    - pytest .
+    - cd ..
     - pytest .
   only:
     - main
-- 
GitLab


From 768decb51a600b16b0475ff3fda137792c159901 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 19 Oct 2023 10:44:14 +0200
Subject: [PATCH 52/68] add my own image.

---
 .gitlab-ci.yml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index bf2bf333..108e940b 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -15,7 +15,7 @@ black:
 
 tests:
   stage: checks
-  image: "registry.forgemia.inra.fr/jbleger/docker-image-pandas-torch-sphinx:master"
+  image: "registry.forgemia.inra.fr/bbatardiere/jbleger:main"
   before_script:
     pip install '.[tests]'
   script:
-- 
GitLab


From 8d90591bb7220293af4732e068506275479e413d Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 19 Oct 2023 10:46:02 +0200
Subject: [PATCH 53/68] fix the name of the image.

---
 .gitlab-ci.yml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 108e940b..9c0bda6f 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -15,7 +15,7 @@ black:
 
 tests:
   stage: checks
-  image: "registry.forgemia.inra.fr/bbatardiere/jbleger:main"
+  image: "registry.forgemia.inra.fr/bbatardiere/docker-image-pandas-torch-sphinx-jupyter:main"
   before_script:
     pip install '.[tests]'
   script:
-- 
GitLab


From 076746e2434bb646e3428d35873c773b286bf2c3 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 19 Oct 2023 10:50:14 +0200
Subject: [PATCH 54/68] retry the ci

---
 .gitignore | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/.gitignore b/.gitignore
index 85c56ae8..c5443b2c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -160,6 +160,6 @@ paper/*
 tests/docstrings_examples/*
 tests/getting_started/*
 tests/readme_examples/*
-tests/test_getting_started.py
+# tests/test_getting_started.py
 Getting_started.py
 new_model.py
-- 
GitLab


From 92fc2ef1215d76238d4882b28be978de8019ffb2 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 19 Oct 2023 10:52:38 +0200
Subject: [PATCH 55/68] change the image name.

---
 .gitlab-ci.yml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 9c0bda6f..bba6aff6 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -15,7 +15,7 @@ black:
 
 tests:
   stage: checks
-  image: "registry.forgemia.inra.fr/bbatardiere/docker-image-pandas-torch-sphinx-jupyter:main"
+  image: "registry.forgemia.inra.fr/bbatardiere/docker-image-pandas-torch-sphinx-jupyter"
   before_script:
     pip install '.[tests]'
   script:
-- 
GitLab


From 92563e1861d1b35d0bb4e70f81ddaae9f922bce2 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 19 Oct 2023 13:51:53 +0200
Subject: [PATCH 56/68] chang the gitlab and the running of examles since
 pytest was not testing it.

---
 .gitlab-ci.yml                                |   8 +-
 .../run_readme_docstrings_getting_started.sh  |  15 ++
 tests/test_getting_started.py                 | 131 ++++++++++++++----
 3 files changed, 119 insertions(+), 35 deletions(-)
 create mode 100755 tests/run_readme_docstrings_getting_started.sh

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index bba6aff6..ba97ffea 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -23,13 +23,7 @@ tests:
     - jupyter nbconvert Getting_started.ipynb --to python --output tests/test_getting_started
     - cd tests
     - python create_readme_getting_started_and_docstrings_tests.py
-    - cd readme_examples
-    - pytest .
-    - cd ../readme_examples
-    - pytest .
-    - cd ../getting_started
-    - pytest .
-    - cd ..
+    - ./run_readme_docstrings_getting_started.sh
     - pytest .
   only:
     - main
diff --git a/tests/run_readme_docstrings_getting_started.sh b/tests/run_readme_docstrings_getting_started.sh
new file mode 100755
index 00000000..59489033
--- /dev/null
+++ b/tests/run_readme_docstrings_getting_started.sh
@@ -0,0 +1,15 @@
+#!/bin/sh
+for file in docstrings_examples/*
+do
+    python $file
+done
+
+for file in readme_examples/*
+do
+    python $file
+done
+
+for file in getting_started/*
+do
+    python $file
+done
diff --git a/tests/test_getting_started.py b/tests/test_getting_started.py
index 69299741..605132a0 100644
--- a/tests/test_getting_started.py
+++ b/tests/test_getting_started.py
@@ -1,49 +1,63 @@
 #!/usr/bin/env python
 # coding: utf-8
 
-# get_ipython().system('pip install pyPLNmodels')
+# In[1]:
+
+
+get_ipython().system('pip install pyPLNmodels')
 
 
 # ## pyPLNmodels
 
 # We assume the data comes from a PLN model:  $ \text{counts} \sim  \mathcal P(\exp(\text{Z}))$, where $Z$ are some unknown latent variables.
-#
-#
-# The goal of the package is to retrieve the latent variables $Z$ given the counts. To do so, one can instantiate a Pln or PlnPCA model, fit it and then extract the latent variables.
+# 
+# 
+# The goal of the package is to retrieve the latent variables $Z$ given the counts. To do so, one can instantiate a Pln or PlnPCA model, fit it and then extract the latent variables.  
 
 # ### Import the needed functions
 
-from pyPLNmodels import (
-    get_real_count_data,
-    get_simulated_count_data,
-    load_model,
-    Pln,
-    PlnPCA,
-    PlnPCAcollection,
-)
+# In[2]:
+
+
+from pyPLNmodels import get_real_count_data, get_simulated_count_data, load_model, Pln, PlnPCA, PlnPCAcollection
 import matplotlib.pyplot as plt
 
 
 # ### Load the data
 
-counts, labels = get_real_count_data(return_labels=True)  # np.ndarray
+# In[3]:
+
+
+counts, labels  = get_real_count_data(return_labels=True) # np.ndarray
 
 
 # ### PLN model
 
-pln = Pln(counts, add_const=True)
+# In[4]:
+
+
+pln = Pln(counts, add_const = True)
 pln.fit()
 
 
+# In[5]:
+
+
 print(pln)
 
 
 # #### Once fitted, we can extract multiple variables:
 
+# In[6]:
+
+
 gaussian = pln.latent_variables
 print(gaussian.shape)
 
 
+# In[7]:
+
+
 model_param = pln.model_parameters
 print(model_param["coef"].shape)
 print(model_param["covariance"].shape)
@@ -51,37 +65,61 @@ print(model_param["covariance"].shape)
 
 # ### PlnPCA model
 
-pca = PlnPCA(counts, add_const=True, rank=5)
+# In[8]:
+
+
+pca = PlnPCA(counts, add_const = True, rank = 5)
 pca.fit()
 
 
+# In[9]:
+
+
 print(pca)
 
 
+# In[10]:
+
+
 print(pca.latent_variables.shape)
 
 
+# In[11]:
+
+
 print(pca.model_parameters["components"].shape)
 print(pca.model_parameters["coef"].shape)
 
 
 # ### One can save the model in order to load it back after:
 
+# In[13]:
+
+
 pca.save()
 dict_init = load_model("PlnPCA_nbcov_1_dim_200_rank_5")
-loaded_pca = PlnPCA(counts, add_const=True, dict_initialization=dict_init)
+loaded_pca = PlnPCA(counts, add_const = True, dict_initialization=  dict_init)
 print(loaded_pca)
 
 
 # ### One can fit multiple PCA and choose the best rank with BIC or AIC criterion
 
-pca_col = PlnPCAcollection(counts, add_const=True, ranks=[5, 15, 25, 40, 50])
+# In[14]:
+
+
+pca_col = PlnPCAcollection(counts, add_const = True, ranks = [5,15,25,40,50])
 pca_col.fit()
 
 
+# In[15]:
+
+
 pca_col.show()
 
 
+# In[16]:
+
+
 print(pca_col)
 
 
@@ -89,53 +127,90 @@ print(pca_col)
 
 # #### AIC best model
 
-print(pca_col.best_model(criterion="AIC"))
+# In[17]:
+
+
+print(pca_col.best_model(criterion = "AIC"))
 
 
 # #### BIC best model
 
-print(pca_col.best_model(criterion="BIC"))
+# In[18]:
+
+
+print(pca_col.best_model(criterion = "BIC"))
 
 
 # #### Visualization of the individuals (sites) with PCA on the latent variables.
 
+# In[19]:
+
+
 pln.viz(colors=labels)
 plt.show()
 
 
+# In[20]:
+
+
 best_pca = pca_col.best_model()
-best_pca.viz(colors=labels)
+best_pca.viz(colors = labels)
 plt.show()
 
 
-# ### What would give a PCA on the log normalize data ?
+# ### What would give a PCA on the log normalize data ? 
+
+# In[21]:
+
 
 from sklearn.decomposition import PCA
 import numpy as np
 import seaborn as sns
 
 
-sk_pca = PCA(n_components=2)
+# In[22]:
+
+
+sk_pca = PCA(n_components = 2)
 pca_log_counts = sk_pca.fit_transform(np.log(counts + (counts == 0)))
-sns.scatterplot(x=pca_log_counts[:, 0], y=pca_log_counts[:, 1], hue=labels)
+sns.scatterplot(x = pca_log_counts[:,0], y = pca_log_counts[:,1], hue = labels)
 
 
 # ### Visualization of the variables
 
-pln.plot_pca_correlation_graph(["var_1", "var_2"], indices_of_variables=[0, 1])
+# In[23]:
+
+
+pln.plot_pca_correlation_graph(["var_1","var_2"], indices_of_variables = [0,1])
 plt.show()
 
 
-best_pca.plot_pca_correlation_graph(["var_1", "var_2"], indices_of_variables=[0, 1])
+# In[24]:
+
+
+best_pca.plot_pca_correlation_graph(["var_1","var_2"], indices_of_variables = [0,1])
 plt.show()
 
 
 # ### Visualization of each components of the PCA
-#
+# 
 
-pln.scatter_pca_matrix(color=labels, n_components=5)
+# In[25]:
+
+
+pln.scatter_pca_matrix(color = labels, n_components = 5)
 plt.show()
 
 
-best_pca.scatter_pca_matrix(color=labels, n_components=6)
+# In[26]:
+
+
+best_pca.scatter_pca_matrix(color = labels, n_components = 6)
 plt.show()
+
+
+# In[ ]:
+
+
+
+
-- 
GitLab


From 5ab9fc33eda8ed4d06b40dd933681eaafdf917db Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 19 Oct 2023 13:54:16 +0200
Subject: [PATCH 57/68] remove useless file sinc eit is going to be removed

---
 tests/test_getting_started.py | 216 ----------------------------------
 1 file changed, 216 deletions(-)
 delete mode 100644 tests/test_getting_started.py

diff --git a/tests/test_getting_started.py b/tests/test_getting_started.py
deleted file mode 100644
index 605132a0..00000000
--- a/tests/test_getting_started.py
+++ /dev/null
@@ -1,216 +0,0 @@
-#!/usr/bin/env python
-# coding: utf-8
-
-# In[1]:
-
-
-get_ipython().system('pip install pyPLNmodels')
-
-
-# ## pyPLNmodels
-
-# We assume the data comes from a PLN model:  $ \text{counts} \sim  \mathcal P(\exp(\text{Z}))$, where $Z$ are some unknown latent variables.
-# 
-# 
-# The goal of the package is to retrieve the latent variables $Z$ given the counts. To do so, one can instantiate a Pln or PlnPCA model, fit it and then extract the latent variables.  
-
-# ### Import the needed functions
-
-# In[2]:
-
-
-from pyPLNmodels import get_real_count_data, get_simulated_count_data, load_model, Pln, PlnPCA, PlnPCAcollection
-import matplotlib.pyplot as plt
-
-
-# ### Load the data
-
-# In[3]:
-
-
-counts, labels  = get_real_count_data(return_labels=True) # np.ndarray
-
-
-# ### PLN model
-
-# In[4]:
-
-
-pln = Pln(counts, add_const = True)
-pln.fit()
-
-
-# In[5]:
-
-
-print(pln)
-
-
-# #### Once fitted, we can extract multiple variables:
-
-# In[6]:
-
-
-gaussian = pln.latent_variables
-print(gaussian.shape)
-
-
-# In[7]:
-
-
-model_param = pln.model_parameters
-print(model_param["coef"].shape)
-print(model_param["covariance"].shape)
-
-
-# ### PlnPCA model
-
-# In[8]:
-
-
-pca = PlnPCA(counts, add_const = True, rank = 5)
-pca.fit()
-
-
-# In[9]:
-
-
-print(pca)
-
-
-# In[10]:
-
-
-print(pca.latent_variables.shape)
-
-
-# In[11]:
-
-
-print(pca.model_parameters["components"].shape)
-print(pca.model_parameters["coef"].shape)
-
-
-# ### One can save the model in order to load it back after:
-
-# In[13]:
-
-
-pca.save()
-dict_init = load_model("PlnPCA_nbcov_1_dim_200_rank_5")
-loaded_pca = PlnPCA(counts, add_const = True, dict_initialization=  dict_init)
-print(loaded_pca)
-
-
-# ### One can fit multiple PCA and choose the best rank with BIC or AIC criterion
-
-# In[14]:
-
-
-pca_col = PlnPCAcollection(counts, add_const = True, ranks = [5,15,25,40,50])
-pca_col.fit()
-
-
-# In[15]:
-
-
-pca_col.show()
-
-
-# In[16]:
-
-
-print(pca_col)
-
-
-# ### One can extract the best model found (according to AIC or BIC criterion).
-
-# #### AIC best model
-
-# In[17]:
-
-
-print(pca_col.best_model(criterion = "AIC"))
-
-
-# #### BIC best model
-
-# In[18]:
-
-
-print(pca_col.best_model(criterion = "BIC"))
-
-
-# #### Visualization of the individuals (sites) with PCA on the latent variables.
-
-# In[19]:
-
-
-pln.viz(colors=labels)
-plt.show()
-
-
-# In[20]:
-
-
-best_pca = pca_col.best_model()
-best_pca.viz(colors = labels)
-plt.show()
-
-
-# ### What would give a PCA on the log normalize data ? 
-
-# In[21]:
-
-
-from sklearn.decomposition import PCA
-import numpy as np
-import seaborn as sns
-
-
-# In[22]:
-
-
-sk_pca = PCA(n_components = 2)
-pca_log_counts = sk_pca.fit_transform(np.log(counts + (counts == 0)))
-sns.scatterplot(x = pca_log_counts[:,0], y = pca_log_counts[:,1], hue = labels)
-
-
-# ### Visualization of the variables
-
-# In[23]:
-
-
-pln.plot_pca_correlation_graph(["var_1","var_2"], indices_of_variables = [0,1])
-plt.show()
-
-
-# In[24]:
-
-
-best_pca.plot_pca_correlation_graph(["var_1","var_2"], indices_of_variables = [0,1])
-plt.show()
-
-
-# ### Visualization of each components of the PCA
-# 
-
-# In[25]:
-
-
-pln.scatter_pca_matrix(color = labels, n_components = 5)
-plt.show()
-
-
-# In[26]:
-
-
-best_pca.scatter_pca_matrix(color = labels, n_components = 6)
-plt.show()
-
-
-# In[ ]:
-
-
-
-
-- 
GitLab


From eac1b642012ac2cc3f2417594f9595f18f9b4b47 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 19 Oct 2023 14:04:15 +0200
Subject: [PATCH 58/68] remove a file in the ci

---
 .gitlab-ci.yml | 1 +
 1 file changed, 1 insertion(+)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index ba97ffea..61e51176 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -23,6 +23,7 @@ tests:
     - jupyter nbconvert Getting_started.ipynb --to python --output tests/test_getting_started
     - cd tests
     - python create_readme_getting_started_and_docstrings_tests.py
+    - rm test_getting_started.py
     - ./run_readme_docstrings_getting_started.sh
     - pytest .
   only:
-- 
GitLab


From abaa5d83c52fadd6480f6df4351aee8edf399b79 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 19 Oct 2023 14:23:38 +0200
Subject: [PATCH 59/68] gpu support for the from formula.

---
 pyPLNmodels/_utils.py | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 1082259e..05d13b00 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -929,6 +929,10 @@ def _extract_data_from_formula(
         A tuple containing the extracted endog, exog, and offsets.
 
     """
+    # dmatrices can not deal with GPU matrices
+    for key,matrix in data.items():
+        if isinstance(matrix, torch.Tensor):
+            data[key] = matrix.cpu()
     dmatrix = dmatrices(formula, data=data)
     endog = dmatrix[0]
     exog = dmatrix[1]
-- 
GitLab


From 97a3e544d988c14833615a3b4bb474e1a61718b2 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 19 Oct 2023 15:40:41 +0200
Subject: [PATCH 60/68] add GPU support.

---
 pyPLNmodels/models.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 49c590d1..7dfe71a3 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -231,7 +231,7 @@ class _model(ABC):
         if self._get_max_components() < 2:
             raise RuntimeError("Can't perform visualization for dim < 2.")
         pca = self.sk_PCA(n_components=2)
-        proj_variables = pca.transform(self.latent_variables.detach().cpu())
+        proj_variables = pca.transform(self.latent_variables)
         x = proj_variables[:, 0]
         y = proj_variables[:, 1]
         sns.scatterplot(x=x, y=y, hue=colors, ax=ax)
@@ -1286,7 +1286,7 @@ class _model(ABC):
             raise RuntimeError("Please fit the model before.")
         if ax is None:
             ax = plt.gca()
-        predictions = self._endog_predictions().ravel().detach()
+        predictions = self._endog_predictions().ravel().cpu().detach()
         if colors is not None:
             colors = np.repeat(np.array(colors), repeats=self.dim).ravel()
         sns.scatterplot(x=self.endog.ravel(), y=predictions, hue=colors, ax=ax)
@@ -3284,7 +3284,7 @@ class PlnPCA(_model):
         """,
     )
     def latent_variables(self) -> torch.Tensor:
-        return torch.matmul(self._latent_mean, self._components.T).detach()
+        return torch.matmul(self.latent_mean, self.components.T)
 
     @_add_doc(
         _model,
-- 
GitLab


From 55e215d1591c3d9e88231782835422c228dd6d6c Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 19 Oct 2023 15:41:16 +0200
Subject: [PATCH 61/68] remove the shell script since otherwise it is running
 two times the test of the examples.

---
 .gitlab-ci.yml | 1 -
 1 file changed, 1 deletion(-)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 61e51176..e45adf60 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -24,7 +24,6 @@ tests:
     - cd tests
     - python create_readme_getting_started_and_docstrings_tests.py
     - rm test_getting_started.py
-    - ./run_readme_docstrings_getting_started.sh
     - pytest .
   only:
     - main
-- 
GitLab


From 917a67eea915ab9f67463fb3321846982f06552c Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Fri, 20 Oct 2023 12:36:07 +0200
Subject: [PATCH 62/68] fix GPU support.

---
 pyPLNmodels/models.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 7dfe71a3..ce920b97 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -236,7 +236,7 @@ class _model(ABC):
         y = proj_variables[:, 1]
         sns.scatterplot(x=x, y=y, hue=colors, ax=ax)
         if show_cov is True:
-            sk_components = torch.from_numpy(pca.components_)
+            sk_components = torch.from_numpy(pca.components_).to(DEVICE)
             covariances = self._get_pca_low_dim_covariances(sk_components).detach()
             for i in range(covariances.shape[0]):
                 _plot_ellipse(x[i], y[i], cov=covariances[i], ax=ax)
@@ -3008,7 +3008,7 @@ class PlnPCA(_model):
         else:
             XB = 0
         return torch.exp(
-            self._offsets + XB + self.latent_variables + 1 / 2 * covariance_a_posteriori
+            self._offsets + XB + self.latent_variables.to(DEVICE) + 1 / 2 * covariance_a_posteriori
         )
 
     @latent_mean.setter
-- 
GitLab


From 5273192e2add0ea47865b77d12b133cc4c47d41b Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Fri, 20 Oct 2023 12:48:10 +0200
Subject: [PATCH 63/68] fixe one more gpu support bug

---
 pyPLNmodels/models.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index ce920b97..ac2af805 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -232,8 +232,8 @@ class _model(ABC):
             raise RuntimeError("Can't perform visualization for dim < 2.")
         pca = self.sk_PCA(n_components=2)
         proj_variables = pca.transform(self.latent_variables)
-        x = proj_variables[:, 0]
-        y = proj_variables[:, 1]
+        x = proj_variables[:, 0].cpu()
+        y = proj_variables[:, 1].cpu()
         sns.scatterplot(x=x, y=y, hue=colors, ax=ax)
         if show_cov is True:
             sk_components = torch.from_numpy(pca.components_).to(DEVICE)
-- 
GitLab


From e2c19d64bfe30005a3373fc51dc6487ecfd0db99 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Sat, 21 Oct 2023 11:51:49 +0200
Subject: [PATCH 64/68] add the doc for zipln

---
 docs/source/index.rst |  1 +
 docs/source/zipln.rst | 10 ++++++++++
 2 files changed, 11 insertions(+)
 create mode 100644 docs/source/zipln.rst

diff --git a/docs/source/index.rst b/docs/source/index.rst
index 98f3e0a6..da418320 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -16,6 +16,7 @@ API documentation
    ./plnpcacollection.rst
    ./plnpca.rst
    ./pln.rst
+   ./zipln.rst
 
 .. toctree::
    :maxdepth: 1
diff --git a/docs/source/zipln.rst b/docs/source/zipln.rst
new file mode 100644
index 00000000..ae0e1e81
--- /dev/null
+++ b/docs/source/zipln.rst
@@ -0,0 +1,10 @@
+
+ZIPln
+===
+
+.. autoclass:: pyPLNmodels.ZIPln
+   :members:
+   :inherited-members:
+   :special-members: __init__
+   :undoc-members:
+   :show-inheritance:
-- 
GitLab


From 6fa0871612a28be9f6627f379c7d3e26eadf9290 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Sat, 21 Oct 2023 11:58:57 +0200
Subject: [PATCH 65/68] add GPU support

---
 pyPLNmodels/models.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index ac2af805..1249f14b 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -232,12 +232,12 @@ class _model(ABC):
             raise RuntimeError("Can't perform visualization for dim < 2.")
         pca = self.sk_PCA(n_components=2)
         proj_variables = pca.transform(self.latent_variables)
-        x = proj_variables[:, 0].cpu()
-        y = proj_variables[:, 1].cpu()
+        x = proj_variables[:, 0]
+        y = proj_variables[:, 1]
         sns.scatterplot(x=x, y=y, hue=colors, ax=ax)
         if show_cov is True:
             sk_components = torch.from_numpy(pca.components_).to(DEVICE)
-            covariances = self._get_pca_low_dim_covariances(sk_components).detach()
+            covariances = self._get_pca_low_dim_covariances(sk_components).cpu().detach()
             for i in range(covariances.shape[0]):
                 _plot_ellipse(x[i], y[i], cov=covariances[i], ax=ax)
         return ax
-- 
GitLab


From 33baf5979716ee9b6af100ea3f585f95cb0f6aab Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Sat, 21 Oct 2023 12:13:29 +0200
Subject: [PATCH 66/68] fixed gpu support.

---
 pyPLNmodels/_utils.py |  3 +--
 pyPLNmodels/models.py | 31 ++++++++++---------------------
 2 files changed, 11 insertions(+), 23 deletions(-)

diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 05d13b00..b8fa8001 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -199,9 +199,8 @@ def _log_stirling(integer: torch.Tensor) -> torch.Tensor:
         integer_ / math.exp(1)
     )
 
-
 def _trunc_log(tens: torch.Tensor, eps: float = 1e-16) -> torch.Tensor:
-    integer = torch.min(torch.max(tens, torch.tensor([eps])), torch.tensor([1 - eps]))
+    integer = torch.min(torch.max(tens, torch.tensor([eps]).to(DEVICE)), torch.tensor([1 - eps]).to(DEVICE))
     return torch.log(integer)
 
 
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 1249f14b..b591a7a8 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -968,7 +968,7 @@ class _model(ABC):
             raise ValueError(
                 f"Wrong shape. Expected {self.n_samples, self.dim}, got {latent_mean.shape}"
             )
-        self._latent_mean = latent_mean
+        self._latent_mean = latent_mean.to(DEVICE)
 
     def _cpu_attribute_or_none(self, attribute_name):
         """
@@ -1821,17 +1821,6 @@ class Pln(_model):
         pass
         # no model parameters since we are doing a profiled ELBO
 
-    @_add_doc(_model)
-    def _smart_init_latent_parameters(self):
-        self._random_init_latent_sqrt_var()
-        if not hasattr(self, "_latent_mean"):
-            self._latent_mean = torch.log(self._endog + (self._endog == 0))
-
-    @_add_doc(_model)
-    def _random_init_latent_parameters(self):
-        self._random_init_latent_sqrt_var()
-        if not hasattr(self, "_latent_mean"):
-            self._latent_mean = torch.ones((self.n_samples, self.dim)).to(DEVICE)
 
     @property
     @_add_doc(_model)
@@ -3583,9 +3572,9 @@ class ZIPln(_model):
         return "with full covariance model and zero-inflation."
 
     def _random_init_model_parameters(self):
-        self._coef_inflation = torch.randn(self.nb_cov, self.dim)
-        self._coef = torch.randn(self.nb_cov, self.dim)
-        self._components = torch.randn(self.dim, self.dim)
+        self._coef_inflation = torch.randn(self.nb_cov, self.dim).to(DEVICE)
+        self._coef = torch.randn(self.nb_cov, self.dim).to(DEVICE)
+        self._components = torch.randn(self.dim, self.dim).to(DEVICE)
 
     # should change the good initialization for _coef_inflation
     def _smart_init_model_parameters(self):
@@ -3595,7 +3584,7 @@ class ZIPln(_model):
             self._components = _init_components(self._endog, self.dim)
 
         if not hasattr(self, "_coef_inflation"):
-            self._coef_inflation = torch.randn(self.nb_cov, self.dim)
+            self._coef_inflation = torch.randn(self.nb_cov, self.dim).to(DEVICE)
             # for j in range(self.exog.shape[1]):
             #     Y_j = self._endog[:,j].numpy()
             #     offsets_j = self.offsets[:,j].numpy()
@@ -3605,12 +3594,12 @@ class ZIPln(_model):
             #     self._coef_inflation[:,j] = zip_training_results.params[1]
 
     def _random_init_latent_parameters(self):
-        self._latent_mean = torch.randn(self.n_samples, self.dim)
-        self._latent_sqrt_var = torch.randn(self.n_samples, self.dim)
+        self._latent_mean = torch.randn(self.n_samples, self.dim).to(DEVICE)
+        self._latent_sqrt_var = torch.randn(self.n_samples, self.dim).to(DEVICE)
         self._latent_prob = (
             torch.empty(self.n_samples, self.dim).uniform_(0, 1).to(DEVICE)
             * self._dirac
-        ).double()
+        ).double().to(DEVICE)
 
     def _smart_init_latent_parameters(self):
         self._random_init_latent_parameters()
@@ -3755,8 +3744,8 @@ class ZIPln(_model):
         """
         if self._use_closed_form_prob is False:
             with torch.no_grad():
-                self._latent_prob = torch.maximum(self._latent_prob, torch.tensor([0]))
-                self._latent_prob = torch.minimum(self._latent_prob, torch.tensor([1]))
+                self._latent_prob = torch.maximum(self._latent_prob, torch.tensor([0]).to(DEVICE))
+                self._latent_prob = torch.minimum(self._latent_prob, torch.tensor([1]).to(DEVICE))
                 self._latent_prob *= self._dirac
 
     @property
-- 
GitLab


From 426f53d5f85b54782c774842bb16166b28c3977a Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Mon, 23 Oct 2023 15:15:43 +0200
Subject: [PATCH 67/68] pass all the tests on GPU, going to merge on main.

---
 pyPLNmodels/models.py          |  2 +-
 tests/import_data.py           |  5 +++++
 tests/test_common.py           | 10 +++++-----
 tests/test_plnpcacollection.py |  5 ++---
 tests/test_zi.py               | 12 ++++++------
 5 files changed, 19 insertions(+), 15 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index b591a7a8..624bcf2a 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -3197,7 +3197,7 @@ class PlnPCA(_model):
         """
         Orthogonal components of the model.
         """
-        return torch.linalg.qr(self._components, "reduced")[0]
+        return torch.linalg.qr(self._components, "reduced")[0].cpu()
 
     @property
     def components(self) -> torch.Tensor:
diff --git a/tests/import_data.py b/tests/import_data.py
index 9ef5ef7e..9942db40 100644
--- a/tests/import_data.py
+++ b/tests/import_data.py
@@ -1,10 +1,15 @@
 import os
+import torch
 
 from pyPLNmodels import (
     get_simulated_count_data,
     get_real_count_data,
 )
 
+if torch.cuda.is_available():
+    DEVICE = "cuda:0"
+else:
+    DEVICE = "cpu"
 
 (
     endog_sim_0cov,
diff --git a/tests/test_common.py b/tests/test_common.py
index 6cba2cf6..df49f39c 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -46,14 +46,14 @@ def test_verbose(any_instance_model):
 @filter_models(pln_and_plnpca)
 def test_find_right_covariance(simulated_fitted_any_model):
     if simulated_fitted_any_model.nb_cov == 0:
-        true_covariance = true_sim_0cov["Sigma"]
+        true_covariance = true_sim_0cov["Sigma"].cpu()
     elif simulated_fitted_any_model.nb_cov == 2:
-        true_covariance = true_sim_2cov["Sigma"]
+        true_covariance = true_sim_2cov["Sigma"].cpu()
     else:
         raise ValueError(
             f"Not the right numbers of covariance({simulated_fitted_any_model.nb_cov})"
         )
-    mse_covariance = MSE(simulated_fitted_any_model.covariance - true_covariance)
+    mse_covariance = MSE(simulated_fitted_any_model.covariance.cpu() - true_covariance.cpu())
     assert mse_covariance < 0.05
 
 
@@ -75,7 +75,7 @@ def test_right_covariance_shape(real_fitted_and_loaded_model):
 def test_find_right_coef(simulated_fitted_any_model):
     if simulated_fitted_any_model.nb_cov == 2:
         true_coef = true_sim_2cov["beta"]
-        mse_coef = MSE(simulated_fitted_any_model.coef - true_coef)
+        mse_coef = MSE(simulated_fitted_any_model.coef.cpu() - true_coef.cpu())
         assert mse_coef < 0.1
     elif simulated_fitted_any_model.nb_cov == 0:
         assert simulated_fitted_any_model.coef is None
@@ -118,7 +118,7 @@ def test_batch(model):
     model.show()
     if model.nb_cov == 2:
         true_coef = true_sim_2cov["beta"]
-        mse_coef = MSE(model.coef - true_coef)
+        mse_coef = MSE(model.coef.cpu() - true_coef.cpu())
         assert mse_coef < 0.1
     elif model.nb_cov == 0:
         assert model.coef is None
diff --git a/tests/test_plnpcacollection.py b/tests/test_plnpcacollection.py
index 0d982d60..2c1db5a4 100644
--- a/tests/test_plnpcacollection.py
+++ b/tests/test_plnpcacollection.py
@@ -8,7 +8,6 @@ from tests.conftest import dict_fixtures
 from tests.utils import MSE, filter_models
 from tests.import_data import true_sim_0cov, true_sim_2cov
 
-
 @pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_model"])
 @filter_models(["PlnPCAcollection"])
 def test_best_model(plnpca):
@@ -80,9 +79,9 @@ def test_batch(collection):
     else:
         raise ValueError(f"Not the right numbers of covariance({collection.nb_cov})")
     for model in collection.values():
-        mse_covariance = MSE(model.covariance - true_covariance)
+        mse_covariance = MSE(model.covariance.cpu() - true_covariance.cpu())
         if true_coef is not None:
-            mse_coef = MSE(model.coef - true_coef)
+            mse_coef = MSE(model.coef.cpu() - true_coef.cpu())
             assert mse_coef < 0.35
         assert mse_covariance < 0.25
     collection.fit()
diff --git a/tests/test_zi.py b/tests/test_zi.py
index 2016accf..acfaa5bd 100644
--- a/tests/test_zi.py
+++ b/tests/test_zi.py
@@ -77,9 +77,9 @@ def test_find_right_covariance_coef_and_infla():
     )
     zi = ZIPln(endog, exog=exog, offsets=offsets, use_closed_form_prob=False)
     zi.fit()
-    mse_covariance = MSE(zi.covariance - covariance)
-    mse_coef = MSE(zi.coef - coef)
-    mse_coef_infla = MSE(zi.coef_inflation - coef_inflation)
+    mse_covariance = MSE(zi.covariance.cpu() - covariance.cpu())
+    mse_coef = MSE(zi.coef.cpu() - coef.cpu())
+    mse_coef_infla = MSE(zi.coef_inflation.cpu() - coef_inflation.cpu())
     assert mse_coef < 3
     assert mse_coef_infla < 3
     assert mse_covariance < 1
@@ -118,9 +118,9 @@ def test_batch(model):
     )
     zi = ZIPln(endog, exog=exog, offsets=offsets, use_closed_form_prob=False)
     zi.fit(batch_size=20)
-    mse_covariance = MSE(zi.covariance - covariance)
-    mse_coef = MSE(zi.coef - coef)
-    mse_coef_infla = MSE(zi.coef_inflation - coef_inflation)
+    mse_covariance = MSE(zi.covariance.cpu() - covariance.cpu())
+    mse_coef = MSE(zi.coef.cpu() - coef.cpu())
+    mse_coef_infla = MSE(zi.coef_inflation.cpu() - coef_inflation.cpu())
     assert mse_coef < 3
     assert mse_coef_infla < 3
     assert mse_covariance < 1
-- 
GitLab


From 71a5640bce62d54e158b5ca4754e9bd076196098 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Mon, 23 Oct 2023 15:24:05 +0200
Subject: [PATCH 68/68] balcked

---
 pyPLNmodels/_utils.py          |  8 ++++++--
 pyPLNmodels/models.py          | 28 ++++++++++++++++++++--------
 tests/test_common.py           |  4 +++-
 tests/test_plnpcacollection.py |  1 +
 4 files changed, 30 insertions(+), 11 deletions(-)

diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index b8fa8001..1cb9d2cd 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -199,8 +199,12 @@ def _log_stirling(integer: torch.Tensor) -> torch.Tensor:
         integer_ / math.exp(1)
     )
 
+
 def _trunc_log(tens: torch.Tensor, eps: float = 1e-16) -> torch.Tensor:
-    integer = torch.min(torch.max(tens, torch.tensor([eps]).to(DEVICE)), torch.tensor([1 - eps]).to(DEVICE))
+    integer = torch.min(
+        torch.max(tens, torch.tensor([eps]).to(DEVICE)),
+        torch.tensor([1 - eps]).to(DEVICE),
+    )
     return torch.log(integer)
 
 
@@ -929,7 +933,7 @@ def _extract_data_from_formula(
 
     """
     # dmatrices can not deal with GPU matrices
-    for key,matrix in data.items():
+    for key, matrix in data.items():
         if isinstance(matrix, torch.Tensor):
             data[key] = matrix.cpu()
     dmatrix = dmatrices(formula, data=data)
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 624bcf2a..d5854adf 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -237,7 +237,9 @@ class _model(ABC):
         sns.scatterplot(x=x, y=y, hue=colors, ax=ax)
         if show_cov is True:
             sk_components = torch.from_numpy(pca.components_).to(DEVICE)
-            covariances = self._get_pca_low_dim_covariances(sk_components).cpu().detach()
+            covariances = (
+                self._get_pca_low_dim_covariances(sk_components).cpu().detach()
+            )
             for i in range(covariances.shape[0]):
                 _plot_ellipse(x[i], y[i], cov=covariances[i], ax=ax)
         return ax
@@ -1821,7 +1823,6 @@ class Pln(_model):
         pass
         # no model parameters since we are doing a profiled ELBO
 
-
     @property
     @_add_doc(_model)
     def _list_of_parameters_needing_gradient(self):
@@ -2997,7 +2998,10 @@ class PlnPCA(_model):
         else:
             XB = 0
         return torch.exp(
-            self._offsets + XB + self.latent_variables.to(DEVICE) + 1 / 2 * covariance_a_posteriori
+            self._offsets
+            + XB
+            + self.latent_variables.to(DEVICE)
+            + 1 / 2 * covariance_a_posteriori
         )
 
     @latent_mean.setter
@@ -3597,9 +3601,13 @@ class ZIPln(_model):
         self._latent_mean = torch.randn(self.n_samples, self.dim).to(DEVICE)
         self._latent_sqrt_var = torch.randn(self.n_samples, self.dim).to(DEVICE)
         self._latent_prob = (
-            torch.empty(self.n_samples, self.dim).uniform_(0, 1).to(DEVICE)
-            * self._dirac
-        ).double().to(DEVICE)
+            (
+                torch.empty(self.n_samples, self.dim).uniform_(0, 1).to(DEVICE)
+                * self._dirac
+            )
+            .double()
+            .to(DEVICE)
+        )
 
     def _smart_init_latent_parameters(self):
         self._random_init_latent_parameters()
@@ -3744,8 +3752,12 @@ class ZIPln(_model):
         """
         if self._use_closed_form_prob is False:
             with torch.no_grad():
-                self._latent_prob = torch.maximum(self._latent_prob, torch.tensor([0]).to(DEVICE))
-                self._latent_prob = torch.minimum(self._latent_prob, torch.tensor([1]).to(DEVICE))
+                self._latent_prob = torch.maximum(
+                    self._latent_prob, torch.tensor([0]).to(DEVICE)
+                )
+                self._latent_prob = torch.minimum(
+                    self._latent_prob, torch.tensor([1]).to(DEVICE)
+                )
                 self._latent_prob *= self._dirac
 
     @property
diff --git a/tests/test_common.py b/tests/test_common.py
index df49f39c..bd5ca62c 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -53,7 +53,9 @@ def test_find_right_covariance(simulated_fitted_any_model):
         raise ValueError(
             f"Not the right numbers of covariance({simulated_fitted_any_model.nb_cov})"
         )
-    mse_covariance = MSE(simulated_fitted_any_model.covariance.cpu() - true_covariance.cpu())
+    mse_covariance = MSE(
+        simulated_fitted_any_model.covariance.cpu() - true_covariance.cpu()
+    )
     assert mse_covariance < 0.05
 
 
diff --git a/tests/test_plnpcacollection.py b/tests/test_plnpcacollection.py
index 2c1db5a4..761afabc 100644
--- a/tests/test_plnpcacollection.py
+++ b/tests/test_plnpcacollection.py
@@ -8,6 +8,7 @@ from tests.conftest import dict_fixtures
 from tests.utils import MSE, filter_models
 from tests.import_data import true_sim_0cov, true_sim_2cov
 
+
 @pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_model"])
 @filter_models(["PlnPCAcollection"])
 def test_best_model(plnpca):
-- 
GitLab