Skip to content

Commit

Permalink
Merge branch 'split_reads' into 'master'
Browse files Browse the repository at this point in the history
Split reads

See merge request algorithm/remora!272
  • Loading branch information
marcus1487 committed Nov 13, 2023
2 parents 857c945 + 434189a commit 6eec805
Show file tree
Hide file tree
Showing 10 changed files with 301 additions and 154 deletions.
45 changes: 24 additions & 21 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,18 @@ Chunk raw data are loaded from each core dataset at specified proportions to con
In a break from Remora <3.0, datasets allow "infinite iteration", where each core dataset is drawn from indefinitely and independently to supply training chunks.
For validation from a fixed set of chunks, finite iteration is also supported.

To generate a dataset config from the datasets created above one can use the following command.

.. code-block:: bash
remora \
dataset make_config \
train_dataset.jsn \
can_chunks \
mod_chunks \
--dataset-weights 1 1 \
--log-filename train_dataset.log
Model Training
--------------

Expand All @@ -137,16 +149,17 @@ For example a model can be trained with the following command.
train_dataset.jsn \
--model remora/models/ConvLSTM_w_ref.py \
--device 0 \
--chunk-context 50 50 \
--output-path train_results
This command will produce a "best" model in torchscript format for use in Bonito, ``remora infer``, or ``remora validate`` commands.
Models can be exported for use in Dorado with the ``remora model export`` command.
Models can be exported for use in Dorado with the ``remora model export train_results/model_best.pt`` command.

Model Inference
---------------

For testing purposes inference within Remora is provided.
Note that for large scale using the exported Dorado model during basecalling is recommended.
For testing purposes, inference within Remora is provided.
For large scale using the exported Dorado model during basecalling is recommended.

.. code-block:: bash
Expand All @@ -156,26 +169,32 @@ Note that for large scale using the exported Dorado model during basecalling is
can_mappings.bam \
--model train_results/model_best.pt \
--out-file can_infer.bam \
--log-filename can_infer.log \
--device 0
remora \
infer from_pod5_and_bam \
mod_signal.pod5 \
mod_mappings.bam \
--model train_results/model_best.pt \
--out-file mod_infer.bam \
--log-filename mod_infer.log \
--device 0
Finally, Remora provides tools to validate these results.
Ground truth `BED files <http://useast.ensembl.org/info/website/upload/bed.html>`_ reference positions where each read should be called as the modified or canonical base listed in the BED name field.
Note in the test files where the control file has a ``C`` in the name field, while the modified BED file has ``m`` (single letter code for 5mC) in the name field.

WARNING: There is a bug in pysam which causes all-context (e.g. ``--motif C 0``) modified model calls to produce invalid results with this command.
This issue is reported `here <https://github.com/pysam-developers/pysam/issues/1123>`_.
We are investigating solutions to bypass this issue including dropping this command.

.. code-block:: bash
remora \
validate from_modbams \
--bam-and-bed can_infer.bam can_ground_truth.bed \
--bam-and-bed mod_infer.bam mod_ground_truth.bed \
--full-output-filename validation_results.txt
--explicit-mod-tag-used
Pre-trained Models
------------------
Expand Down Expand Up @@ -219,23 +238,7 @@ Note that only a single POD5 file per sample is allowed as input and that the BA
:width: 600
:alt: Plot reference region image (reverse strand)

The Remora API has a simple interface to access and manipulate a nanopore read including signal, basecalls, reference mapping and links between each of these.
The ``remora.io.Read`` object is the core object for joining these data types.
The ``remora.io.Read.from_pod5_and_alignment`` class method is the simplest interface to initialize the object.
This method takes ``pod5.Read`` and ``pysam.AlignedSegment`` objects as input.
Remora also provides a method to generate an in-memory index of a BAM file (``remora.io.ReadIndexedBam``) for random access by read ID.

Note that the input BAM file should contain ``mv`` (move table) and ``MD`` tags in order to access signal and reference information respectively.
See the see "Data Preparation" section above for details.

The ``notebooks/read_plotting.ipynb`` notebook included with this repository exemplifies some of the functionality provided via the ``io.Read`` object.

A ``remora.data_chunks.RemoraRead`` object can extracted from an ``io.Read`` object with the ``into_remora_read`` method.
The ``RemoraRead`` object is more specialized to contain just the information needed to create chunks for input into Remora modified base models.
The ``RemoraRead`` object can be generated from either basecalled sequence or reference sequence via the ``use_reference_anchor``.
The ``remora.inference.call_read_mods`` function runs a Remora model on a ``RemoraRead`` returning the probabilities for each modeled base and positions of those bases.

The ``remora.io.Read`` API also enables access to per-read, per-site raw signal metrics for more advanced statistical analysis.
The Remora API to access, manipulate and visualize nanopore reads including signal, basecalls, and reference mapping is described in more detail in the ``notebooks`` section of this repository.

Terms and Licence
-----------------
Expand Down
2 changes: 1 addition & 1 deletion src/remora/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
DEFAULT_SUPER_BATCH_SAMPLE_FRAC = 1.0
DEFAULT_CHUNKS_PER_EPOCH = 10_000_000
DEFAULT_NUM_TEST_CHUNKS = 10_000
DEFAULT_CHUNK_CONTEXT = (50, 50)
DEFAULT_CHUNK_CONTEXT = (200, 200)
DEFAULT_MIN_SAMPLES_PER_BASE = 5
DEFAULT_KMER_CONTEXT_BASES = (4, 4)
DEFAULT_KMER_LEN = sum(DEFAULT_KMER_CONTEXT_BASES) + 1
Expand Down
74 changes: 37 additions & 37 deletions src/remora/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
from threading import Thread
from collections import defaultdict

import pod5
import pysam
import numpy as np
from tqdm import tqdm
from pod5 import DatasetReader

from remora import constants, log, RemoraError
from remora.data_chunks import CoreRemoraDataset, DatasetMetadata
from remora.io import (
ReadIndexedBam,
get_read_ids,
iter_signal,
extract_alignments,
DuplexRead,
Expand All @@ -30,7 +31,6 @@
NamedMPQueue,
format_mm_ml_tags,
softmax_axis1,
get_read_ids,
Motif,
revcomp,
human_format,
Expand Down Expand Up @@ -71,19 +71,24 @@ def prepare_reads(read_errs, model_metadata, ref_anchored):
"extra_arrays": {"read_focus_bases": ("int64", "")},
}
for io_read, err in read_errs:
if io_read is None:
out_read_errs.append((None, None, err))
if err is not None:
io_read.prune(drop_move_tag=False)
out_read_errs.append((io_read, None, err))
continue
try:
remora_read = io_read.into_remora_read(ref_anchored)
except RemoraError as e:
LOGGER.debug(f"{io_read.read_id} Remora read prep error: {e}")
out_read_errs.append((None, None, f"Remora read prep error: {e}"))
io_read.prune(drop_move_tag=False)
LOGGER.debug(f"{io_read.child_read_id} Read prep error: {e}")
out_read_errs.append((io_read, None, f"Read prep error: {e}"))
continue
except Exception as e:
LOGGER.debug(f"{io_read.read_id} Unexpected error: {e}")
out_read_errs.append((None, None, f"Unexpected error: {e}"))
io_read.prune(drop_move_tag=False)
LOGGER.debug(f"{io_read.child_read_id} Unexpected error: {e}")
out_read_errs.append((io_read, None, f"Unexpected error: {e}"))
continue
# after creating remora read strip IO read of large arrays
io_read.prune(drop_move_tag=False)
remora_read.set_motif_focus_bases(motifs)
remora_read.refine_signal_mapping(model_metadata["sig_map_refiner"])
chunks = list(
Expand All @@ -95,16 +100,9 @@ def prepare_reads(read_errs, model_metadata, ref_anchored):
)
)
if len(chunks) == 0:
LOGGER.debug(f"{io_read.read_id} No mod calls")
out_read_errs.append((None, None, "No mod calls"))
LOGGER.debug(f"{io_read.child_read_id} No mod calls")
out_read_errs.append((io_read, None, "No mod calls"))
continue
# clear larger memory arrays (for quicker queue transfer)
# access sig len to save the value
io_read.sig_len
io_read.dacs = None
io_read.mv_table = None
io_read.query_to_signal = None
io_read.ref_to_signal = None
# prepare in memory dataset to perform chunk extraction
num_chunks = len(chunks)
md_kwargs["allocate_size"] = num_chunks
Expand Down Expand Up @@ -139,9 +137,7 @@ def prep_nn_input(read_errs):
# TODO for basecall-anchored calls only call on first read and apply to
# other mappings
if len(read_errs) == 0:
return [
(None, None, "No valid mappings"),
]
return [(None, None, "No valid mappings")]
read_nn_inputs = []
for io_read, read_dataset, err in read_errs:
if err is not None:
Expand Down Expand Up @@ -178,14 +174,14 @@ def new_arrays():
for read_nn_inputs in prepped_nn_inputs:
for io_read, r_chunks, err in read_nn_inputs:
if err is not None:
b_reads.append([None, None, None, err])
b_reads.append([io_read, None, None, err])
continue
num_chunks = r_chunks["read_focus_bases"].size
# fill out and yield full batches
rb_consumed = 0
# while this read extends through a whole batch continue to
# supply batches from this read
while b_pos + num_chunks - rb_consumed > batch_size:
while b_pos + num_chunks - rb_consumed >= batch_size:
rb_en = rb_consumed + batch_size - b_pos
b_sigs[b_pos:] = r_chunks["signal"][rb_consumed:rb_en]
b_enc_kmers[b_pos:] = r_chunks["enc_kmers"][rb_consumed:rb_en]
Expand Down Expand Up @@ -284,9 +280,8 @@ def unbatch_reads(curr_read, b_nn_out, b_read_pos, b_reads):
comp_reads.append(curr_read)
comp_reads.append((io_read, None, None, err))
curr_read = None
continue
# end of read from previous batch
if b_st is None:
elif b_st is None:
if curr_read is None:
LOGGER.debug("Unbatching encountered None read")
raise RemoraError("Unbatching encountered None read")
Expand All @@ -304,10 +299,15 @@ def unbatch_reads(curr_read, b_nn_out, b_read_pos, b_reads):
np.concatenate([r_read_pos, b_read_pos[:b_en]]),
None,
)
continue
if curr_read is not None:
comp_reads.append(curr_read)
curr_read = (io_read, b_nn_out[b_st:b_en], b_read_pos[b_st:b_en], None)
else:
if curr_read is not None:
comp_reads.append(curr_read)
curr_read = (
io_read,
b_nn_out[b_st:b_en],
b_read_pos[b_st:b_en],
None,
)
return comp_reads, curr_read


Expand Down Expand Up @@ -339,7 +339,7 @@ def unbatch(*args, **kwargs):
def post_process_reads(read_mapping, model_metadata, ref_anchored):
io_read, nn_out, r_poss, err = read_mapping
if err is not None:
return None, err
return io_read, err
r_probs = softmax_axis1(nn_out)[:, 1:].astype(np.float64)
seq = io_read.ref_seq if ref_anchored else io_read.seq
mod_tags = mods_tags_to_str(
Expand All @@ -351,11 +351,6 @@ def post_process_reads(read_mapping, model_metadata, ref_anchored):
can_base=model_metadata["can_base"],
)
)
io_read.full_align["tags"] = [
tag
for tag in io_read.full_align["tags"]
if not (tag.startswith("MM") or tag.startswith("ML"))
]
io_read.full_align["tags"].extend(mod_tags)
if ref_anchored:
io_read.full_align["cigar"] = f"{len(io_read.ref_seq)}M"
Expand Down Expand Up @@ -385,7 +380,7 @@ def infer_from_pod5_and_bam(
ref_anchored=False,
):
bam_idx = ReadIndexedBam(in_bam_path, skip_non_primary, req_tags={"mv"})
with pod5.DatasetReader(Path(pod5_path)) as pod5_dr:
with DatasetReader(Path(pod5_path)) as pod5_dr:
read_ids, num_reads = get_read_ids(bam_idx, pod5_dr, num_reads)

signals = BackgroundIter(
Expand Down Expand Up @@ -483,6 +478,8 @@ def infer_from_pod5_and_bam(
final_reads.out_q,
]
errs = defaultdict(int)
if bam_idx.num_non_primary > 0:
errs["Non-primary alignment skipped"] = bam_idx.num_non_primary
pysam_save = pysam.set_verbosity(0)
sig_called = 0
in_bam = out_bam = pbar = prev_rid = None
Expand All @@ -504,16 +501,19 @@ def infer_from_pod5_and_bam(
[f"{q.name}: {q.qsize()}/{q.maxsize}" for q in all_qs]
)
)
if err is not None:
if io_read is None:
# should not reach this block
errs[err] += 1
prev_rid = None
LOGGER.DEBUG("None io_read encountered")
pbar.update()
continue
if prev_rid != io_read.read_id:
pbar.update()
sig_called += io_read.sig_len
sps, mag = human_format(sig_called / pbar.format_dict["elapsed"])
pbar.set_postfix_str(f"{sps:>5.1f} {mag}samps/s", refresh=False)
if err is not None:
errs[err] += 1
out_bam.write(
pysam.AlignedSegment.from_dict(
io_read.full_align, out_bam.header
Expand Down
Loading

0 comments on commit 6eec805

Please sign in to comment.