ViralBERT is a nucleotide language model pre-trained on large-scale metagenomic datasets, designed to tackle a wide range of tasks in viral metagenomics analysis. By leveraging deep contextual representations, ViralBERT enables powerful downstream applications including embedding extraction, binning, and viral identification.
conda create -n viralbert python=3.10
conda activate viralbert
pip install -r requirements.txtconda create -n viralbert python=3.10
conda activate viralbert
pip install -r requirements_mps.txtTo run inference, you will need to download the pre-trained weights.
ViralBERT provides three main inference scripts tailored for different analysis needs.
Best for: Dimensionality reduction (t-SNE/UMAP), sequence clustering, or feature engineering for downstream sequnece analysis.
Use the base pre-trained model to extract deep semantic vectors from DNA sequences via scripts/inference_pretrain.py.
Usage:
python scripts/inference_pretrain.py \
--model_path models/viralbert_base \
--input_fasta data/my_sequences.fasta \
--output_dir results/embeddings \
--device cuda:0Key Arguments:
--pooling: Strategy to derive chunk embeddings (meanorcls). Default:mean.--combine: Aggregates chunk embeddings into a single sequence vector. Use--no-combineto output embeddings for every individual chunk.--batch_size: Batch size for inference. Default: 32.
Outputs:
embeddings.npy: The embedding matrix (shape:[N_seqs, Hidden_dim]).metadata.csv: Corresponding sequence IDs, lengths, and chunk counts.
Best for: Recovering viral genomes (vMAGs) from mixed metagenomic data using unsupervised clustering.
The scripts/inference_binning.py pipeline uses a contrastive-learning fine-tuned model to generate embeddings optimized for separating distinct genomes, followed by clustering algorithms (Leiden, HDBSCAN, or KMeans).
Usage:
python scripts/inference_binning.py \
--model_path models/viralbert_binning \
--input_fasta data/metagenome.fasta \
--output_dir results/binning \
--device cuda:0Key Arguments:
- Clustering Parameters:
--leiden_k: Number of neighbors for Leiden graph construction (default: 15).--hdbscan_min_cluster_size: Minimum cluster size for HDBSCAN (default: 5).--kmeans_n_clusters: Number of clusters for KMeans (can be an integer orauto).
- Embeddings:
--save_embeddings: If set, saves the intermediate embeddings to a.npzfile for reuse.
Outputs:
The script generates TSV files mapping contig_id to bin_id for different combinations of embedding types (CLS, Projected, MeanPool) and clustering algorithms:
- e.g.,
prefix_mean_pool_leiden_bin_map.tsv
Best for: Identifying viral contigs vs. host, or classifying viruses into families/genera.
scripts/inference_classification.py classifies sequences using a supervised fine-tuned model. It handles long sequences by splitting them into chunks and aggregating predictions via Voting or Probability Averaging.
Usage:
python scripts/inference_classification.py \
--model_path models/viralbert_classification \
--input_fasta data/unknown_contigs.fasta \
--output_path results/classification.csv \
--device cuda:0Key Arguments:
--inference_max_len: Max token length per chunk (default: 512).--batch_size: Inference batch size per GPU.-c/--continue: Resume from an existing output file if the process was interrupted.
Outputs: A CSV file containing:
predicted_label_vote: Prediction based on majority vote of chunks.predicted_label_prob: Prediction based on averaged probabilities.prob_<class>: Probability scores for each class.
This project is licensed under the MIT License.