from typing import Union, Optional, List
from pathlib import Path
import numpy as np
import scanpy as sc
import pandas as pd
import anndata as ad
import muon as mu
from collections import Counter
import io, os, requests, contextlib, torch, random
from mygene import MyGeneInfo
from sklearn.preprocessing import OneHotEncoder
from sklearn.neighbors import NearestNeighbors
from scipy.sparse import csr_matrix
[docs]
def map_homologs(
gene_list: list[str],
target_NCBI_Taxon_ID: int,
context_NCBI_Taxon_ID: int
) -> list[str]:
"""
Maps a list of gene symbols from the target species to their homologous symbols
of the context species using MyGeneInfo.
Parameters
----------
gene_list : list[str]
Gene symbols in the target species to be translated.
target_NCBI_Taxon_ID : int
NCBI Taxonomy ID of the target species.
context_NCBI_Taxon_ID : int
NCBI Taxonomy ID of the source (context) species.
Returns
-------
list[str]
Homologous gene symbols in the target species, with 'non_hom_<i>' for non homologous genes.
"""
mg = MyGeneInfo()
results = mg.querymany(
gene_list,
scopes='symbol',
species=target_NCBI_Taxon_ID,
fields='homologene',
as_dataframe=False
)
homolog_ids = {}
ids_to_lookup = set()
for res in results:
src = res['query']
if res.get('notfound') or 'homologene' not in res:
homolog_ids[src] = []
else:
hits = [g[1] for g in res['homologene']['genes'] if g[0] == context_NCBI_Taxon_ID]
homolog_ids[src] = hits
ids_to_lookup.update(hits)
id_to_symbol = {}
if ids_to_lookup:
lookup = mg.querymany(
list(ids_to_lookup),
scopes='entrezgene',
fields='symbol',
species=context_NCBI_Taxon_ID,
as_dataframe=False
)
for hit in lookup:
try:
eid = int(hit['query'])
if 'symbol' in hit:
id_to_symbol[eid] = hit['symbol']
except (KeyError, ValueError):
continue
mapped = []
for i,g in enumerate(gene_list):
syms = [id_to_symbol[eid] for eid in homolog_ids.get(g, []) if eid in id_to_symbol]
if not syms:
mapped.append('non_hom_'+str(i))
else:
mapped.append(syms[0])
return mapped
[docs]
def map_homologs_silent(
gene_list: list[str],
target_NCBI_Taxon_ID: int,
context_NCBI_Taxon_ID: int
) -> list[str]:
"""
Same as `map_homologs` but suppresses all console output
as map_homologs outputs a print statement for each gene.
Parameters
----------
gene_list : list[str]
Gene symbols in the target species to be translated.
target_NCBI_Taxon_ID : int
NCBI Taxonomy ID of the target species.
context_NCBI_Taxon_ID : int
NCBI Taxonomy ID of the source (context) species.
Returns
-------
list[str]
Homologous gene symbols in the target species, with 'non_hom_<i>' for non homologous genes.
"""
buf = io.StringIO()
with contextlib.redirect_stdout(buf), contextlib.redirect_stderr(buf):
result = map_homologs(gene_list, target_NCBI_Taxon_ID, context_NCBI_Taxon_ID)
return result
[docs]
def get_key(
gene,
homology_targsp_df,
homology_context_df,
i
) -> str:
"""
Retrieve the homologous context gene symbol for a given target gene using homology tables from
informatics.jax.org/downloads/reports/HOM_AllOrganism.rpt
Can only be used for mouse, rat, human, zebrafish context-target dataset pairs,
Parameters
----------
gene : str
Gene symbol in the ‘from’ DataFrame.
homology_targsp_df : pandas.DataFrame
Homology table for the target species (columns include 'Symbol' and 'DB Class Key').
homology_context_df : pandas.DataFrame
Homology table for the context species (same key column).
i : int
Index of the gene in the original list, used to name unmapped genes.
Returns
-------
str
Context‐species gene symbol if found, otherwise 'non_hom_<i>'.
"""
targ_gene_names = 'non_hom_'+str(i)
if gene in homology_targsp_df['Symbol'].unique():
key = homology_targsp_df[homology_targsp_df['Symbol'] == gene]['DB Class Key'].values[0]
if key in homology_context_df['DB Class Key'].unique():
targ_gene_names = homology_context_df[homology_context_df['DB Class Key'] == key]['Symbol'].values[0]
return targ_gene_names
[docs]
def download_datasets():
"""
Download liver cell .h5ad datasets into ./data directory.
Downloads each file and skips files already present.
Raises
------
requests.HTTPError
If any of the dataset URLs returns a bad status.
"""
data_urls = {
"human_liver.h5ad": "https://zenodo.org/records/15522251/files/human_liver.h5ad?download=1",
"mouse_liver.h5ad": "https://zenodo.org/records/15522251/files/mouse_liver.h5ad?download=1",
"hamster_liver.h5ad": "https://zenodo.org/records/15522251/files/hamster_liver.h5ad?download=1",
}
data_path = Path("data")
data_path.mkdir(parents=True, exist_ok=True)
for fname, url in data_urls.items():
out_file = data_path / fname
if out_file.exists():
print(f"{fname} already exists. Skipping download.")
continue
print(f"Downloading {fname} …")
resp = requests.get(url) # no stream=True
resp.raise_for_status()
with open(out_file, "wb") as f:
f.write(resp.content)
size_mb = out_file.stat().st_size / 1024 / 1024
print(f"{fname} downloaded. Size: {size_mb:.2f} MB")
print("All datasets have been downloaded to the ./data directory.")
[docs]
def set_random_seed(
seed: int
):
"""
Fix all relevant RNG seeds for reproducibility.
Parameters
----------
seed : int
The seed value to use for Python, NumPy, random, and PyTorch.
"""
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
[docs]
class create_mdata():
"""
Builder for MuData container that is used by scSpecies to align context & target AnnData datasets.
Handles downloading a gene-translation table from the mouse to human genome,
preprocessing a “context” AnnData, and “target” AnnData from potentially multiple species, and saving the
final MuData object.
"""
[docs]
def __init__(
self,
adata: ad.AnnData,
batch_key: str,
cell_key: str,
dataset_name: str = 'mouse',
NCBI_Taxon_ID: int = 10090,
n_top_genes: Union[int, None] = None,
min_non_zero_genes: float = 0.025,
min_cell_type_size: int = 20,
min_batch_size: int = 20,
):
"""
Initialize and preprocess the context dataset.
Steps:
1. Onehot-encode experimental batchs.
2. Calculate library size encoder prior parameters for scVI
3. Subset to top HVGs and filter out cells with low expression patterns as well as rare cell types and batches (optionally).
Parameters
----------
adata : ad.AnnData
AnnData used as a context in scSpecies.
batch_key : str
Observation key for experimental batch labels.
cell_key : str
Observation key for cell-type annotation.
dataset_name : str, optional
Tag for the context dataset (default 'mouse').
NCBI_Taxon_ID : int, optional
Taxonomy ID of the context species (default mouse - 10090).
n_top_genes : int or None, optional
Number of HVGs to retain (None to skip) (default None).
min_non_zero_genes : float, optional
Min fraction of nonzero genes per cell (default 0.025).
min_cell_type_size : int, optional
Min cells per cell-type, cell types with fewer samples are removed (default 20).
min_batch_size : int, optional
Min cells per batch for encoding, batch with fewer samples are removed (default 20).
Effects
-------
- Ensures a `data/` directory exists.
- Annotates `adata.uns['metadata']` with context dataset info.
- One-hot encodes batch labels, dropping any batches smaller than `min_batch_size`.
- Computes per-batch library size prior parameters.
- Subsets to top highly variable genes if `n_top_genes` is not None.
- Filters out cells with low gene detection and rare cell-types.
- Stores the processed AnnData in `self.dataset_collection`.
"""
adata = adata.copy()
mu.set_options(pull_on_update=False)
self.min_non_zero_genes = min_non_zero_genes
self.min_cell_type_size = min_cell_type_size
self.min_batch_size = min_batch_size
out_dir = Path("data")
out_dir.mkdir(parents=True, exist_ok=True)
self.context_dataset_name = dataset_name
self.context_cell_key = cell_key
self.context_NCBI_Taxon_ID = NCBI_Taxon_ID
if adata.isbacked:
adata = adata.copy()
adata.uns['metadata'] = {
'name': dataset_name,
'batch_key': batch_key,
'cell_key': cell_key,
'NCBI_Taxon_ID': NCBI_Taxon_ID,
'function': 'context',
}
adata.obs['dataset'] = dataset_name
adata.obs.index = adata.obs.index.astype(str) + f"_{dataset_name}"
adata = self.encode_batch_labels(adata, self.min_batch_size)
adata = self.compute_lib_prior_params(adata)
if n_top_genes != None:
adata = self.subset_to_hvg(adata, n_top_genes)
adata = self.filter_cells(adata, self.min_non_zero_genes, self.min_cell_type_size)
self.dataset_collection = {dataset_name: adata}
adata.obs_names_make_unique()
adata.X = csr_matrix(adata.X)
print('Done!\n'+'-'*90)
[docs]
def setup_target_adata(self,
adata: ad.AnnData,
batch_key: str,
cell_key: Union[str, None] = None,
eval_nns_keys: Union[List[str], None] = None,
dataset_name: str = 'human',
NCBI_Taxon_ID: int = 9606,
n_top_genes: Union[int, None] = None,
compute_log1p: bool = True,
nn_kwargs: Optional[dict] = None,
):
"""
Preprocess and align a target AnnData against the context.
Steps:
1. Onehot-encode experimental batchs.
2. Calculate library size encoder prior parameters for scVI
3. Subset to top HVGs and filter out cells with low expression patterns as well as rare cell types and batches (optionally).
4. Translate target gene symbols to context homologs.
5. Compute and evaluate data-level nearest neighbors on the shared homologous gene set.
Parameters
----------
adata : ad.AnnData
Target dataset.
batch_key : str
Observation key for experimental batch labels.
cell_key : str or None
Observation key for cell types (None if unkown).
eval_nns_keys : List of str or None
List of context dataset `obs` keys that should be transferred by scSpecies. Defaults to [cell_key].
dataset_name : str, optional
Defaults to 'human'.
NCBI_Taxon_ID : int, optional
Taxonomy ID for the target species (default human - 9606).
n_top_genes : int or None, optional
Number of HVGs to keep (None to skip) (default None).
compute_log1p : bool, optional
Use log1p counts for neighbor search if True (default True).
nn_kwargs : dict, optional
Args for sklearn.neighbors.NearestNeighbors.
Defaults to `{'n_neighbors': 250, 'metric': 'cosine'}`.
Effects
-------
- Updates `adata.uns['metadata']` with target dataset info.
- Filters and one-hot encodes batch (and cell-type, if provided).
- Computes library size prior parameters.
- Calls `translate_gene_list` to add translated gene symbols in the context genome to `var_names_transl`.
- Subsets to HVGs if `n_top_genes` is not None.
- Filters out low-coverage cells and rare cell-types.
- Identifies intersecting homologous genes with the context and performs a nearest-neighbor search on log1p (or raw) counts.
- Stores neighbor indices in `adata.obsm['ind_neigh_nns']`.
- Calculates the percentage of neighbor label agreement and transfers labels based on the data-level nearest neighbor search.
- Inserts the processed AnnData into `self.dataset_collection`.
"""
adata = adata.copy()
adata.uns['metadata'] = {
'name': dataset_name,
'batch_key': batch_key,
'cell_key': cell_key,
'NCBI_Taxon_ID': NCBI_Taxon_ID,
'function': 'target',
}
if cell_key == None:
adata.uns['metadata']['cell_key'] = 'unknown'
adata.obs['dataset'] = dataset_name
adata.obs.index = adata.obs.index.astype(str) + f"_{dataset_name}"
adata = self.encode_batch_labels(adata, self.min_batch_size)
adata = self.compute_lib_prior_params(adata)
if n_top_genes != None:
adata = self.subset_to_hvg(adata, n_top_genes)
adata = self.filter_cells(adata, self.min_non_zero_genes, self.min_cell_type_size)
adata = self.translate_gene_list(adata)
_, context_ind, target_ind = np.intersect1d(self.dataset_collection[self.context_dataset_name].var_names.to_numpy(), adata.var['var_names_transl'], return_indices=True)
if nn_kwargs is None:
nn_kwargs = {}
if "n_neighbors" not in nn_kwargs:
nn_kwargs["n_neighbors"] = 250
if "metric" not in nn_kwargs:
nn_kwargs["metric"] = "cosine"
if len(context_ind) == 0:
raise ValueError("No homologous genes found. scSpecies cannot be used.")
elif len(context_ind) < 250:
raise Warning("Only \033[35m{}\033[0m homologous genes found. Data-level neighbor search may yield noisy results.".format(str(len(context_ind))))
else:
print("Found \033[35m{}\033[0m shared homologous genes between context and target dataset".format(str(len(context_ind))))
print('Perform the data-level nearest neigbor search on the homologous gene set.')
if compute_log1p:
context_neigh = np.log1p(self.dataset_collection[self.context_dataset_name].X.toarray()[:, context_ind])
target_neigh = np.log1p(adata.X.toarray()[:, target_ind])
else:
context_neigh = self.dataset_collection[self.context_dataset_name].X.toarray()[:, context_ind]
target_neigh = adata.X.toarray()[:, target_ind]
neigh = NearestNeighbors(**nn_kwargs)
neigh.fit(context_neigh)
_, indices_whole = neigh.kneighbors(target_neigh)
adata.obsm['ind_neigh_nns'] = np.squeeze(indices_whole).astype(np.int32)
if eval_nns_keys == None:
eval_nns_keys = [self.context_cell_key]
adata = self.pred_labels_nns_hom_genes(adata, eval_nns_keys)
self.dataset_collection[dataset_name] = adata
adata.X = csr_matrix(adata.X)
print('Done!\n'+'-'*90)
[docs]
def translate_gene_list(
self,
adata: ad.AnnData
) -> ad.AnnData:
"""
Translate gene symbols in var_names of a target AnnData to homologous context-species symbols.
Will download a HOM_AllOrganism.rpt if not present if context-target species pair consits of
human, mouse, rat or zebrafish. Will fallback to `map_homologs_silent` for unsupported species pairs.
Parameters
----------
adata : anndata.AnnData
Target AnnData whose var_names will be translated.
Effects
-------
- Prints a status message about which datasets are being translated.
- Downloads and saves `HOM_AllOrganism.rpt` if not already present.
- Reads the homology report into a DataFrame.
- Filters the table to context and target species.
- Computes a translated gene list via `get_key` or falls back to `map_homologs_silent` if species is not human, mouse, rat or zebrafish.
- Sets `adata.var['var_names_transl']` to the mapped names.
"""
print('Translating homologous gene names between {} context and {} target dataset.'.format(self.context_dataset_name, adata.uns['metadata']['name']))
gene_list = adata.var_names
NCBI_Taxon_ID = adata.uns['metadata']['NCBI_Taxon_ID']
if self.context_NCBI_Taxon_ID == NCBI_Taxon_ID:
transl_gene_list = gene_list
elif self.context_NCBI_Taxon_ID in (9606, 10090, 10116, 7955) and NCBI_Taxon_ID in (9606, 10090, 10116, 7955): #False:# s
out_dir = Path("data")
out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / "HOM_AllOrganism.rpt"
if not out_file.exists():
print(f"Downloading gene translation dictionary.")
homology_df = pd.read_csv("https://www.informatics.jax.org/downloads/reports/HOM_AllOrganism.rpt", sep="\t")
homology_df.to_csv(out_file, sep="\t", index=False)
print('Gene translation dictionary saved to data/HOM_AllOrganism.rpt'.format(out_file))
else:
homology_df = pd.read_csv(out_file, sep="\t")
homology_contsp_df = homology_df[homology_df['NCBI Taxon ID'] == self.context_NCBI_Taxon_ID]
homology_targsp_df = homology_df[homology_df['NCBI Taxon ID'] == NCBI_Taxon_ID]
transl_gene_list = [get_key(gene, homology_targsp_df, homology_contsp_df, i) for i,gene in enumerate(gene_list)]
else:
transl_gene_list = map_homologs_silent(list(gene_list), NCBI_Taxon_ID, self.context_NCBI_Taxon_ID)
num_hom_genes = len([gene for gene in transl_gene_list if 'non_hom' not in gene])
print('Could map \033[33m{}\033[0m of {} target gene symbols to context species gene symbols'.format(str(num_hom_genes), str(len(transl_gene_list))))
adata.var['var_names_transl'] = transl_gene_list
return adata
[docs]
@staticmethod
def filter_cells(
adata: ad.AnnData,
min_non_zero_genes: float,
min_cell_type_size: int
) -> ad.AnnData:
"""
Filter cells based on minimum non-zero gene fraction and cell‐type size.
Parameters
----------
adata : anndata.AnnData
The annotated data matrix to filter.
min_non_zero_genes : float
Minimum fraction of genes that must have nonzero counts in a cell.
min_cell_type_size : int
Minimum number of cells required to retain any given cell‐type.
Effects
-------
- Removes cells with fewer than `min_non_zero_genes * n_vars` detected genes.
- If a cell‐type key is set in `adata.uns['metadata']['cell_key']`, discards
any cell‐types with fewer than `min_cell_type_size` cells.
"""
old_n_obs = adata.n_obs
cell_key = adata.uns['metadata']['cell_key']
sc.pp.filter_cells(adata, min_genes=adata.n_vars*min_non_zero_genes)
if cell_key != 'unknown':
cell_type_counts = adata.obs[cell_key].value_counts()>min_cell_type_size
cell_type_counts = cell_type_counts[cell_type_counts==True].index
adata = adata[adata.obs[cell_key].isin(cell_type_counts)]
print('Filtering cells. Kept {}, removed {}.'.format(str(adata.n_obs), str(int(old_n_obs-adata.n_obs))))
return adata
[docs]
@staticmethod
def compute_lib_prior_params(
adata: ad.AnnData
) -> ad.AnnData:
"""
Compute scVI library size prior parameters for each cell.
Parameters
----------
adata : anndata.AnnData
Annotated data matrix with raw counts in `adata.X`.
Effects
-------
- Within each batch (from `adata.uns['metadata']['batch_key']`),
calculates the mean and standard deviation of log-total counts.
- Stores values in `adata.obs['library_log_mean']` and
`adata.obs['library_log_std']` as float32 columns.
"""
print('Compute prior parameters for the library encoder.')
batch_key = adata.uns['metadata']['batch_key']
library_log_mean = np.zeros(shape=(adata.n_obs, 1))
library_log_std = np.ones(shape=(adata.n_obs, 1))
log_sum = np.log(adata.X.sum(axis=1))
for batch in np.unique(adata.obs[batch_key]):
ind = np.where(adata.obs[batch_key] == batch)[0]
library_log_mean[ind] = np.mean(log_sum[ind])
library_log_std[ind] = np.std(log_sum[ind])
adata.obs['library_log_mean'] = library_log_mean.astype(np.float32)
adata.obs['library_log_std'] = library_log_std.astype(np.float32)
return adata
[docs]
def pred_labels_nns_hom_genes(
self,
adata: ad.AnnData,
context_label_keys: List[str] = None,
k: int = 25,
) -> ad.AnnData:
"""
Predicts target cell-type labels using data-level k-nearest neighbor search
results over homologous genes shared with the context dataset.
Additionaly calculates the uncertainty score that will be used by scSpecies to decide
which cells are aligned during fine-tuning.
Parameters
----------
adata : anndata.AnnData
Target dataset that contains the neighbor indices in `adata.obsm['ind_neigh_nns']`.
context_label_keys : list of str
Keys in the context dataset's `obs` corresponding to categorical labels to be transferred
(e.g., cell-type, tissue-type).
k : int
Amount of neighbort to consider for majority voting
Effects
-------
- For each key in `context_label_keys`, assigns:
- `adata.obs['pred_nns_<label_key>']`: predicted label (most frequent among neighbors).
- `adata.obs['top_percent_<label_key>']`: confidence score based on relative neighbor rank.
"""
context_adata = self.dataset_collection[self.context_dataset_name]
for context_label_key in context_label_keys:
print('Evaluating data level NNS and calculating cells with the highest agreement for context labels key {}.'.format(context_label_key))
ind_neigh_topk = adata.obsm['ind_neigh_nns'][:,:k]
candidate_labels = context_adata.obs[context_label_key].to_numpy()
label_counts = [dict(Counter(candidate_labels[ind_neigh_topk[i]])) for i in range(adata.n_obs)]
label_counts = [max(label_counts[i].items(), key=lambda x: x[1]) + (i, ) for i in range(adata.n_obs)]
top_dict = {c: [] for c in np.unique(candidate_labels)}
for i in range(len(label_counts)):
top_dict[label_counts[i][0]] += [label_counts[i]]
for key in top_dict.keys():
top_dict[key] = sorted(top_dict[key], key=lambda x: x[1])
num_samples = len(top_dict[key])
top_dict[key] = [top_dict[key][i]+(1-(i+1)/num_samples,) for i in range(len(top_dict[key]))]
label_counts = sorted([item for sublist in top_dict.values() for item in sublist], key=lambda x: x[-2])
adata.obs['top_percent_'+context_label_key] = np.array([label_counts[i][-1] for i in range(len(label_counts))])
adata.obs['pred_nns_'+context_label_key] = np.array([label_counts[i][0] for i in range(len(label_counts))])
return adata
[docs]
@staticmethod
def encode_batch_labels(
adata: ad.AnnData,
min_batch_size: Union[int, None] = None
) -> ad.AnnData:
"""
One‐hot encode experimental batch labels, excluding small batches.
Parameters
----------
adata : anndata.AnnData
Annotated data matrix with batch labels in `adata.obs[...]`.
min_batch_size : int
Smallest batch size to keep; batches with fewer cells are removed, must be >= 0.
Effects
-------
- Drops any batch categories with fewer than `min_batch_size` cells.
- Fits a OneHotEncoder to remaining batch labels.
- Saves the encoded batch matrix to `adata.obsm['batch_label_enc']`.
- Builds `adata.uns[batch_dict]`, mapping each cell‐type (and 'unknown')
to batch labels in which they have samples.
"""
if min_batch_size == None:
min_batch_size = 0
batch_key = adata.uns['metadata']['batch_key']
cell_key = adata.uns['metadata']['cell_key']
name = adata.uns['metadata']['name']
batch_counts = adata.obs[batch_key].value_counts()
to_remove = batch_counts[batch_counts < min_batch_size].index
adata = adata[~adata.obs[batch_key].isin(to_remove)]
batch_labels = adata.obs[batch_key].to_numpy().reshape(-1, 1)
print('Registering experimental batches for the {} dataset. Kept {}, removed {}.'.format(
name, str(len(np.unique(batch_labels))), str(len(batch_counts))))
enc = OneHotEncoder()
enc.fit(batch_labels)
adata.obsm['batch_label_enc'] = enc.transform(batch_labels).toarray().astype(np.float32)
if cell_key == 'unknown':
batch_dict = {'unknown': enc.transform(np.unique(batch_labels).reshape(-1, 1)).toarray().astype(np.float32)}
else:
cell_types = adata.obs[cell_key].cat.categories.to_numpy()
batch_dict = {c: adata[adata.obs[cell_key] == c].obs[batch_key].value_counts() > 3 for c in cell_types}
batch_dict = {c : batch_dict[c][batch_dict[c]].index.to_numpy() for c in cell_types}
batch_dict = {c : enc.transform(batch_dict[c].reshape(-1, 1)).toarray().astype(np.float32) for c in cell_types}
batch_dict['unknown'] = enc.transform(np.unique(batch_labels).reshape(-1, 1)).toarray().astype(np.float32)
adata.uns['batch_dict'] = batch_dict
return adata
[docs]
@staticmethod
def subset_to_hvg(
adata: ad.AnnData,
n_top_genes: int,
) -> ad.AnnData:
"""
Subset dataset to the top highly variable genes using the Seurat method.
Parameters
----------
adata : anndata.AnnData
Annotated data matrix to subset.
n_top_genes : int
Number of top highly variable genes to select.
Effects
-------
- Subsets `adata` to the top `n_top_genes` hvg genes.
"""
print('Subsetting the {} dataset to the {} most highly variable genes using seurat.'.format(adata.uns['metadata']['name'], str(n_top_genes)))
batch_key = adata.uns['metadata']['batch_key']
adata.layers["raw_counts"] = adata.X.copy()
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(
adata,
batch_key=batch_key,
n_top_genes=n_top_genes,
subset=True,
flavor='seurat_v3',
)
adata.X = adata.layers['raw_counts'].copy()
del adata.layers['raw_counts']
return adata
[docs]
def return_mdata(self,
return_mdata: bool = True,
save: bool = True,
save_path: Path = Path("data"),
save_name: str = 'mudata'
) -> mu.MuData:
"""
Optionally save and/or return the assembled MuData object.
Parameters
----------
return_mdata : bool, optional
If True, return the MuData object at the end (default True).
save : bool, optional
If True, write the MuData object to disk (default True).
save_path : pathlib.Path, optional
Directory in which to save the file; created if missing (default Path("data")).
save_name : str, optional
Filename stem for the .h5mu file; '.h5mu' is appended (default 'mudata').
Effects
-------
- If `save` is True:
- Ensures that `save_path` exists, creating it if necessary.
- Writes the MuData assembled from `self.dataset_collection` to
`save_path/<save_name>.h5mu`.
- Prints messages about directory creation and file saving.
- If `return_mdata` is True:
- Returns the MuData object constructed from `self.dataset_collection`.
"""
if save:
save_path = Path(save_path)
if not save_path.exists():
save_path.mkdir(parents=True, exist_ok=True)
print(f"\nCreated directory '{save_path}'.")
mdata = mu.MuData(self.dataset_collection)
file_path = save_path / f"{save_name}.h5mu"
mdata.write(str(file_path))
print(f"Saved mdata to {file_path}.")
if return_mdata:
return mdata