import numpy as np
import pandas as pd
import muon as mu
import anndata as ad
import torch
import torch.nn.functional as F
from torch import nn
from torch.distributions.normal import Normal
import time, os, pickle
from typing import Optional, Sequence, Union, Tuple, Dict, List
from datetime import timedelta
from sklearn.metrics import balanced_accuracy_score
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import OneHotEncoder
from scipy.stats import chi2
from scipy.sparse import csr_matrix
[docs]
def create_structure(
structure: List[int],
layer_order: list
) -> nn.Sequential:
"""
Builds neural network architectures based on a list of layer sizes and operation order.
Parameters
----------
structure : list of int
No. of neurons in each layer, including input and output dimensions.
For example, [input_dim, hidden1_dim, ..., output_dim]. Must have at least two entries.
layer_order : list
Sequence of layer specifications. Each element must be either
1: 'linear',
Affine linear transformation
2: 'batch_norm',
Batch normalization
3: 'layer_norm',
Layer normalization
4: ('act', activation: nn.Module or 'PReLU', [min_clip, max_clip] optional),
Activation function. Unbounded activation functions should be clipped for numerical stability, example: ('act', torch.nn.ReLU(), [0, 6])
5: ('dropout', dropout rate - float in [0, 1]),
Dropout layer, example: ('dropout', 0.1)
Returns
-------
nn.Sequential
A sequential container of PyTorch layers in the specified order for each pair in `structure`.
"""
layer_operations = [l if type(l) == str else l[0] for l in layer_order]
if 'dropout' in layer_operations:
dr_ind = layer_operations.index('dropout')
dropout = layer_order[dr_ind][1]
act_ind = layer_operations.index('act')
act = layer_order[act_ind][1]
if len(layer_order[act_ind]) == 3:
clip_act = layer_order[act_ind][-1]
else:
clip_act = False
layers = []
for neurons_in, neurons_out in zip(structure, structure[1:]):
for operation in layer_operations:
if operation == 'linear':
layers.append(nn.Linear(neurons_in, neurons_out))
elif operation == 'act':
if act == 'PReLU': act = nn.PReLU(num_parameters=neurons_out)
else: act = act
if clip_act != False:
layers.append(make_act_bounded(act, min=clip_act[0], max=clip_act[1]))
else:
layers.append(act)
elif operation == 'dropout':
layers.append(nn.Dropout(dropout))
elif operation == 'layer_norm':
layers.append(nn.LayerNorm(neurons_out))
elif operation == 'batch_norm':
layers.append(nn.BatchNorm1d(neurons_out))
return nn.Sequential(*layers)
[docs]
class make_act_bounded(nn.Module):
"""
Wrapper module that applies an activation and clips its output.
Parameters
----------
act : nn.Module
Activation function to apply.
min : float
Lower bound for clipping.
max : float
Upper bound for clipping.
"""
[docs]
def __init__(
self,
act: nn.Module,
min: float,
max: float
):
super().__init__()
self.act = act
self.min = min
self.max = max
[docs]
def forward(
self,
x: torch.Tensor
) -> torch.Tensor:
x = self.act(x)
return torch.clamp(x, min=self.min, max=self.max)
[docs]
class Encoder_outer(nn.Module):
"""
Outer encoder module that concatenates data and labels, then applies a feed-forward network.
Will be reinitialized by scSpecies after pre-training on a context dataset.
Parameters
----------
param_dict : dict
Dictionary with keys:
- 'data_dim' (int): Dimensionality of input data.
- 'label_dim' (int): Dimensionality of input labels.
- 'dims_enc_outer' (list of int): Hidden layer sizes after concatenation.
- 'layer_order' (list): See `create_structure` for format.
Attributes
----------
model : nn.Sequential
The feed-forward network created by `create_structure`.
"""
[docs]
def __init__(
self,
param_dict: dict
):
super(Encoder_outer, self).__init__()
layer_order=param_dict['layer_order'].copy()
structure = [param_dict['data_dim']+param_dict['label_dim']] + param_dict['dims_enc_outer']
self.model = create_structure(structure=structure,
layer_order=layer_order,
)
[docs]
def forward(
self,
data: torch.Tensor,
label_inp: torch.Tensor
) -> torch.Tensor:
"""
Forward pass through the outer encoder layers.
Parameters
----------
data : torch.Tensor
Input data tensor of shape (batch_size, data_dim).
label_inp : torch.Tensor
Input label tensor of shape (batch_size, label_dim).
Returns
-------
torch.Tensor
Encoded representation of shape (batch_size, dims_enc_outer[-1]).
"""
x = torch.cat((data, label_inp), dim=-1)
x = self.model(x)
return x
[docs]
class Encoder_inner(nn.Module):
"""
Inner encoder module producing Gaussian latent parameters and sampling latent variables.
Will be shared between context and target scVI self.
Parameters
----------
device : str
Device identifier for sampling ('cpu', 'mps' or 'cuda').
param_dict : dict
Dictionary with keys:
- 'dims_enc_outer' (list of int): Output dims of outer encoder.
- 'dims_enc_inner' (list of int): Hidden layer sizes for inner encoder.
- 'lat_dim' (int): Dimensionality of the latent space.
- 'layer_order' (list): See `create_structure`.
Attributes
----------
model : nn.Sequential
Feed-forward network for intermediate representation.
mu : nn.Linear
Linear layer mapping to latent mean.
log_sig : nn.Linear
Linear layer mapping to log-standard deviation.
sampling_dist : Normal
Standard normal distribution for sampling latent representations.
"""
[docs]
def __init__(
self,
device: str,
param_dict: dict
):
super(Encoder_inner, self).__init__()
structure = [param_dict['dims_enc_outer'][-1]] + param_dict['dims_enc_inner']
layer_order=param_dict['layer_order'].copy()
self.model = create_structure(structure=structure,
layer_order=layer_order,
)
self.mu = nn.Linear(structure[-1], param_dict['lat_dim'])
self.log_sig = nn.Linear(structure[-1], param_dict['lat_dim'])
self.sampling_dist = Normal(
torch.zeros(torch.Size([param_dict['lat_dim']]), device=torch.device(device)),
torch.ones(torch.Size([param_dict['lat_dim']]), device=torch.device(device)))
[docs]
def encode(
self,
inter: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute latent mean and log-std from intermediate representation.
Parameters
----------
inter : torch.Tensor
Intermediate features of shape (batch_size, dims_enc_inner[-1]).
Returns
-------
mu : torch.Tensor
Latent means of shape (batch_size, lat_dim).
log_sig : torch.Tensor
Latent log-standard deviations of shape (batch_size, lat_dim).
"""
x = self.model(inter)
mu = self.mu(x)
log_sig = self.log_sig(x)
return mu, log_sig
[docs]
def forward(
self,
inter: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Sample latent variable and compute KL-Divergence.
Parameters
----------
inter : torch.Tensor
Intermediate features from outer encoder.
Returns
-------
z : torch.Tensor
Sampled latent tensor of shape (batch_size, lat_dim).
kl_div : torch.Tensor
Scalar KL-Divergence across the batch.
"""
mu, log_sig = self.encode(inter)
eps = self.sampling_dist.sample(torch.Size([log_sig.size(dim=0)]))
kl_div = torch.mean(0.5 * torch.sum(mu.square() + torch.exp(2.0 * log_sig) - 1.0 - 2.0 * log_sig, dim=1))
z = mu + log_sig.exp() * eps
return z, kl_div
[docs]
class Library_encoder(nn.Module):
"""
Encoder for library size factor, modeling a 1D log-normal distribution.
Parameters
----------
device : str
Device identifier for sampling.
param_dict : dict
Dictionary with keys:
- 'data_dim', 'label_dim' (int): Input dims.
- 'dims_l_enc' (list of int): Hidden layer sizes.
- 'lib_mu_add' (float): Offset added to the mean.
- 'layer_order' (list): For `create_structure`.
Attributes
----------
model : nn.Sequential
Feed-forward network for concatenated input.
mu : nn.Linear
Layer mapping to log-mean of library.
log_sig : nn.Linear
Layer mapping to log-std of library.
sampling_dist : Normal
Standard normal for sampling.
mu_add : float
Added to the decoded mean.
"""
[docs]
def __init__(
self,
device,
param_dict,
):
super(Library_encoder, self).__init__()
structure = [param_dict['data_dim']+param_dict['label_dim']] + param_dict['dims_l_enc']
self.model = create_structure(structure=structure,
layer_order=param_dict['layer_order'],
)
self.mu_add = param_dict['lib_mu_add']
self.mu = nn.Linear(structure[-1], 1)
self.log_sig = nn.Linear(structure[-1], 1)
self.sampling_dist = Normal(
torch.zeros(torch.Size([1]), device=torch.device(device)),
torch.ones(torch.Size([1]), device=torch.device(device)))
[docs]
def encode(
self,
data: torch.Tensor,
label_inp: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute library log-mean and log-std from inputs.
Parameters
----------
data : torch.Tensor
Data tensor of shape (batch_size, data_dim).
label_inp : torch.Tensor
Label tensor of shape (batch_size, label_dim).
Returns
-------
mu : torch.Tensor
Adjusted log-mean of shape (batch_size, 1).
log_sig : torch.Tensor
Log-std of shape (batch_size, 1).
"""
x = torch.cat((data, label_inp), dim=-1)
x = self.model(x)
mu = self.mu(x)
log_sig = self.log_sig(x)
return mu + self.mu_add, log_sig
[docs]
def forward(
self,
data: torch.Tensor,
label_inp: torch.Tensor,
prior_mu: Optional[torch.Tensor] = None,
prior_sig: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Sample library factor and compute optional KL-Divergence to prior.
Parameters
----------
data : torch.Tensor
Input data.
label_inp : torch.Tensor
Input labels.
prior_mu : torch.Tensor, optional
Precomputed prior mean parameter for KL-Divergence.
prior_sig : torch.Tensor, optional
Precomputed prior std parameter for KL-Divergence.
Returns
-------
l : torch.Tensor
Sampled library factor of shape (batch_size, 1).
kl_div : torch.Tensor, optional
KL-Divergence if prior_mu and prior_sig are not None.
"""
mu, log_sig = self.encode(data, label_inp)
eps = self.sampling_dist.sample(torch.Size([log_sig.size(dim=0)]))
l = torch.exp(mu + log_sig.exp() * eps)
if prior_mu != None or prior_sig != None:
kl_div = torch.mean(prior_sig.log() - log_sig.squeeze() + (1 / torch.clamp((2.0 * prior_sig.square()), min=1e-7)) * ((mu.squeeze() - prior_mu) ** 2 + torch.exp(2.0 * log_sig.squeeze()) - prior_sig.square()))
return l, kl_div
else:
return l
[docs]
class Decoder(nn.Module):
"""
Decoder mapping latent and label inputs back to data distribution parameters.
Parameters
----------
param_dict : dict
Dictionary with keys:
- 'lat_dim', 'label_dim', 'data_dim' (int): Latent, label, and data dims.
- 'dims_dec' (list of int): Hidden layer sizes.
- 'layer_order' (list): For `create_structure`.
- 'data_distr' (str): 'nb' or 'zinb'.
- 'dispersion' (str): One of 'dataset', 'batch', 'cell'.
- 'dispersion' and 'data_distr' control parameter layers.
- 'homologous_genes' (list of int): Indices for homologous genes.
Attributes
----------
model : nn.Sequential
Feed-forward network for decoder.
rho_pre : nn.Linear
Linear layer for relative expression logits.
log_alpha : Parameter or nn.Linear
Dispersion parameter(s) depending on `dispersion`.
pi_nlogit : nn.Linear, optional
Zero-inflation logits if `data_distr == 'zinb'`.
"""
[docs]
def __init__(
self,
param_dict: dict
):
super(Decoder, self).__init__()
structure = [param_dict['lat_dim']+param_dict['label_dim']] + param_dict['dims_dec']
self.data_distr = param_dict['data_distr']
self.dispersion = param_dict['dispersion']
self.homologous_genes = np.array(param_dict['homologous_genes'])
self.non_hom_genes = np.setdiff1d(np.arange(param_dict['data_dim']), self.homologous_genes)
self.gene_ind = np.argsort(np.concatenate((self.homologous_genes, self.non_hom_genes)))
self.data_dim = param_dict['data_dim']
if self.data_distr not in ['zinb', 'nb']:
raise ValueError(f"data_distr must be a list containing these strings: {'zinb', 'nb'}")
if self.dispersion not in ['dataset', 'batch', 'cell']:
raise ValueError(f"dispersion must be a list containing these strings: {'dataset', 'batch', 'cell'}")
self.model = create_structure(structure=structure,
layer_order=param_dict['layer_order'],
)
self.rho_pre = nn.Linear(structure[-1], self.data_dim)
if self.dispersion == "dataset":
self.log_alpha = torch.nn.parameter.Parameter(data=torch.randn(self.data_dim)*0.1, requires_grad=True)
elif self.dispersion == "batch":
self.log_alpha = torch.nn.parameter.Parameter(data=torch.randn((param_dict['label_dim'], self.data_dim))*0.1, requires_grad=True)
elif self.dispersion == "cell":
self.log_alpha = nn.Linear(structure[-1], self.data_dim)
if self.data_distr == 'zinb':
self.pi_nlogit = nn.Linear(structure[-1], self.data_dim)
[docs]
def calc_nlog_likelihood(
self,
dec_outp: List[torch.Tensor],
library: torch.Tensor,
x: torch.Tensor,
eps: float = 1e-7
) -> torch.Tensor:
"""
Compute negative log-likelihood under NB or ZINB self.
Parameters
----------
dec_outp : list of torch.Tensor
[alpha, rho] or [alpha, rho, pi_nlogit] depending on distribution.
library : torch.Tensor
Library size factor.
x : torch.Tensor
Observed count data.
eps : float
Numerical stability constant.
Returns
-------
torch.Tensor
Negative log-likelihood per sample.
"""
if self.data_distr == 'nb':
alpha, rho = dec_outp
alpha = torch.clamp(alpha, min=eps)
rho = torch.clamp(rho, min=1e-8, max=1-eps)
mu = rho * library
p = torch.clamp(mu / (mu + alpha), min=eps, max=1-eps)
log_likelihood = x * torch.log(p) + alpha * torch.log(1.0 - p) - torch.lgamma(alpha) - torch.lgamma(1.0 + x) + torch.lgamma(x + alpha)
elif self.data_distr == 'zinb':
alpha, rho, pi_nlogit = dec_outp
alpha = torch.clamp(alpha, min=eps)
rho = torch.clamp(rho, min=1e-8, max=1-eps)
mu = rho * library
log_alpha_mu = torch.log(alpha + mu)
log_likelihood = torch.where(x < eps,
F.softplus(pi_nlogit + alpha * (torch.log(alpha) - log_alpha_mu)) - F.softplus(pi_nlogit),
- F.softplus(pi_nlogit) + pi_nlogit
+ alpha * (torch.log(alpha) - log_alpha_mu) + x * (torch.log(mu) - log_alpha_mu)
+ torch.lgamma(x + alpha) - torch.lgamma(alpha) - torch.lgamma(1.0 + x))
return - torch.sum(log_likelihood, dim=-1)
[docs]
def decode(
self,
z: torch.Tensor,
label_inp: torch.Tensor
) -> List[torch.Tensor]:
"""
Decode latent and label inputs to distribution parameters.
Parameters
----------
z : torch.Tensor
Latent tensor of shape (batch_size, lat_dim).
label_inp : torch.Tensor
Label tensor of shape (batch_size, label_dim).
Returns
-------
outputs : list of torch.Tensor
[alpha, rho] or [alpha, rho, pi_nlogit].
"""
x = torch.cat((z, label_inp), dim=-1)
x = self.model(x)
if self.dispersion == "dataset":
alpha = self.log_alpha.exp()
elif self.dispersion == "batch":
alpha = self.log_alpha[torch.argmax(label_inp, dim=-1)].exp()
elif self.dispersion == "cell":
alpha = self.log_alpha(x).exp()
rho_pre = self.rho_pre(x)
rho_pre_hom = F.softmax(rho_pre[:, self.homologous_genes], dim=-1) * len(self.homologous_genes)/self.data_dim
rho_pre_nonhom = F.softmax(rho_pre[:, self.non_hom_genes], dim=-1) * len(self.non_hom_genes)/self.data_dim
rho = torch.cat((rho_pre_hom, rho_pre_nonhom), dim=-1)[:, self.gene_ind]
outputs = [alpha, rho]
if self.data_distr == 'zinb':
pi_nlogit = self.pi_nlogit(x)
outputs.append(pi_nlogit)
return outputs
[docs]
def decode_homologous(
self,
z: torch.Tensor,
label_inp: torch.Tensor
) -> torch.Tensor:
"""
Decodes the latent variables and label input into gene expression for homologous genes.
This method is specifically used to asess and compare the log2-fold change between species.
Parameters
----------
z (Tensor): The latent space representation.
label_inp (Tensor): The label input tensor.
Returns
-------
Tensor: The decoded gene expression probabilities for homologous genes.
"""
if self.data_distr == 'zinb':
pi_nlogit = self.decode(z, label_inp)[-1]
pi_hom = torch.sigmoid(pi_nlogit[:, self.homologous_genes])
x = torch.cat((z, label_inp), dim=-1)
x = self.model(x)
rho_pre = self.rho_pre(x)
rho_hom = F.softmax(rho_pre[:, self.homologous_genes], dim=-1) * pi_hom
if self.data_distr == 'nb':
x = torch.cat((z, label_inp), dim=-1)
x = self.model(x)
rho_pre = self.rho_pre(x)
rho_hom = F.softmax(rho_pre[:, self.homologous_genes], dim=-1)
return rho_hom
[docs]
def forward(
self,
z: torch.Tensor,
label_inp: torch.Tensor,
library: torch.Tensor,
x: torch.Tensor
) -> torch.Tensor:
"""
Compute mean negative log-likelihood loss.
Parameters
----------
z : torch.Tensor
Latent representations.
label_inp : torch.Tensor
Labels.
library : torch.Tensor
Library size factors.
x : torch.Tensor
Observed data.
Returns
-------
torch.Tensor
Mean negative log-likelihood over batch.
"""
outputs = self.decode(z, label_inp)
n_log_likeli = self.calc_nlog_likelihood(outputs, library, x).mean()
return n_log_likeli
[docs]
class scSpecies():
"""
The scSpecies cross-species architecture alignment framework built on scVI.
This class implements end-to-end preprocessing, variational encoding, decoding, and alignment
for a “context” dataset (e.g., mouse) and a “target” dataset (e.g., human). It supports:
- Training scVI models on context and target (latent or intermediate alignment).
- Library size encoding and negative-binomial / zero-inflated NB likelihoods.
- Establishing a direct correspondece between traget can context cell via a likelihood-based similarity measure
- Latent-space nearest-neighbor label transfer based on the similarity measure.
- Log-fold-change computation of homologous genes.
Parameters
----------
device : str
PyTorch device identifier ('cpu', 'mps' or 'cuda').
mdata : mu.MuData
Multi-modal container holding context and target AnnData objects, set up by the `create_mdata` class.
directory : str
Base path for saving model parameters, data, and figures.
random_seed : int, default=369963
Seed for NumPy and PyTorch RNGs.
context_key : str, default='mouse'
Key in `mdata.mod` for the context dataset.
target_key : str, default='human'
Key in `mdata.mod` for the target dataset.
context_optimizer, target_optimizer : torch.optim.Optimizer classes
Optimizer constructors for context and target models.
context_hidden_dims_enc_outer, target_hidden_dims_enc_outer : list[int]
Hidden layer sizes for the outer encoders.
hidden_dims_enc_inner : list[int]
Hidden layer sizes for the inner encoder.
context_hidden_dims_l_enc, target_hidden_dims_l_enc : list[int]
Hidden layer sizes for the library encoder.
context_hidden_dims_dec, target_hidden_dims_dec : list[int]
Hidden layer sizes for the decoder.
context_layer_order, target_layer_order : list
Layer specification lists for `create_structure`.
b_s : int, default=128
Batch size for training and inference.
context_data_distr, target_data_distr : {'nb', 'zinb'}
Observation models for counts.
lat_dim : int, default=10
Dimensionality of the latent space.
context_dispersion, target_dispersion : {'dataset', 'batch', 'cell'}
Dispersion parameterization strategy.
alignment : {'inter', 'latent'}
Alignment mode between context and target. Either at the outer encoder output space or at the latent space.
k_neigh : int, default=25
Number of neighbors candidates for alignment from the data-level NNS.
top_percent : float, default=20
Percentile cutoff for selecting top-agreement neighbors.
context_beta_*, target_beta_*, eta_* : floats and ints
Schedules for KL and alignment weight ramps.
use_lib_enc : bool, default=True
Whether to include a library-size encoder.
"""
[docs]
def __init__(
self,
device: str,
mdata: mu.MuData,
directory: str,
random_seed: int = 369963,
context_key: str = 'mouse',
target_key: str = 'human',
context_optimizer: torch.optim.Optimizer = torch.optim.Adam,
target_optimizer: torch.optim.Optimizer = torch.optim.Adam,
context_hidden_dims_enc_outer: List[int] = [300],
target_hidden_dims_enc_outer: List[int] = [300],
hidden_dims_enc_inner: List[int] = [200],
context_hidden_dims_l_enc: List[int] = [200],
target_hidden_dims_l_enc: List[int] = [200],
context_hidden_dims_dec: List[int] = [200, 300],
target_hidden_dims_dec: List[int] = [200, 300],
lat_dim: int = 10,
context_layer_order: list = ['linear', 'layer_norm', ('act', nn.ReLU()), ('dropout', 0.1)],
target_layer_order: list = ['linear', 'layer_norm', ('act', nn.ReLU()), ('dropout', 0.1)],#
use_lib_enc: bool = True,
b_s: int = 128,
context_data_distr: str = 'zinb',
target_data_distr: str = 'zinb',
context_dispersion: str = 'batch',
target_dispersion: str = 'batch',
alignment: int = 'inter',
k_neigh: int = 25,
top_percent: float = 20,
context_beta_start: float = 0.1,
context_beta_max: float = 1,
context_beta_epochs_raise: int = 10,
target_beta_start: float = 0.1,
target_beta_max: float = 1,
target_beta_epochs_raise: int = 10,
eta_start: float = 10,
eta_max: float = 25,
eta_epochs_raise: int = 10,
):
self.context_likeli_hist_dict = []
self.target_likeli_hist_dict = []
self.mdata = mdata
torch.manual_seed(random_seed)
np.random.seed(random_seed)
self.rng = np.random.default_rng(random_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
_, hom_ind_context, hom_ind_target = np.intersect1d(mdata.mod[context_key].var_names, mdata.mod[target_key].var['var_names_transl'], return_indices=True)
self.config_dict = {
'random_seed': random_seed,
'device': device,
'inter_dim': hidden_dims_enc_inner[0],
'lat_dim': lat_dim,
'b_s': b_s,
'alignment': alignment,
'use_lib_enc': use_lib_enc
}
self.context_config = {
'context_key': context_key,
'optimizer': context_optimizer,
'homologous_genes': list(hom_ind_context),
'data_dim': self.mdata.mod[context_key].n_vars,
'label_dim': np.shape(self.mdata.mod[context_key].obsm['batch_label_enc'])[1],
'lib_mu_add': round(np.mean(self.mdata.mod[context_key].obs['library_log_mean']),5),
'dims_enc_outer': context_hidden_dims_enc_outer,
'dims_enc_inner': hidden_dims_enc_inner,
'dims_l_enc': context_hidden_dims_l_enc,
'lat_dim': lat_dim,
'dims_dec': context_hidden_dims_dec,
'layer_order': context_layer_order,
'data_distr': context_data_distr,
'dispersion': context_dispersion,
'beta_start': context_beta_start,
'beta_max': context_beta_max,
'beta_epochs_raise': context_beta_epochs_raise,
'beta': context_beta_start,
}
self.target_config = {
'target_key': target_key,
'optimizer': target_optimizer,
'homologous_genes': list(hom_ind_target),
'data_dim': self.mdata.mod[target_key].n_vars,
'label_dim': np.shape(self.mdata.mod[target_key].obsm['batch_label_enc'])[1],
'lib_mu_add': round(np.mean(self.mdata.mod[target_key].obs['library_log_mean']),5),
'dims_enc_outer': target_hidden_dims_enc_outer,
'dims_enc_inner': hidden_dims_enc_inner,
'dims_l_enc': target_hidden_dims_l_enc,
'lat_dim': lat_dim,
'dims_dec': target_hidden_dims_dec,
'layer_order': target_layer_order,
'data_distr': target_data_distr,
'dispersion': target_dispersion,
'beta_start': target_beta_start,
'beta_max': target_beta_max,
'beta_epochs_raise': target_beta_epochs_raise,
'beta': target_beta_start,
'k_neigh': k_neigh,
'top_percent': top_percent,
'eta_start': eta_start,
'eta_max': eta_max,
'eta_epochs_raise': eta_epochs_raise,
'eta': eta_start,
}
if self.context_config['dims_enc_outer'][-1] != self.target_config['dims_enc_outer'][-1]:
raise ValueError("Context and target dims_enc_outer must have the same output dimensions.")
self.create_directory(directory)
self.initialize()
[docs]
def get_batch(
self,
array: Union[torch.Tensor, Sequence],
step: int,
*,
perm: Optional[Sequence[int]] = None,
batch_size: Optional[int] = None
) -> Union[torch.Tensor, Sequence]:
"""
Slice out a minibatch and move to device.
Parameters
----------
array : Tensor or sequence
Data to batch (e.g., features, labels, indices).
step : int
Batch index.
perm : sequence of int, optional
Permutation for shuffling; if None, uses contiguous slices.
batch_size : int, optional
Number of samples per batch; defaults to `self.config_dict['b_s']`.
Returns
-------
Tensor or sequence
The selected batch, on the configured device if a Tensor.
"""
bs = batch_size if batch_size is not None else self.config_dict['b_s']
start = step * bs
end = start + bs
idx = perm[start:end] if perm is not None else slice(start, end)
batch = array[idx]
device = self.config_dict.get('device')
if isinstance(batch, torch.Tensor):
batch = batch.to(device)
return batch
[docs]
def initialize(
self,
initialize: str = 'both'
):
"""
Instantiate or reinstantiate context and/or target encoder and decoder modules.
Parameters
----------
initialize : {'context', 'context_decoder', 'target', 'both'}, default='both'
Which sub-model(s) to initialize.
"""
if initialize in ['context', 'context_decoder', 'both']:
print('Initializing context scVI model.')
self.context_decoder = Decoder(param_dict=self.context_config).to(self.config_dict['device'])
model_params = list(self.context_decoder.parameters())
if initialize != 'context_decoder':
self.context_encoder_inner = Encoder_inner(device=self.config_dict['device'], param_dict=self.context_config).to(self.config_dict['device'])
self.context_encoder_outer = Encoder_outer(param_dict=self.context_config).to(self.config_dict['device'])
model_params += list(self.context_encoder_outer.parameters()) + list(self.context_encoder_inner.parameters())
if self.config_dict['use_lib_enc']:
self.context_lib_encoder = Library_encoder(device=self.config_dict['device'], param_dict=self.context_config).to(self.config_dict['device'])
model_params += list(self.context_lib_encoder.parameters())
self.context_lib_encoder.__name__ = 'context_lib_encoder'
self.context_lib_encoder.eval()
self.context_optimizer = self.context_config['optimizer'](model_params)
self.context_encoder_inner.eval()
self.context_encoder_outer.eval()
self.context_decoder.eval()
self.context_encoder_inner.__name__ = 'context_encoder_inner'
self.context_encoder_outer.__name__ = 'context_encoder_outer'
self.context_decoder.__name__ = 'context_decoder'
self.context_optimizer.__name__ = 'context_optimizer'
if initialize in ['target', 'both']:
print('Initializing target scVI model.')
self.target_encoder_outer = Encoder_outer(param_dict=self.target_config).to(self.config_dict['device'])
self.target_decoder = Decoder(param_dict=self.target_config).to(self.config_dict['device'])
model_params = list(self.target_encoder_outer.parameters()) + list(self.target_decoder.parameters())
if self.config_dict['use_lib_enc']:
self.target_lib_encoder = Library_encoder(device=self.config_dict['device'], param_dict=self.target_config).to(self.config_dict['device'])
self.target_lib_encoder.eval()
self.target_lib_encoder.__name__ = 'target_lib_encoder'
model_params += list(self.target_lib_encoder.parameters())
if self.config_dict['alignment'] == 'latent':
self.target_encoder_inner = Encoder_inner(device=self.config_dict['device'], param_dict=self.context_config).to(self.config_dict['device'])
model_params += list(self.target_encoder_inner.parameters())
elif self.config_dict['alignment'] == 'inter':
self.target_encoder_inner = self.context_encoder_inner
self.target_optimizer = self.target_config['optimizer'](model_params)
self.target_encoder_inner.eval()
self.target_encoder_outer.eval()
self.target_decoder.eval()
self.target_encoder_inner.__name__ = 'target_encoder_inner'
self.target_encoder_outer.__name__ = 'target_encoder_outer'
self.target_decoder.__name__ = 'target_decoder'
self.target_optimizer.__name__ = 'target_optimizer'
[docs]
def create_directory(
self,
directory: str
):
"""
Create project subdirectories for parameters, data, and figures.
Parameters
----------
directory : str
Base output directory.
"""
if not os.path.exists(directory):
os.makedirs(directory)
print(f"Created directory '{directory}'.")
self.prm_dir = os.path.join(directory, 'params')
self.dat_dir = os.path.join(directory, 'data')
self.fig_dir = os.path.join(directory, 'figures')
for path in (self.prm_dir, self.dat_dir, self.fig_dir):
if not os.path.exists(path):
os.makedirs(path)
print(f"Created directory '{path}'.")
[docs]
def pkl(self, model_name: str, save_key: str): return os.path.join(self.prm_dir, f"{model_name}_{save_key}.pkl")
[docs]
def pth(self, model_name: str, save_key: str): return os.path.join(self.prm_dir, f"{model_name}_{save_key}.pth")
[docs]
def opt(self, model_name: str, save_key: str): return os.path.join(self.prm_dir, f"{model_name}_{save_key}.opt")
[docs]
def hmu(self, model_name: str, save_key: str): return os.path.join(self.dat_dir, f"{model_name}_{save_key}.h5mu")
[docs]
def save(
self,
models: str = 'both',
save_key: str = ''
):
"""
Serialize model configuration, optimizers, and context and/or target scVI weights to disk.
Parameters
----------
models : {'context', 'target', 'both'}
Which sub-models to save.
save_key : str
Suffix for filenames.
"""
path = os.path.join(self.prm_dir, f"config_dict.pkl")
with open(path, 'wb') as f:
pickle.dump(self.config_dict, f)
print(f"Saved {path}")
model_list = []
if models in ('context', 'both'):
model_list += [self.context_encoder_inner, self.context_encoder_outer, self.context_decoder]
if self.config_dict['use_lib_enc']:
model_list += [self.context_lib_encoder]
path = self.pkl('context_config', save_key)
with open(path, 'wb') as f:
pickle.dump(self.context_config, f)
print(f"Saved {path}")
path = self.opt('context_optimizer', save_key)
torch.save(self.context_optimizer.state_dict(), path)
print(f"Saved {path}")
if models in ('target', 'both'):
model_list += [self.target_encoder_inner, self.target_encoder_outer, self.target_decoder]
if self.config_dict['use_lib_enc']:
model_list += [self.target_lib_encoder]
path = self.pkl('target_config', save_key)
with open(path, 'wb') as f:
pickle.dump(self.target_config, f)
print(f"Saved {path}")
path = self.opt('target_optimizer', save_key)
torch.save(self.target_optimizer.state_dict(), path)
print(f"Saved {path}")
for model in model_list:
path = self.pth(model.__name__, save_key)
torch.save(model.state_dict(), path)
print(f'Saved {path}.')
[docs]
def save_mdata(
self,
save_key: str
):
"""
Write the assembled MuData object to `.h5mu`.
Parameters
----------
save_key : str
Suffix for the data filename.
"""
path = self.hmu(self.dat_dir, save_key)
self.mdata.write(path)
print(f'Saved {path}')
[docs]
def load(
self,
models: str = 'both',
save_key:str = ''
):
"""
Load previously saved configs, optimizers, and weights.
Parameters
----------
models : {'context', 'target', 'both'}
save_key : str
"""
path = os.path.join(self.prm_dir, f"config_dict.pkl")
with open(path, 'wb') as f:
pickle.dump(self.config_dict, f)
print(f"Loaded {path}")
if models in ('context','both'):
path = self.pkl('context_config', save_key)
with open(path, 'rb') as f:
self.context_config = pickle.load(f)
print(f"Loaded {path}")
if models in ('target','both'):
path = self.pkl('target_config', save_key)
with open(path, 'rb') as f:
self.target_config = pickle.load(f)
print(f"Loaded {path}")
if models in ('context','both'):
path = self.opt('context_optimizer', save_key)
state = torch.load(path, map_location=torch.device(self.config_dict['device']))
self.context_optimizer.load_state_dict(state)
print(f"Loaded {path}")
if models in ('target', 'both'):
path = self.opt('target_optimizer', save_key)
state = torch.load(path, map_location=torch.device(self.config_dict['device']))
self.target_optimizer.load_state_dict(state)
print(f"Loaded {path}")
model_list = []
if models in ('context', 'both'):
model_list += [self.context_encoder_outer, self.context_decoder, self.context_encoder_inner]
if self.config_dict['use_lib_enc']:
model_list += [self.context_lib_encoder]
if models in ('target', 'both'):
model_list += [self.target_encoder_outer, self.target_decoder, self.target_encoder_inner]
if self.config_dict['use_lib_enc']:
model_list += [self.target_lib_encoder]
for model in model_list:
path = self.pth(model.__name__, save_key)
model.load_state_dict(torch.load(path, map_location=torch.device(self.config_dict['device'])))
print(f"Loaded {path}")
if self.config_dict['alignment'] == 'inter':
self.target_encoder_inner = self.context_encoder_inner
[docs]
@staticmethod
def most_frequent(arr: np.ndarray) -> np.ndarray:
"""
Return the modal value of a 1D array.
Helper for the `label_transfer` function.
Parameters
----------
arr : array-like
Returns
-------
element
The value occurring most often.
"""
values, counts = np.unique(arr, return_counts=True)
return values[np.argmax(counts)]
[docs]
def transfer_labels_cell(
self,
target_ind: int,
context_obs_transfer: Union[List[str], str],
) -> pd.DataFrame:
"""
Calculate similarity scores for a specific target cell specified by its index in `self.mdata[target_key].X`
and all context cells. Transfers labels specifies in context_obs_transfer. Returns a dataframe
of context cells sorted by similarity scores.
Parameters
----------
target_ind : int
Target cell indices.
context_obs_transfer : str or List of str
Observation key from context dataset to return as columns in the outpt (e.g., 'cell_type').
Returns
-------
DataFrame
Context labels, source indices, and similarity scores with the specified target cell.
"""
if isinstance(context_obs_transfer, str):
context_obs_transfer = [context_obs_transfer]
context_inds = np.arange(self.mdata[self.context_config['context_key']].n_obs)[np.newaxis, :]
similarities = self.similarity_metric(np.full(np.arange(1).shape, target_ind, dtype=int), context_inds)
df_neigbor = self.mdata.mod[self.context_config['context_key']][np.argsort(similarities)].obs.copy()[context_obs_transfer]
df_neigbor['index'] = np.squeeze(np.argsort(similarities))
df_neigbor['similarity_score'] = np.squeeze(similarities[:, np.argsort(similarities)])
return df_neigbor
[docs]
def similarity_metric(
self,
target_ind: np.ndarray,
context_ind: np.ndarray,
b_s: Optional[int] = None,
b_sc: Optional[int] = None,
display = True,
) -> np.ndarray:
"""
Compute negative log-likelihood based similarity scores for target and context cells specified by their indices.
Parameters
----------
target_ind : array of integers
Traget cell indices in `self.mdata[target_key].X` shape (n_target, 1)
context_ind : array of integers
Context cell neighbors in `self.mdata[context_key].X` shape (n_target, k).
Calculates the similarity of k candidates for a specific entry in the first axis.
b_s : int, optional
Batch size for target.
b_sc : int, optional
Chunk size for context neighbors.
display : bool
If True, prints progress.
Returns
-------
similarities : ndarray
Contains the similarity scores between the context cells and their k candidates, shape (n_target, k).
"""
if b_s == None:
b_s = self.config_dict['b_s']
if b_sc == None:
b_sc = int(128*25/b_s)
k_neigh = np.shape(context_ind)[1]
steps = int(np.ceil(np.shape(target_ind)[0]/b_s)) # +1e-10
steps_c_ind = int(np.ceil(k_neigh/b_sc))
similarities = []
with torch.no_grad():
tic = time.time()
for step in range(steps):
if display == True and time.time() - tic > 0.5:
tic = time.time()
print('\rCalculate similarity metric. Step {}/{}.'.format(str(step), str(steps)), end='', flush=True)
target_ind_batch = self.get_batch(target_ind, step, batch_size=b_s)
target_adata_batch = self.mdata.mod[self.target_config['target_key']][target_ind_batch]
target_x_batch = torch.from_numpy(target_adata_batch.X.toarray()).to(self.config_dict['device'])
target_s_batch = torch.from_numpy(target_adata_batch.obsm['batch_label_enc']).to(self.config_dict['device'])
target_l_batch = torch.from_numpy(target_adata_batch.obsm['l_mu']).exp().to(self.config_dict['device'])
sim_batch_c = []
for step_c in range(steps_c_ind):
context_ind_batch = self.get_batch(context_ind, step, batch_size=b_s)[:, step_c*b_sc:(step_c+1)*b_sc]
s_interl = torch.repeat_interleave(target_s_batch, repeats=np.shape(context_ind_batch)[-1], dim=0)
l_interl = torch.repeat_interleave(target_l_batch, repeats=np.shape(context_ind_batch)[-1], dim=0)
x_interl = torch.repeat_interleave(target_x_batch, repeats=np.shape(context_ind_batch)[-1], dim=0)
context_ind_batch_sq = np.squeeze(np.reshape(context_ind_batch, (-1, target_x_batch.size(0)*np.shape(context_ind_batch)[-1])))
context_z_batch = torch.from_numpy(self.mdata.mod[self.context_config['context_key']].obsm['z_mu'][context_ind_batch_sq]).to(self.config_dict['device'])
outp_neighbors = self.target_decoder.decode(context_z_batch, s_interl)
outp = self.target_decoder.calc_nlog_likelihood(outp_neighbors, l_interl, x_interl).reshape(target_x_batch.size(0), np.shape(context_ind_batch)[-1]).cpu().numpy()
sim_batch_c.append(outp)
sim_batch_c = np.concatenate(sim_batch_c, axis=-1)
target_z_batch = torch.from_numpy(target_adata_batch.obsm['z_mu']).to(self.config_dict['device'])
outp_target = self.target_decoder.decode(target_z_batch, target_s_batch)
sim_batch_c -= self.target_decoder.calc_nlog_likelihood(outp_target, target_l_batch, target_x_batch).unsqueeze(-1).cpu().numpy()
similarities.append(sim_batch_c)
similarities = np.concatenate(similarities)
return similarities
[docs]
def ret_pred_df(
self,
pred_key: str,
target_label_key: str,
context_label_key: str
) -> Tuple[pd.DataFrame, float]:
"""
Compute a normalized confusion matrix (%) and balanced accuracy for label transfer.
This evaluates how well the predicted context-derived labels match the true labels
on the target dataset.
Parameters
----------
pred_key : str
Key in `self.mdata.mod[target_key].obs` under which predicted labels are stored.
target_label_key : str
Key in `self.mdata.mod[target_key].obs` for the ground-truth labels.
context_label_key : str
Key in `self.mdata.mod[context_key].obs` for the reference context labels.
Returns
-------
df : pd.DataFrame
Confusion matrix (in percent) with
- index: sorted labels of `target_label_key`,
- columns: sorted labels of `context_label_key`,
- values: percentage of cells with true label = row and predicted label = column.
bas : float
Balanced accuracy score computed only over the subset of cells whose true labels
also appear in the context set.
"""
predicted_labels = self.mdata.mod[self.target_config['target_key']].obs[pred_key].to_numpy()
target_labels = self.mdata.mod[self.target_config['target_key']].obs[target_label_key].to_numpy()
context_labels = self.mdata.mod[self.context_config['context_key']].obs[context_label_key].to_numpy()
unique_target_labels = np.unique(target_labels)
unique_context_labels = np.unique(context_labels)
joint_labels = np.intersect1d(context_labels, target_labels)
joint_ind = np.where(np.array([cell_label in joint_labels for cell_label in target_labels]))[0]
bas = balanced_accuracy_score(target_labels[joint_ind], predicted_labels[joint_ind])
df = pd.DataFrame(0, index=unique_target_labels, columns=unique_context_labels, dtype=int)
for true_lbl, pred_lbl in zip(target_labels, predicted_labels):
df.loc[true_lbl, pred_lbl] += 1
df = df.div(df.sum(axis=1), axis=0) * 100
return df, bas
[docs]
def similarity_metric_on_latent_space(
self,
precompute_neighbors: bool = True
) -> Tuple[np.ndarray, np.ndarray]:
"""
Compute similarity scores for the whole context and target dataset pairs.
Either for a precomputed set of neighbors based on the results of a latent spce neighborhood search to speed up computation
or the whole dataset. (Should only be done for small datasets.)
Parameters
----------
precompute_neighbors : bool
If True precomutes a set of 250 euclidean neighbors on the aligned latent space.
Returns
-------
similarities : ndarray of shape (target.n_obs, 250) or (target.n_obs, context.n_obs)
context_ind : ndarray of shape (target.n_obs, 250) or (target.n_obs, context.n_obs)
"""
if precompute_neighbors != True:
n_obs_context = self.mdata.mod[self.context_config['context_key']].n_obs
n_obs_target = self.mdata.mod[self.target_config['target_key']].n_obs
n_evals = n_obs_context * n_obs_target
prompt = (
f"Warning, not precomputing neighbors on the latent space will lead to "
f"a total of {n_evals} decoder evaluations. Proceed? [y/N] "
)
resp = input(prompt)
if resp.strip().lower() not in ('y', 'yes'):
print("Set precompute_neighbors to true.")
precompute_neighbors == True
else:
context_ind = np.concat([np.arange(n_obs_target)[:, np.newaxis] for i in range(n_obs_context)], axis=-1)
if precompute_neighbors == True:
print('Pre-computing latent space NNS with 250 neighbors using the euclidean distance.')
neigh = NearestNeighbors(n_neighbors=250, metric='euclidean')
neigh.fit(self.mdata.mod[self.context_config['context_key']].obsm['z_mu'])
_, context_ind = neigh.kneighbors(self.mdata.mod[self.target_config['target_key']].obsm['z_mu'])
target_ind = np.arange(self.mdata.mod[self.target_config['target_key']].n_obs)
similarities = self.similarity_metric(target_ind, context_ind, b_s=None, b_sc=None)
return similarities, context_ind
[docs]
def transfer_labels_data(
self,
context_obs_transfer: Union[List[str], str],
top_neigh: int = 25,
write_sim: bool = False
):
"""
Assign context-derived labels via similarity scores to each target cell by majority vote among its top candidates.
For each observation key in `context_obs_transfer`, finds the `top_neigh` most similar context
cells (based on decoder likelihood in latent space), takes the most frequent label among
those neighbors, and writes it into `self.mdata.mod[target_key].obs['pred_sim_<obs_key>']`.
When target cell annotation is unknown, the inferred values of the last entry in `context_obs_transfer` will serve as a replacement for target cell annotation in downstream analyses.
Parameters
----------
context_obs_transfer : List of str or str
One or more keys in `self.mdata.mod[context_key].obs` whose values to transfer.
top_neigh : int, default=25
Number of nearest neighbors to consider for the majority vote.
write_sim : bool, default=False
If True, also stores raw similarity scores and neighbor indices in
`self.mdata.mod[target_key].obsm['similarities']` and
`['similarities_ind']`.
"""
target_key = self.target_config['target_key']
similarities, context_ind = self.similarity_metric_on_latent_space()
if isinstance(context_obs_transfer, str):
context_obs_transfer = [context_obs_transfer]
if write_sim == True:
self.mdata.mod[target_key].obsm['similarities'] = pred_labels
self.mdata.mod[target_key].obsm['similarities_ind'] = context_ind
for obs_key in context_obs_transfer:
context_labels = self.mdata.mod[self.context_config['context_key']].obs[obs_key].to_numpy()
target_n_obs = self.mdata.mod[target_key].n_obs
pred_labels = np.stack([self.most_frequent(context_labels[context_ind[i][np.argsort(similarities[i])]][:top_neigh]) for i in range(target_n_obs)])
self.mdata.mod[target_key].obs['pred_sim_'+obs_key] = pred_labels
if self.mdata.mod[target_key].uns['metadata']['cell_key'] == 'unknown':
self.mdata.mod[target_key].uns['metadata']['cell_key_transferred'] = context_obs_transfer[-1]
print(f'Set {context_obs_transfer[-1]} as target cell key for downstream analyses.')
[docs]
@staticmethod
def average_slices(
array: np.ndarray,
slice_sizes: Sequence[int]
) -> np.ndarray:
"""
Compute the mean of consecutive subarrays of a flat 2D array.
Helper for `compute_logfold_change`.
Parameters
----------
array : ndarray, shape (sum(slice_sizes), n_features)
The concatenated data.
slice_sizes : sequence of int
Positive integers that sum to `array.shape[0]`.
Returns
-------
stacked_means : ndarray, shape (len(slice_sizes), n_features)
The mean of each slice.
"""
averages = []
start = 0
for size in slice_sizes:
end = start + size
slice_avg = np.mean(array[start:end], axis=0)
averages.append(slice_avg)
start = end
return np.stack(averages)
[docs]
@staticmethod
def filter_outliers(
data: np.ndarray,
confidence_level: float = 0.9
) -> Tuple[np.ndarray, np.ndarray]:
"""
Identify inlier and outlier rows based on the Mahalanobis distance.
Computes the Mahalanobis distance of each row in `data` from the multivariate mean,
uses a chi-squared cutoff at the given `confidence_level`, and returns boolean masks.
Helper for `compute_logfold_change`.
Parameters
----------
data : ndarray, shape (n_samples, n_features)
Input points in feature space.
confidence_level : float, default=0.9
Threshold percentile for declaring a point an inlier.
Returns
-------
inlier_mask : ndarray of bool, shape (n_samples,)
True for rows whose Mahalanobis distance is below the threshold.
outlier_mask : ndarray of bool, shape (n_samples,)
True for rows whose distance exceeds the threshold.
"""
mean = np.mean(data, axis=0)
data_centered = data - mean
cov_matrix = np.dot(data_centered.T, data_centered) / (data_centered.shape[0] - 1)
cov_inv = np.linalg.inv(cov_matrix)
m_dist = np.sqrt(np.sum(np.dot(data_centered, cov_inv) * data_centered, axis=1))
df = mean.shape[0]
threshold = np.sqrt(chi2.ppf(confidence_level, df))
filtered_data_ind = m_dist < threshold
outlier_ind = m_dist >= threshold
return filtered_data_ind, outlier_ind
[docs]
def generate_homologous_samples(
self,
samples: int = 5000,
target_cell_key = None,
b_s: int = 128,
confidence_level: float = 0.9
) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]:
"""
Decode homologous normalized expression profiles for context and target species by Monte Carlo sampling.
Parameters
----------
target_cell_key : str or None
Column name in `.obs` specifying inferred cell type labels for the target dataset;
samples : int, default=5000
Total number of decoded samples to return per cell type.
b_s : int, default=128
Batch size for decoding iterations.
confidence_level : float, default=0.9
Quantile threshold used in `filter_outliers` to remove extreme latent embeddings.
Returns
-------
target_rho_dict : dict of str to ndarray of shape (samples, genes)
Decoded normalized expression (`rho`) for shared cell types in the target species.
context_rho_dict : dict of str to ndarray of shape (samples, genes)
Decoded normalized expression (`rho`) for shared cell types in the context species.
"""
self.context_decoder.eval()
self.target_decoder.eval()
self.context_encoder_inner.eval()
self.target_encoder_inner.eval()
self.context_encoder_outer.eval()
self.target_encoder_outer.eval()
self.context_lib_encoder.eval()
self.target_lib_encoder.eval()
context_key = self.context_config['context_key']
target_key = self.target_config['target_key']
context_cell_key = self.mdata.mod[context_key].uns['metadata']['cell_key']
if target_cell_key == None:
target_cell_key = self.mdata.mod[target_key].uns['metadata']['cell_key']
context_batch_key = self.mdata.mod[context_key].uns['metadata']['batch_key']
target_batch_key = self.mdata.mod[target_key].uns['metadata']['batch_key']
if target_cell_key == 'unknown':
if 'cell_key_transferred' in self.mdata.mod[target_key].uns['metadata'].keys():
target_cell_key = self.mdata.mod[target_key].uns['metadata']['cell_key_transferred']
print(f'Use inferred context labels in {target_cell_key} for similarity calculation.')
else:
raise ValueError(f"Target cell labels must be known or transferred from the context dataset via `label_transfer`.")
context_cell_labels = self.mdata.mod[context_key].obs[context_cell_key].to_numpy()
context_cell_types = np.unique(context_cell_labels)
target_cell_labels = self.mdata.mod[target_key].obs[target_cell_key].to_numpy()
target_cell_types = np.unique(target_cell_labels)
target_cell_index = {c : np.where(target_cell_labels == c)[0] for c in target_cell_types}
context_batch_labels = self.mdata.mod[context_key].obs[context_batch_key].to_numpy().reshape(-1, 1)
target_batch_labels = self.mdata.mod[target_key].obs[target_batch_key].to_numpy().reshape(-1, 1)
context_enc = OneHotEncoder()
context_enc.fit(context_batch_labels)
target_enc = OneHotEncoder()
target_enc.fit(target_batch_labels)
context_batches = {c : self.mdata.mod[context_key][self.mdata.mod[context_key].obs[context_cell_key] == c].obs[context_batch_key].value_counts() > 3 for c in context_cell_types}
context_batches = {c : context_batches[c][context_batches[c]].index.to_numpy() for c in context_cell_types}
context_batches = {c : context_enc.transform(context_batches[c].reshape(-1, 1)).toarray().astype(np.float32) for c in context_cell_types}
context_batches['unknown'] = context_enc.transform(np.unique(context_batch_labels).reshape(-1, 1)).toarray().astype(np.float32)
joint_cell_types = np.intersect1d(context_cell_types, target_cell_types, return_indices=True)[0]
target_batches = {c : self.mdata.mod[target_key][self.mdata.mod[target_key].obs[target_cell_key] == c].obs[target_batch_key].value_counts() > 3 for c in target_cell_types}
target_batches = {c : target_batches[c][target_batches[c]].index.to_numpy() for c in target_cell_types}
target_batches = {c : target_enc.transform(target_batches[c].reshape(-1, 1)).toarray().astype(np.float32) for c in target_cell_types}
target_batches['unknown'] = target_enc.transform(np.unique(target_batch_labels).reshape(-1, 1)).toarray().astype(np.float32)
context_rho_dict = {}
target_rho_dict = {}
for cell_type in joint_cell_types:
adata_target = self.mdata.mod[target_key][target_cell_index[cell_type]]
filtered_data_ind, _ = self.filter_outliers(adata_target.obsm['z_mu'], confidence_level=confidence_level)
adata_target = adata_target[filtered_data_ind]
steps = np.ceil(adata_target.n_obs/b_s).astype(int)
iterations = int(np.ceil(samples/adata_target.n_obs))
with torch.no_grad():
context_rho_dict[cell_type] = []
target_rho_dict[cell_type] = []
for iter in range(iterations):
for step in range(steps):
batch_adata = adata_target[step*b_s:(step+1)*b_s]
context_cell_type = batch_adata.obs[context_cell_key].to_numpy()
target_cell_type = batch_adata.obs[target_cell_key].to_numpy()
context_labels = np.concatenate([context_batches[c] for c in context_cell_type])
target_labels = np.concatenate([target_batches[c] for c in target_cell_type])
context_labels = torch.from_numpy(context_labels).to(self.config_dict['device'])
target_labels = torch.from_numpy(target_labels).to(self.config_dict['device'])
context_ind_batch = np.array([np.shape(context_batches[c])[0] for c in context_cell_type])
target_ind_batch = np.array([np.shape(target_batches[c])[0] for c in target_cell_type])
shape = np.shape(batch_adata.obsm['z_sig'])
z = np.float32(batch_adata.obsm['z_mu'] + batch_adata.obsm['z_sig'] * np.random.rand(shape[0], shape[1]))
context_z = np.concatenate([np.tile(z[j], (i, 1)) for j, i in enumerate(context_ind_batch)])
target_z = np.concatenate([np.tile(z[j], (i, 1)) for j, i in enumerate(target_ind_batch)])
context_z = torch.from_numpy(context_z).to(self.config_dict['device'])
target_z = torch.from_numpy(target_z).to(self.config_dict['device'])
context_rho = self.context_decoder.decode_homologous(context_z, context_labels).cpu().numpy()
context_rho = self.average_slices(context_rho, context_ind_batch)
target_rho = self.target_decoder.decode_homologous(target_z, target_labels).cpu().numpy()
target_rho = self.average_slices(target_rho, target_ind_batch)
context_rho_dict[cell_type].append(context_rho)
target_rho_dict[cell_type].append(target_rho)
target_rho_dict[cell_type] = np.concatenate(target_rho_dict[cell_type])[:samples]
context_rho_dict[cell_type] = np.concatenate(context_rho_dict[cell_type])[:samples]
return target_rho_dict, context_rho_dict
[docs]
def compute_logfold_change(
self,
eval_cell_types: Optional[Sequence[str]] = None,
eps: float = 1e-6,
lfc_delta: float = 1,
samples: int = 50000,
target_cell_key = None,
b_s: int = 128,
confidence_level: float = 0.9
) -> Dict[str, pd.DataFrame]:
"""
Monte Carlo estimation of per-gene Log2-fold-changes and associated probabilities.
For each specified cell type (or the intersection of context/target types), samples
from the scVI posterior, computes the ratio of target vs. context expression for each
homologous gene, and aggregates:
- Median Log2-fold-change (on normalized decoder space),
- Probability(abs(Log2Fc) > lfc_delta),
- Mean gene expression on normalized decoder space and NB parameter space.
Parameters
----------
eval_cell_types : sequence of str, optional
Cell types to include; defaults to the intersection of context and target types.
eps : float, default=1e-6
Small constant added before log to prevent small gene expression patterns from returning large LFC values.
lfc_delta : float, default=1
Threshold for computing the probability of large fold-changes.
target_cell_key : str or None
Column name in `.obs` specifying inferred cell type labels for the target dataset;
samples : int, default=50000
Total number of Monte Carlo draws per cell.
b_s : int, default=128
Batch size for sampling iterations.
confidence_level : float, default=0.9
Outlier filtering threshold for latent space.
Returns
-------
lfc_dict : dict of str to pd.Dataframe
Dictionary with cell-wise data frames containing the keys:
- 'rho_median_context' : Median context normalized gene expression,
- 'mu_median_context' : Median context expected value gene expression,
- 'rho_median_target' : Median target normalized gene expression,
- 'mu_median_target' : Median target expected value gene expression,
- 'lfc' : Median Log2 fold-change of the relative expression parameter rho,
- 'p' : Probability of Log2 fold-change values greater than lfc_delta,
- 'lfc_rand' : Median Log2 fold-change of the relative expression parameter rho on permuted data,
- 'p_rand' : Probability of Log2 fold-change values greater than lfc_delta on permuted data.
"""
context_key = self.context_config['context_key']
target_key = self.target_config['target_key']
context_cell_key = self.mdata.mod[context_key].uns['metadata']['cell_key']
if target_cell_key == None:
target_cell_key = self.mdata.mod[target_key].uns['metadata']['cell_key']
context_batch_key = self.mdata.mod[context_key].uns['metadata']['batch_key']
target_batch_key = self.mdata.mod[target_key].uns['metadata']['batch_key']
if target_cell_key == 'unknown':
if 'cell_key_transferred' in self.mdata.mod[target_key].uns['metadata'].keys():
target_cell_key = self.mdata.mod[target_key].uns['metadata']['cell_key_transferred']
print(f'Use inferred context labels in {target_cell_key} for differential gene expression analysis.')
else:
raise ValueError(f"Target cell labels must be known or transferred from the context dataset via `label_transfer`.")
self.context_decoder.eval()
self.target_decoder.eval()
self.context_encoder_inner.eval()
self.target_encoder_inner.eval()
self.context_encoder_outer.eval()
self.target_encoder_outer.eval()
self.context_lib_encoder.eval()
self.target_lib_encoder.eval()
target_ind = np.array(self.target_config['homologous_genes'])
target_gene_names = self.mdata.mod[target_key].var_names.to_numpy()[target_ind]
context_cell_labels = self.mdata.mod[context_key].obs[context_cell_key].to_numpy()
context_cell_types = np.unique(context_cell_labels)
context_cell_index = {c : np.where(context_cell_labels == c)[0] for c in context_cell_types}
target_cell_labels = self.mdata.mod[target_key].obs[target_cell_key].to_numpy()
target_cell_types = np.unique(target_cell_labels)
target_cell_index = {c : np.where(target_cell_labels == c)[0] for c in target_cell_types}
context_batch_labels = self.mdata.mod[context_key].obs[context_batch_key].to_numpy().reshape(-1, 1)
target_batch_labels = self.mdata.mod[target_key].obs[target_batch_key].to_numpy().reshape(-1, 1)
context_enc = OneHotEncoder()
context_enc.fit(context_batch_labels)
target_enc = OneHotEncoder()
target_enc.fit(target_batch_labels)
context_batches = {c : self.mdata.mod[context_key][self.mdata.mod[context_key].obs[context_cell_key] == c].obs[context_batch_key].value_counts() > 3 for c in context_cell_types}
context_batches = {c : context_batches[c][context_batches[c]].index.to_numpy() for c in context_cell_types}
context_batches = {c : context_enc.transform(context_batches[c].reshape(-1, 1)).toarray().astype(np.float32) for c in context_cell_types}
context_batches['unknown'] = context_enc.transform(np.unique(context_batch_labels).reshape(-1, 1)).toarray().astype(np.float32)
if eval_cell_types==None:
eval_cell_types = np.intersect1d(context_cell_types, target_cell_types)
target_batches = {c : self.mdata.mod[target_key][self.mdata.mod[target_key].obs[target_cell_key] == c].obs[target_batch_key].value_counts() > 3 for c in target_cell_types}
target_batches = {c : target_batches[c][target_batches[c]].index.to_numpy() for c in target_cell_types}
target_batches = {c : target_enc.transform(target_batches[c].reshape(-1, 1)).toarray().astype(np.float32) for c in target_cell_types}
target_batches['unknown'] = target_enc.transform(np.unique(target_batch_labels).reshape(-1, 1)).toarray().astype(np.float32)
random_perm = np.random.permutation(len(target_gene_names))
lfc_dict = {}
for cell_type in eval_cell_types:
adata_context = self.mdata.mod[context_key][context_cell_index[cell_type]]
adata_target = self.mdata.mod[target_key][target_cell_index[cell_type]]
adata_context.obs_names_make_unique()
adata_target.obs_names_make_unique()
filtered_data_ind, _ = self.filter_outliers(adata_context.obsm['z_mu'], confidence_level=confidence_level)
adata_context = adata_context[filtered_data_ind].copy()
filtered_data_ind, _ = self.filter_outliers(adata_target.obsm['z_mu'], confidence_level=confidence_level)
adata_target = adata_target[filtered_data_ind].copy()
latent_target = adata_target.obsm['z_mu']
latent_context = adata_context.obsm['z_mu']
nn = NearestNeighbors(n_neighbors=25, metric='cosine', algorithm='auto')
nn.fit(latent_context)
distances, indices = nn.kneighbors(latent_target)
adata_target.obsm['cell_context_ind'] = indices
steps = np.ceil(adata_target.n_obs/b_s).astype(int)
sampling_size = max(int(samples / adata_target.n_obs), 1)
with torch.no_grad():
lfc_list = []
lfc_list_random = []
rho_mouse = []
mu_mouse = []
rho_human = []
mu_human = []
for step in range(steps):
batch_adata = adata_target[step*b_s:(step+1)*b_s]
context_cell_type = batch_adata.obs[target_cell_key].to_numpy()
target_cell_type = batch_adata.obs[target_cell_key].to_numpy()
context_labels = np.concatenate([context_batches[c] for c in context_cell_type])
target_labels = np.concatenate([target_batches[c] for c in target_cell_type])
context_labels = torch.from_numpy(context_labels).to(self.config_dict['device'])
target_labels = torch.from_numpy(target_labels).to(self.config_dict['device'])
context_ind_batch = np.array([np.shape(context_batches[c])[0] for c in context_cell_type])
target_ind_batch = np.array([np.shape(target_batches[c])[0] for c in target_cell_type])
shape = np.shape(batch_adata.obsm['z_sig'])
for k in range(sampling_size):
z = np.float32(batch_adata.obsm['z_mu'] + batch_adata.obsm['z_sig'] * np.random.rand(shape[0], shape[1]))
target_l = np.exp(np.float32(batch_adata.obsm['l_mu'] + batch_adata.obsm['l_sig'] * np.random.rand(shape[0], 1)))
neigh_ind = batch_adata.obsm['cell_context_ind']
context_l = np.exp(np.float32(adata_context.obsm['l_mu'][neigh_ind] + adata_context.obsm['l_sig'][neigh_ind] * np.random.rand(shape[0], 25, 1)))
context_l = context_l.mean(axis=1)
context_z = np.concatenate([np.tile(z[j], (i, 1)) for j, i in enumerate(context_ind_batch)])
target_z = np.concatenate([np.tile(z[j], (i, 1)) for j, i in enumerate(target_ind_batch)])
context_z = torch.from_numpy(context_z).to(self.config_dict['device'])
target_z = torch.from_numpy(target_z).to(self.config_dict['device'])
context_rho = self.context_decoder.decode_homologous(context_z, context_labels).cpu().numpy()
context_rho = self.average_slices(context_rho, context_ind_batch)
target_rho = self.target_decoder.decode_homologous(target_z, target_labels).cpu().numpy()
target_rho = self.average_slices(target_rho, target_ind_batch)
context_mu = context_rho * context_l
target_mu = target_rho * target_l
rho_mouse.append(context_rho)
mu_mouse.append(context_mu)
rho_human.append(target_rho)
mu_human.append(target_mu)
lfc_list.append(np.log2(target_rho+eps) - np.log2(context_rho+eps))
lfc_list_random.append(np.log2(target_rho+eps) - np.log2(context_rho[:, random_perm]+eps))
lfc_dict[cell_type] = pd.DataFrame(0, index=target_gene_names, columns=[
'rho_median_context', 'mu_median_context', 'rho_median_target', 'mu_median_target', 'lfc', 'p', 'lfc_rand', 'p_rand'])
rho_mouse = np.concatenate(rho_mouse)
mu_mouse = np.concatenate(mu_mouse)
rho_human = np.concatenate(rho_human)
mu_human = np.concatenate(mu_human)
lfc_dict[cell_type]['rho_median_context'] = np.median(rho_mouse, axis=0)
lfc_dict[cell_type]['mu_median_context'] = np.median(mu_mouse, axis=0)
lfc_dict[cell_type]['rho_median_target'] = np.median(rho_human, axis=0)
lfc_dict[cell_type]['mu_median_target'] = np.median(mu_human, axis=0)
lfc_list = np.concatenate(lfc_list)
lfc_dict[cell_type]['lfc'] = np.median(lfc_list, axis=0)
lfc_dict[cell_type]['p'] = np.sum(np.where(np.abs(lfc_list)>lfc_delta, 1, 0), axis=0) / np.shape(lfc_list)[0]
lfc_list_random = np.concatenate(lfc_list_random)
lfc_dict[cell_type]['lfc_rand'] = np.median(lfc_list_random, axis=0)
lfc_dict[cell_type]['p_rand'] = np.sum(np.where(np.abs(lfc_list_random)>lfc_delta, 1, 0), axis=0) / np.shape(lfc_list_random)[0]
lfc_dict['lfc_delta'] = lfc_delta
return lfc_dict
[docs]
@staticmethod
def mode_histogram(
x: np.array,
) -> np.float32:
"""
Return the mid-point of the histogram bin with the highest count.
Helper for .self.similarity_cell_types
Parameters
----------
x : np.array,
Array of values for which to calculate the modal value.
Returns
-------
mode: np.float32
modal value of the empirical distribution
"""
counts, edges = np.histogram(x, bins='fd')
j = np.argmax(counts)
return (edges[j] + edges[j+1]) / 2.0
[docs]
def return_similarity_df(
self,
max_sample_targ=2000,
max_sample_cont=50,
scale: str='none',
) -> pd.DataFrame:
"""
Compute and return similarity scores between target and context cell types
by sampling from latent cell type ditributions and calculating likelihood differences.
Computes the modal value of the resulting distribution as similarity score.
Parameters
----------
max_sample_targ : int, default=2000
Number of samples from the target cell types.
max_sample_cont : int, default=50
Number of samples from the context cell types per target cell.
scale : {'min_max', 'max', 'none'}, default='max'
Scaling strategy across rows: min-max normalization or max-based inversion.
Returns
-------
df : DataFrame
Similarity scores with
- index: target cell types,
- columns: context cell types.
"""
context_key = self.context_config['context_key']
target_key = self.target_config['target_key']
context_cell_key = self.mdata.mod[context_key].uns['metadata']['cell_key']
target_cell_key = self.mdata.mod[target_key].uns['metadata']['cell_key']
if target_cell_key == 'unknown':
if 'cell_key_transferred' in self.mdata.mod[target_key].uns['metadata'].keys():
target_cell_key = self.mdata.mod[target_key].uns['metadata']['cell_key_transferred']
print(f'Use inferred context labels in {target_cell_key} for similarity calculation.')
else:
raise ValueError(f"Target cell labels must be known or transferred from the context dataset via `label_transfer`.")
cells_context = self.mdata.mod[context_key].obs[context_cell_key]
cells_target = self.mdata.mod[target_key].obs[target_cell_key]
cell_types_context = np.unique(cells_context)
cell_types_target = np.unique(cells_target)
df = pd.DataFrame(0, index=cell_types_target, columns=cell_types_context, dtype=float)
for i in range(len(cell_types_target)):
for j in range(len(cell_types_context)):
print('\r{}/{} Similarity calculation for the {}-{} pair'.format(
str(i * len(cell_types_context) + j + 1),
str(len(cell_types_target) * len(cell_types_context)),
cell_types_target[i],
cell_types_context[j]
), end=' '*25, flush=True)
target_ind = np.where(cells_target == cell_types_target[i])[0]
target_ind = np.random.choice(target_ind, size=min(max_sample_targ, np.shape(target_ind)[0]), replace=False)
context_ind = np.where(cells_context == cell_types_context[j])[0]
context_ind = np.stack([np.random.choice(context_ind, size=min(max_sample_cont, np.shape(context_ind)[0]), replace=False) for k in range(np.shape(target_ind)[0])])
sim_metric = self.similarity_metric(target_ind, context_ind, b_s=250, b_sc=250, display=False)
sim_metric_mode = self.mode_histogram(sim_metric.flatten())
df.loc[cell_types_target[i], cell_types_context[j]] = sim_metric_mode
if scale == 'min_max':
df = (df - np.array(df.min(1))[:,np.newaxis]) / (np.array(df.max(1))[:,np.newaxis] - np.array(df.min(1))[:,np.newaxis])
if scale == 'max':
df = np.array(df.max(1))[:,np.newaxis] / df
if scale == 'none':
df = - df
return df
[docs]
@staticmethod
def update_param(
parameter: float,
min_value: float,
max_value: float,
steps: int
) -> float:
"""
Linearly increment `parameter` toward `max_value` over `steps`.
Parameters
----------
parameter : float
Current parameter value.
min_value : float
Starting value.
max_value : float
Final cap.
steps : int
Number of increments until max.
Returns
-------
float
Updated (and capped) parameter.
"""
if steps == 0 or min_value == max_value:
return parameter
parameter += (max_value - min_value) / steps
return min(parameter, max_value)
[docs]
def train_context(
self,
epochs: int = 40,
raise_beta: bool = True,
save_model: bool = True,
train_decoder_only: bool = False,
save_key: str = ''
):
"""
Pretrain the context scVI model on the context dataset.
Parameters
----------
epochs : int, default=40
Number of training epochs.
raise_beta : bool, default=True
If True, increase KL weight over initial epochs.
save_model : bool, default=True
If True, save model parameters after training.
train_decoder_only : bool, default=False
If True, freeze encoders and train only the decoder.
save_key : str, default=''
Filename suffix when saving.
"""
b_s = self.config_dict['b_s']
n_obs = self.mdata.mod[self.context_config['context_key']].n_obs
steps_per_epoch = int(n_obs/b_s)
if self.config_dict['use_lib_enc']:
progBar = Progress_Bar(epochs, steps_per_epoch, ['nELBO', 'nlog_likeli', 'KL-Div z', 'KL-Div l'])
else:
progBar = Progress_Bar(epochs, steps_per_epoch, ['nELBO', 'nlog_likeli', 'KL-Div z'])
print(f'Pretraining on the context dataset for {epochs} epochs (= {epochs*steps_per_epoch} iterations).')
x = torch.from_numpy(self.mdata.mod[self.context_config['context_key']].X.toarray())
s = torch.from_numpy(self.mdata.mod[self.context_config['context_key']].obsm['batch_label_enc'])
if self.config_dict['use_lib_enc']:
lib_mu = torch.from_numpy(self.mdata.mod[self.context_config['context_key']].obs['library_log_mean'].to_numpy())
lib_sig = torch.from_numpy(self.mdata.mod[self.context_config['context_key']].obs['library_log_std'].to_numpy())
if not train_decoder_only:
self.context_encoder_outer.train()
if self.config_dict['use_lib_enc']:
self.context_lib_encoder.train()
self.context_encoder_inner.train()
self.context_decoder.train()
for epoch in range(epochs):
perm = self.rng.permutation(n_obs)
for step in range(steps_per_epoch):
self.context_optimizer.zero_grad(set_to_none=True)
x_batch = self.get_batch(x, step, perm=perm, batch_size=b_s)
s_batch = self.get_batch(s, step, perm=perm, batch_size=b_s)
if self.config_dict['use_lib_enc']:
lib_mu_batch = self.get_batch(lib_mu, step, perm=perm, batch_size=b_s)
lib_sig_batch = self.get_batch(lib_sig, step, perm=perm, batch_size=b_s)
z_batch, z_kl_div = self.context_encoder_inner(self.context_encoder_outer(x_batch, s_batch))
if self.config_dict['use_lib_enc']:
l_batch, l_kl_div = self.context_lib_encoder(x_batch, s_batch, lib_mu_batch, lib_sig_batch)
else:
l_batch = x_batch.sum(-1).unsqueeze(-1)
nlog_likeli = self.context_decoder(z_batch, s_batch, l_batch, x_batch)
nelbo = nlog_likeli + self.context_config['beta'] * z_kl_div
if self.config_dict['use_lib_enc']:
nelbo = nelbo + self.context_config['beta'] * l_kl_div
nelbo.backward()
self.context_optimizer.step()
self.context_likeli_hist_dict.append(nlog_likeli.item())
if self.config_dict['use_lib_enc']:
progBar.update({'nELBO': nelbo.item(), 'nlog_likeli': nlog_likeli.item(), 'KL-Div z': (self.context_config['beta'] * z_kl_div).item(), 'KL-Div l': (self.context_config['beta'] * l_kl_div).item()})
else:
progBar.update({'nELBO': nelbo.item(), 'nlog_likeli': nlog_likeli.item(), 'KL-Div z': (self.context_config['beta'] * z_kl_div).item()})
if raise_beta:
self.context_config['beta'] = self.update_param(self.context_config['beta'], self.context_config['beta_start'], self.context_config['beta_max'], self.context_config['beta_epochs_raise'])
if not train_decoder_only:
self.context_encoder_outer.eval()
if self.config_dict['use_lib_enc']:
self.context_lib_encoder.eval()
self.context_encoder_inner.eval()
self.context_decoder.eval()
if save_model == True:
self.save('context',save_key=save_key)
[docs]
def train_target(
self,
epochs: int = 40,
save_model: bool = True,
raise_beta: bool = True,
raise_eta: bool = True,
save_key: str = '',
):
"""
Train the target-side scVI model, optionally aligning to context.
Parameters
----------
epochs : int, default=40
Number of training epochs.
save_model : bool, default=True
Save parameters after training.
raise_beta : bool, default=True
If True, increase KL weight over initial epochs.
raise_eta : bool, default=True
If True, increase alignment weight over initial epochs.
save_key : str, default=''
Suffix for saved files.
"""
context_cell_key=self.mdata.mod[self.context_config['context_key']].uns['metadata']['cell_key']
n_obs = self.mdata.mod[self.target_config['target_key']].n_obs
k_neigh = self.target_config['k_neigh']
top_percent = self.target_config['top_percent']
steps_per_epoch = int(n_obs/self.config_dict['b_s'])
if self.config_dict['use_lib_enc']:
progBar = Progress_Bar(epochs, steps_per_epoch, ['nELBO', 'nlog_likeli', 'KL-Div z', 'KL-Div l', 'Align-Term'])
else:
progBar = Progress_Bar(epochs, steps_per_epoch, ['nELBO', 'nlog_likeli', 'KL-Div z', 'Align-Term'])
print(f'Training on the target dataset for {epochs} epochs (= {epochs*steps_per_epoch} iterations).')
x = torch.from_numpy(self.mdata.mod[self.target_config['target_key']].X.toarray())
self.target_encoder_outer.train()
self.target_encoder_inner.eval()
if self.config_dict['use_lib_enc']:
self.target_lib_encoder.train()
self.target_decoder.train()
self.target_encoder_inner.train()
for epoch in range(epochs):
perm = self.rng.permutation(n_obs)
for step in range(steps_per_epoch):
self.target_optimizer.zero_grad(set_to_none=True)
batch_adata = self.mdata.mod[self.target_config['target_key']][perm[step*self.config_dict['b_s']:(step+1)*self.config_dict['b_s']]]
x_batch = self.get_batch(x, step, perm=perm)
s_batch = torch.from_numpy(batch_adata.obsm['batch_label_enc']).to(self.config_dict['device'])
if self.config_dict['use_lib_enc']:
lib_mu_batch = torch.from_numpy(batch_adata.obs['library_log_mean'].to_numpy()).to(self.config_dict['device'])
lib_sig_batch = torch.from_numpy(batch_adata.obs['library_log_std'].to_numpy()).to(self.config_dict['device'])
inter = self.target_encoder_outer(x_batch, s_batch)
z_batch, z_kl_div = self.target_encoder_inner(inter)
if self.config_dict['use_lib_enc']:
l_batch, l_kl_div = self.target_lib_encoder(x_batch, s_batch, lib_mu_batch, lib_sig_batch)
else:
l_batch = x_batch.sum(-1).unsqueeze(-1)
nlog_likeli = self.target_decoder(z_batch, s_batch, l_batch, x_batch)
ind_top = np.where(batch_adata.obs['top_percent_'+context_cell_key].to_numpy()<top_percent/100)[0]
if np.shape(ind_top)[0] < 1: ind_top = np.reshape(np.random.randint(self.config_dict['b_s']), (1,))
ind_neigh = batch_adata.obsm['ind_neigh_nns'][ind_top, :k_neigh]
neigh_mu = torch.from_numpy(self.mdata.mod[self.context_config['context_key']].obsm['z_mu'][ind_neigh]).to(self.config_dict['device'])
neigh_sig = torch.from_numpy(self.mdata.mod[self.context_config['context_key']].obsm['z_sig'][ind_neigh]).to(self.config_dict['device'])
neigh_z = neigh_mu + neigh_sig * self.target_encoder_inner.sampling_dist.sample(torch.Size([neigh_sig.size(dim=0), neigh_sig.size(dim=1)]))
s_interl = torch.repeat_interleave(s_batch[ind_top], repeats=k_neigh, dim=0)
l_interl = torch.repeat_interleave(l_batch[ind_top], repeats=k_neigh, dim=0)
x_interl = torch.repeat_interleave(x_batch[ind_top], repeats=k_neigh, dim=0)
outp = self.target_decoder.decode(neigh_z.view(-1, neigh_z.size(-1)), s_interl)
nlog_likeli_neighbors = self.target_decoder.calc_nlog_likelihood(outp, l_interl, x_interl).reshape(np.shape(ind_top)[0], k_neigh)
best_pin_for_x = torch.argmin(nlog_likeli_neighbors, dim=1).cpu().numpy()
if self.config_dict['alignment'] == 'inter':
align_target = torch.from_numpy(self.mdata.mod[self.context_config['context_key']].obsm['inter'][batch_adata.obsm['ind_neigh_nns'][ind_top, best_pin_for_x]]).to(self.config_dict['device'])
sqerror_align = torch.sum((inter[ind_top] - align_target)**2, dim=-1).mean()
elif self.config_dict['alignment'] == 'latent':
sqerror_align = torch.sum((z_batch[ind_top] - neigh_z[np.arange(len(ind_top)), best_pin_for_x])**2, dim=-1).mean()
nelbo = self.target_config['beta'] * z_kl_div + nlog_likeli + self.target_config['eta'] * sqerror_align
if self.config_dict['use_lib_enc']:
nelbo = nelbo + self.target_config['beta'] * l_kl_div
nelbo.backward()
self.target_optimizer.step()
self.target_likeli_hist_dict.append(nlog_likeli.item())
if self.config_dict['use_lib_enc']:
progBar.update({'nELBO': nelbo.item(), 'nlog_likeli': nlog_likeli.item(), 'KL-Div z': (self.target_config['beta'] * z_kl_div).item(), 'KL-Div l': (self.target_config['beta'] * l_kl_div).item(), 'Align-Term': (self.target_config['eta'] * sqerror_align).item()})
else:
progBar.update({'nELBO': nelbo.item(), 'nlog_likeli': nlog_likeli.item(), 'KL-Div z': (self.target_config['beta'] * z_kl_div).item(), 'Align-Term': (self.target_config['eta'] * sqerror_align).item()})
if raise_beta:
self.target_config['beta'] = self.update_param(self.target_config['beta'], self.target_config['beta_start'], self.target_config['beta_max'], self.target_config['beta_epochs_raise'])
if raise_eta:
self.target_config['eta'] = self.update_param(self.target_config['eta'], self.target_config['eta_start'], self.target_config['eta_max'], self.target_config['eta_epochs_raise'])
self.target_encoder_outer.eval()
if self.config_dict['use_lib_enc']:
self.target_lib_encoder.eval()
self.target_decoder.eval()
self.target_encoder_inner.eval()
if save_model == True:
self.save('target',save_key=save_key)
[docs]
def encode(
self,
x: torch.Tensor,
s: torch.Tensor,
encoder_outer: Optional[nn.Module] = None,
encoder_inner: Optional[nn.Module] = None,
lib_encoder: Optional[nn.Module] = None
) -> Union[
Tuple[np.ndarray,np.ndarray,np.ndarray],
Tuple[np.ndarray,np.ndarray,np.ndarray,np.ndarray,np.ndarray]
]:
"""
Encode data into biological and/or library latent variables.
Parameters
----------
x : Tensor, shape (n_cells, n_genes)
Raw or log-transformed count matrix.
s : Tensor, shape (n_cells, n_batches)
One-hot encoded batch labels.
encoder_outer : nn.Module, optional
Outer encoder; if None, skips z/inter outputs.
encoder_inner : nn.Module, optional
Inner encoder; if None, skips z/inter outputs.
lib_encoder : nn.Module, optional
Library encoder; if None, skips l_mu/l_sig outputs.
Returns
-------
Depending on provided encoders:
(z_mu, z_sig, inter) if `lib_encoder` is None.
(l_mu, l_sig) if only `lib_encoder` is provided.
(z_mu, z_sig, inter, l_mu, l_sig) if all provided.
"""
steps = int(np.ceil(x.size(0)/self.config_dict['b_s']+1e-10))
if encoder_outer != None and encoder_inner != None:
encoder_outer.eval()
encoder_inner.eval()
z_mu_list, z_sig_list, inter_list = [], [], []
if lib_encoder != None:
lib_encoder.eval()
l_mu_list, l_sig_list = [], []
with torch.no_grad():
tic = time.time()
for step in range(steps):
if time.time() - tic > 0.5:
tic = time.time()
print('\rCalculate latent variables. Step {}/{} '.format(str(step), str(steps)), end='', flush=True)
x_batch = self.get_batch(x, step)
s_batch = self.get_batch(s, step)
if encoder_outer != None and encoder_inner != None:
inter = encoder_outer(x_batch, s_batch)
z_mu, z_log_sig = encoder_inner.encode(inter)
z_mu_list.append(z_mu.cpu().numpy())
z_sig_list.append(z_log_sig.exp().cpu().numpy())
inter_list.append(inter.cpu().numpy())
if lib_encoder != None:
if self.config_dict['use_lib_enc']:
l_mu, l_log_sig = lib_encoder.encode(x_batch, s_batch)
else:
l_mu, l_log_sig = x_batch.sum(-1).log(), torch.zeros_like(x_batch.sum(-1))
l_mu_list.append(l_mu.cpu().numpy())
l_sig_list.append(l_log_sig.exp().cpu().numpy())
if encoder_outer != None and encoder_inner != None and lib_encoder == None:
return np.concatenate(z_mu_list), np.concatenate(z_sig_list), np.concatenate(inter_list)
elif encoder_outer == None and encoder_inner == None and lib_encoder != None:
return np.concatenate(l_mu_list), np.concatenate(l_mu_list)
elif encoder_outer != None and encoder_inner != None and lib_encoder != None:
return np.concatenate(z_mu_list), np.concatenate(z_sig_list), np.concatenate(inter_list), np.concatenate(l_mu_list), np.concatenate(l_sig_list)
[docs]
def get_representation(
self,
eval_model: str,
save_intermediate: bool = False,
save_libsize: bool = False
):
"""
Compute and store biological latent and/or library latent representations for a dataset.
Parameters
----------
eval_model : {'context','target'}
Which dataset to encode.
save_intermediate : bool, default=False
If True, store the outer encoder output in `.obsm['inter']`.
save_libsize : bool, default=False
If True, store library mean/log-std in `.obsm['l_mu']`/`['l_sig']`.
"""
if eval_model == 'target':
dataset_key = self.target_config['target_key']
encoder_outer = self.target_encoder_outer
encoder_inner = self.target_encoder_inner
lib_encoder = self.target_lib_encoder
elif eval_model == 'context':
dataset_key = self.context_config['context_key']
encoder_outer = self.context_encoder_outer
encoder_inner = self.context_encoder_inner
lib_encoder = self.context_lib_encoder
x = torch.from_numpy(self.mdata.mod[dataset_key].X.toarray())
s = torch.from_numpy(self.mdata.mod[dataset_key].obsm['batch_label_enc'])
if save_libsize == False:
if self.config_dict['alignment'] == 'inter':
z_mu, z_sig, inter = self.encode(x, s, encoder_outer=encoder_outer, encoder_inner=encoder_inner)
elif self.config_dict['alignment'] == 'latent':
z_mu, z_sig, inter = self.encode(x, s, encoder_outer=encoder_outer, encoder_inner=encoder_inner)
elif save_libsize == True:
if self.config_dict['alignment'] == 'inter':
z_mu, z_sig, inter, l_mu, l_sig = self.encode(x, s, encoder_outer=encoder_outer, encoder_inner=encoder_inner, lib_encoder=lib_encoder)
elif self.config_dict['alignment'] == 'latent':
z_mu, z_sig, inter, l_mu, l_sig = self.encode(x, s, encoder_outer=encoder_outer, encoder_inner=encoder_inner, lib_encoder=lib_encoder)
self.mdata.mod[dataset_key].obsm['z_mu'] = z_mu
self.mdata.mod[dataset_key].obsm['z_sig'] = z_sig
if save_intermediate:
self.mdata.mod[dataset_key].obsm['inter'] = inter
if save_libsize:
self.mdata.mod[dataset_key].obsm['l_mu'] = l_mu
self.mdata.mod[dataset_key].obsm['l_sig'] = l_sig
[docs]
def color_str(value, mode):
text = str(value)
if mode == "context":
return f"\033[38;5;208m{text}\033[0m"
elif mode == "target":
return f"\033[38;5;135m{text}\033[0m"
[docs]
class Progress_Bar():
"""
A console progress bar that tracks multiple metrics over training iterations of scSpecies.
Parameters
----------
epochs : int
Total number of epochs.
steps_per_epoch : int
Number of steps (batches) in each epoch.
metrics : list of str
Names of metrics to track (e.g., ['nELBO', 'nlog_likeli']).
avg_over_n_steps : int, optional
Number of recent steps over which to average metric values for display.
sleep_print : float, optional
Interval in seconds between console updates.
"""
[docs]
def __init__(
self,
epochs: int,
steps_per_epoch: int,
metrics: List[str],
avg_over_n_steps: int = 100,
sleep_print: float = 0.5
):
self.epochs = epochs
self.steps_per_epoch = steps_per_epoch
self.total_steps = self.epochs * steps_per_epoch
self.remaining_steps = self.epochs * steps_per_epoch
self.avg_over_n_steps = avg_over_n_steps
self.tic = time.time()
self.sleep_print = sleep_print
self.iteration = 0
self.metrics = metrics
self.time_metrics = ['Progress', 'ETA', 'Epoch', 'Iteration', 'ms/Iteration']
self.dict = {
'Progress' : "0.000%",
'ETA' : 0.0,
'Epoch' : int(1),
'Iteration' : int(0),
'ms/Iteration' : 0.0,
'time': [time.time()]
}
self.dict.update({metric: [] for metric in metrics})
self.dict.update({metric+' last ep': [] for metric in metrics})
self.dict.update({metric+' impr': 0.0 for metric in metrics})
[docs]
def ret_sign(self, number, min_length):
if number > 0.0:
sign_str = '\033[92m{}\033[00m'.format("+" + self.format_number(np.abs(number), min_length))
elif number < 0.0:
sign_str = '\033[91m{}\033[00m'.format("-" + self.format_number(np.abs(number), min_length))
else:
sign_str = '---'
return sign_str
[docs]
def update(self, values):
self.remaining_steps -= 1
for key, value in values.items():
self.dict[key].append(value)
if self.dict['Iteration'] == 1:
for key, value in values.items():
self.dict[key+' last ep'].append(value)
self.dict['Iteration'] += 1
epoch = int(np.ceil(self.dict['Iteration'] / self.steps_per_epoch))
if self.dict['Epoch'] < epoch:
for key in self.metrics:
self.dict[key+' last ep'].append(np.mean(self.dict[key][-self.steps_per_epoch:]))
self.dict[key+' impr'] = self.dict[key+' last ep'][-2] - self.dict[key+' last ep'][-1]
self.dict['Epoch'] = epoch
self.dict['time'].append(time.time())
avg_steps = np.min((self.dict['Iteration'], self.avg_over_n_steps))
avg_time = (self.dict['time'][-1] - self.dict['time'][-avg_steps-1]) / avg_steps
self.dict['ETA'] = timedelta(seconds=int(self.remaining_steps * avg_time))
self.dict['ms/Iteration'] = self.format_number(avg_time*1000.0, 4)
self.dict['Progress'] = self.format_number(100.0 * self.dict['Iteration'] / self.total_steps, 3)+'%'
if time.time() - self.tic > self.sleep_print:
metric_string = [f'\033[95m{key}\033[00m: {self.dict[key]}' for key in self.time_metrics]
metric_string += [f'\033[33m{key}\033[00m: {self.format_number(np.mean(self.dict[key][-avg_steps:]), 5)} ({self.ret_sign(self.dict[key+" impr"], 4)})' for key in self.metrics]
metric_string = "\033[96m - \033[00m".join(metric_string)
print(f"\r{metric_string}. ", end='', flush=True)
self.tic = time.time()
[docs]
def neighbors_workaround(
adata: ad.AnnData,
use_rep: Optional[str] = None,
n_neighbors: int = 15,
metric: str = 'euclidean'
) -> ad.AnnData:
"""
Compute the k-nearest-neighbors graph manually and store it in `adata`.
Replacement for sc.pp.neighbors on M1/M2 chips to avoid kernel crashes.
Parameters
----------
adata : ad.AnnData
Annotated data object.
use_rep : str
Key in `adata.obsm` to use for neighbor search (e.g. 'X_pca'),
or None to use `adata.X`.
n_neighbors : int
Number of nearest neighbors to use.
metric : str or None
Distance metric to use (default 'euclidean').
Returns
-------
AnnData
The same `adata`, with:
- obsp['distances'] : sparse matrix of neighbor distances
- obsp['connectivities'] : sparse binary connectivity matrix
- uns['neighbors'] : dict of params & key names
"""
if use_rep is None:
X = adata.X
rep_key = 'X'
else:
X = adata.obsm[use_rep]
rep_key = use_rep
if hasattr(X, "toarray"):
X = X.toarray()
n_obs = X.shape[0]
nbrs = NearestNeighbors(n_neighbors=n_neighbors, metric=metric).fit(X)
distances, indices = nbrs.kneighbors(X)
rows = np.repeat(np.arange(n_obs), n_neighbors)
cols = indices.flatten()
data_dist = distances.flatten()
dist_matrix = csr_matrix((data_dist, (rows, cols)), shape=(n_obs, n_obs))
data_conn = np.ones_like(data_dist)
conn_matrix = csr_matrix((data_conn, (rows, cols)), shape=(n_obs, n_obs))
adata.obsp['distances'] = dist_matrix
adata.obsp['connectivities'] = conn_matrix
adata.uns['neighbors'] = {
'params': {
'n_neighbors': n_neighbors,
'method': 'umap',
'use_rep': rep_key,
'metric': metric
},
'distances_key': 'distances',
'connectivities_key': 'connectivities'
}
return adata