This repository prepares genomic datasets and trains a contrastive DNA embedding model geared toward variant pathogenicity and mutation robustness. It is tailored for the Dacon DNA sequence learning challenge (see competition description: https://dacon.io/competitions/official/236630/overview/description).
The pipeline:
- Build a compact ClinVar SNV dataset centered in 1024bp windows.
- Exclude any sequences that leak into the provided
test.csvvia exact window match at the SNV position. - Optionally compute Hamming-distance pairs among test sequences to derive labels.
- Train a GPN-based model with contrastive objectives using distributed
torchrun.
- OS: Ubuntu 24.04.
- Model backend: Hugging Face Transformers + PyTorch.
conda create -n dna-embed python=3.12 -y
conda activate dna-embedInstall packages used in this repo:
cd DNA-embedding
pip install -r requirements.txt
pip install git+https://github.com/songlab-cal/gpn.gitDownload hg38 reference FASTA and ClinVar VCF (GRCh38):
cd data
wget https://hgdownload.soe.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz
wget https://ftp.ncbi.nlm.nih.gov/pub/clinvar/vcf_GRCh38/clinvar.vcf.gz
gunzip -k hg38.fa.gz
gunzip -k clinvar.vcf.gzprepare_clinvar_dataset.py extracts 1024bp windows centered on SNVs, placing the reference base at index 511 (0-based) and records the alternate base and label. Only Benign and Pathogenic (including Likely Pathogenic) entries are kept.
python prepare_clinvar_dataset.py --vcf clinvar.vcf --fasta hg38.fa --out clinvar_compact.csv --window 1024Outputs:
clinvar_compact.csvwith columns:ref_seq, mut_idx, alt, labelmut_idxis typically511for window 1024.labelis1for Benign and-1for Pathogenic/Likely Pathogenic.
sequence_matcher.py performs two tasks:
- Removes ClinVar rows whose
(ref_seq, alt)match pairs that would leak intotest.csvwindows. - Produces labels for
test.csvpairs using precomputed Hamming-distance neighbor pairs.
Run it after placing test.csv and clinvar_compact.csv in data/ and precomputing matches into test/results/match_clinvar.csv:
python sequence_matcher.pyOutputs:
clinvar_compact_removed.csv: filtered ClinVar without matched leakage entries.matched_pairs_labeled.csv: pairs fromtest.csvwith labels inferred via ClinVar mapping on the SNV position.
test/match.py computes Hamming distances among all unordered pairs in data/test.csv and emits two CSVs:
match_clinvar.csv: pairs at distance 1 with a differing base at position 512 (1-based), used to infer ClinVar labels.match_mut.csv: generic mutation pairs.
It uses a compiled _hdist module for speed. To build and run on your machine:
cd ~/DNA-embedding/test
python setup.py build_ext --inplacepython match.pyTraining is orchestrated by script/train.sh using torchrun with multi-GPU DDP. The main entry point is train.py, which:
- Loads and prepares three datasets:
ClinVarRefAltDatasetfor benign and pathogenic SNVs, andContrastiveMutateDatasetfor mutation severity regression - Uses
BalancedAlternatingDatasetfor interleaved multi-task training - Tracks evaluation metrics via
EvaluationCallback, which runs at epoch end and logs cosine distance (cd), pathogenic-benign distance difference (cdd), and Pearson correlation (pcc) - Automatically copies training configuration to output directory for reproducibility
cd ~/DNA-embedding
bash script/train.shLogs are written to output/<RUN_NAME>/training.log, and checkpoints under output/<RUN_NAME>/joint.
Evaluation is computed in eval_metrics.py after embeddings are generated:
-
cd(mean cosine distance): Average of cosine distance$\mathrm{cd} = \frac{1-\cos}{2}$ over all ClinVar-matched benign and pathogenic pairs found indata/matched_pairs_labeled.csv. Lower values indicate closer ref/alt pairs overall. -
cdd(pathogenic minus benign distance): Difference between mean pathogenic and mean benign distances. Highercddmeans pathogenic pairs are farther than benign pairs, which is desired. -
pcc(Pearson correlation with mutations): Pearson correlation between cosine distance and mutation count (fromtest/results/match_mut.csv). Higher correlation indicates that embeddings reflect increasing dissimilarity with more mutations.
The generate_embeddings.py script generates embeddings using a fine-tuned model checkpoint. It supports both vanilla GPN (512D) and WrapperModel with projection head (2048D):
python generate_embeddings.py --checkpoint <path_to_checkpoint> --input <input_csv> --output <output_csv>or
bash script/generate.shThis runs the model loaded with output/model.safetensors by default.
--use_vanilla_gpn: Use vanilla GPN with mean pooling (512D) instead of projection head (2048D)
The base GPN encoder is wrapped with a lightweight projection head that explicitly focuses on the single-nucleotide variant (SNV) position while preserving global context. Together with the datasets and losses below, this yields embeddings where benign SNVs are close to their reference and pathogenic SNVs are farther away.
-
Projection head (
model.py→ProjectionHead):- Multi-Head Attention over pooled hidden states to capture global sequence context.
- A local convolutional feature around the SNV index (default
snv_pos=511) to emphasize the mutation effect on neighboring codons. - Explicit SNV token feature (the embedding at the SNV index) concatenated with global pooled features.
- A final dense layer maps the concatenated features to the embedding space (default 2048-D).
-
Training objective (
model.py→ContrastiveTrainer), connected to the datasets in the next subsection:- For ClinVar ref/alt pairs (batch_type == 0; see
ClinVarRefAltDataset), optimize cosine similarity such that:- Benign pairs (label = +1) are encouraged to be close to their reference (high cosine similarity; margin threshold applies).
- Pathogenic pairs (label = -1) are pushed away from the reference (low cosine similarity).
- For mutation severity batches (batch_type == 1; see
ContrastiveMutateDataset), regress the cosine distancecd = (1 - cos) / 2to the normalized mutation countk/512.
- For ClinVar ref/alt pairs (batch_type == 0; see
These SNV-focused features (local convolution + exact SNV token) strengthen sensitivity at the mutation locus while pooled attention preserves global sequence context.
ClinVarRefAltDataset: Yields (ref, alt) pairs filtered by label (±1). Contrastively encourages benign variants to match reference and pathogenic variants to diverge.ContrastiveMutateDataset: Samples random 1024bp windows fromhg38.fawithkpoint mutations (1 ≤ k ≤ 512). Regresses cosine distance to normalized mutation count (k/512).BalancedAlternatingDataset: Interleaves batches from multiple datasets in round-robin fashion, reshuffling each epoch for balanced multi-task training.ContrastiveDataCollator: Maintains pair structure (ref, alt) by stacking inputs and flattening into batches suitable for contrastive objectives.
- Combine a cosine-margin contrastive loss (controlled by
--cos_loss_margin, e.g.,0.9) for ref/alt embeddings and a regression loss on mutation level (k/512) fromContrastiveMutateDataset. - Gradient clipping and cosine LR scheduling are configured in
train.sh.