diff --git a/gnomad/utils/intervals.py b/gnomad/utils/intervals.py index d927b9bbe..ff8039742 100644 --- a/gnomad/utils/intervals.py +++ b/gnomad/utils/intervals.py @@ -1,9 +1,19 @@ # noqa: D100 -from typing import List, Union +import logging +from typing import List, Optional, Union import hail as hl +from gnomad.utils.reference_genome import get_reference_genome + +logging.basicConfig( + format="%(asctime)s (%(name)s %(lineno)s): %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", +) +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + def sort_intervals(intervals: List[hl.Interval]): """ @@ -112,3 +122,71 @@ def _add_padding( return [_add_padding(i) for i in intervals] else: return _add_padding(intervals) + + +def explode_intervals_to_loci( + intervals: Union[hl.Table, hl.expr.IntervalExpression], + interval_field: Optional[str] = None, + keep_intervals: Optional[bool] = False, +) -> Union[hl.Table, hl.expr.ArrayExpression]: + """ + Expand intervals to loci and key by loci, or return loci range expression. + + :param intervals: Hail Table or Interval Expression. + :param interval_field: Name of the interval field. Only required if input is a Hail Table. Default is None. + :param keep_intervals: If True, keep the original intervals as a column in output. Only applies if input is a Hail Table. Default is False. + :return: If input is a Hail Table, returns exploded Table keyed by locus. If input is an IntervalExpression, returns position array expression. + """ + assert isinstance(intervals, hl.Table) or isinstance( + intervals, hl.expr.IntervalExpression + ), "Input must be a Table or IntervalExpression!" + + if isinstance(intervals, hl.Table) and ( + not interval_field or keep_intervals is None + ): + raise ValueError( + "`interval_field` and `keep_intervals` must be defined if input is a Table!" + ) + assert ( + interval_field in intervals.row + ), "`interval_field` must be an annotation present on input Table!" + intervals_expr = ( + intervals + if isinstance(intervals, hl.expr.IntervalExpression) + else intervals[interval_field] + ) + intervals_start_expr = hl.if_else( + intervals_expr.includes_start, + intervals_expr.start.position, + intervals_expr.start.position + 1, + ) + intervals_end_expr = hl.if_else( + intervals_expr.includes_end, + intervals_expr.end.position + 1, + intervals_expr.end.position, + ) + if isinstance(intervals, hl.Table): + intervals = intervals.annotate( + pos=hl.range(intervals_start_expr, intervals_end_expr) + ).explode("pos") + intervals = intervals.key_by( + locus=hl.locus( + intervals[interval_field].start.contig, + intervals.pos, + reference_genome=get_reference_genome(intervals[interval_field]), + ) + ) + + fields_to_drop = ["pos"] + if not keep_intervals: + fields_to_drop.append(interval_field) + + return intervals.drop(*fields_to_drop) + + logger.warning( + "Input is an IntervalExpression, so function will return ArrayExpression of" + " positions within input intervals. To fully explode intervals to loci, we" + " recommend annotating your dataset with the returned ArrayExpression," + " exploding the array, and converting the positions to loci!" + ) + return hl.range(intervals_start_expr, intervals_end_expr)