@@ -3863,15 +3863,26 @@ def load_text(
38633863 return tc .tree_sequence ()
38643864
38653865
3866- class TreeIterator :
3867- """
3868- Simple class providing forward and backward iteration over a tree sequence.
3869- """
3870-
3871- def __init__ (self , tree ):
3872- self .tree = tree
3873- self .more_trees = True
3866+ class ObjectIterator :
3867+ # Simple class providing forward and backward iteration over a
3868+ # mutable object with ``next()`` and ``prev()`` methods, e.g.
3869+ # a Tree or a Variant. ``interval`` allows the bounds of the
3870+ # iterator to be specified, and should already have
3871+ # been checked using _check_genomic_range(left, right)
3872+ # If ``return_copies`` is True, the iterator will return
3873+ # immutable copies of each object (this is likely to be significantly
3874+ # less efficient).
3875+ # It can be useful to define __len__ on one of these iterators,
3876+ # which e.g. allows progress bars to provide useful feedback.
3877+
3878+ def __init__ (self , obj , interval , return_copies = False ):
3879+ self ._obj = obj
3880+ self .min_pos = interval [0 ]
3881+ self .max_pos = interval [1 ]
3882+ self .return_copies = return_copies
38743883 self .forward = True
3884+ self .started = False
3885+ self .finished = False
38753886
38763887 def __iter__ (self ):
38773888 return self
@@ -3880,17 +3891,113 @@ def __reversed__(self):
38803891 self .forward = False
38813892 return self
38823893
3894+ def obj_left (self ):
3895+ # Used to work out where to stop iterating when going backwards.
3896+ # Override with code to return the left coordinate of self.obj
3897+ raise NotImplementedError ()
3898+
3899+ def obj_right (self ):
3900+ # Used to work out when to stop iterating when going forwards.
3901+ # Override with code to return the right coordinate of self.obj
3902+ raise NotImplementedError ()
3903+
3904+ def seek_to_start (self ):
3905+ # Override to set the object position to self.min_pos
3906+ raise NotImplementedError ()
3907+
3908+ def seek_to_end (self ):
3909+ # Override to set the object position just before self.max_pos
3910+ raise NotImplementedError ()
3911+
38833912 def __next__ (self ):
3884- if self .forward :
3885- self .more_trees = self .more_trees and self .tree .next ()
3886- else :
3887- self .more_trees = self .more_trees and self .tree .prev ()
3888- if not self .more_trees :
3913+ if not self .finished :
3914+ if not self .started :
3915+ if self .forward :
3916+ self .seek_to_start ()
3917+ else :
3918+ self .seek_to_end ()
3919+ self .started = True
3920+ else :
3921+ if self .forward :
3922+ if not self ._obj .next () or self .obj_left () >= self .max_pos :
3923+ self .finished = True
3924+ else :
3925+ if not self ._obj .prev () or self .obj_right () < self .min_pos :
3926+ self .finished = True
3927+ if self .finished :
38893928 raise StopIteration ()
3890- return self .tree
3929+ return self ._obj .copy () if self .return_copies else self ._obj
3930+
3931+
3932+ class TreeIterator (ObjectIterator ):
3933+ """
3934+ An iterator over some or all of the :class:`trees<Tree>`
3935+ in a :class:`TreeSequence`.
3936+ """
3937+
3938+ def obj_left (self ):
3939+ return self ._obj .interval .left
3940+
3941+ def obj_right (self ):
3942+ return self ._obj .interval .right
3943+
3944+ def seek_to_start (self ):
3945+ self ._obj .seek (self .min_pos )
3946+
3947+ def seek_to_end (self ):
3948+ self ._obj .seek (np .nextafter (self .max_pos , - np .inf ))
38913949
38923950 def __len__ (self ):
3893- return self .tree .tree_sequence .num_trees
3951+ """
3952+ The number of trees over which a newly created iterator will iterate.
3953+ """
3954+ ts = self ._obj .tree_sequence
3955+ if self .min_pos == 0 and self .max_pos == ts .sequence_length :
3956+ # a common case: don't incur the cost of searching through the breakpoints
3957+ return ts .num_trees
3958+ breaks = ts .breakpoints (as_array = True )
3959+ left_index = breaks .searchsorted (self .min_pos , side = "right" )
3960+ right_index = breaks .searchsorted (self .max_pos , side = "left" )
3961+ return right_index - left_index + 1
3962+
3963+
3964+ class VariantIterator (ObjectIterator ):
3965+ """
3966+ An iterator over some or all of the :class:`variants<Variant>`
3967+ in a :class:`TreeSequence`.
3968+ """
3969+
3970+ def __init__ (self , variant , interval , copy ):
3971+ super ().__init__ (variant , interval , copy )
3972+ if interval [0 ] == 0 and interval [1 ] == variant .tree_sequence .sequence_length :
3973+ # a common case: don't incur the cost of searching through the positions
3974+ self .min_max_sites = [0 , variant .tree_sequence .num_sites ]
3975+ else :
3976+ self .min_max_sites = variant .tree_sequence .sites_position .searchsorted (
3977+ interval
3978+ )
3979+ if self .min_max_sites [0 ] >= self .min_max_sites [1 ]:
3980+ # upper bound is exclusive: we don't include the site at self.bound[1]
3981+ self .finished = True
3982+
3983+ def obj_left (self ):
3984+ return self ._obj .site .position
3985+
3986+ def obj_right (self ):
3987+ return self ._obj .site .position
3988+
3989+ def seek_to_start (self ):
3990+ self ._obj .decode (self .min_max_sites [0 ])
3991+
3992+ def seek_to_end (self ):
3993+ self ._obj .decode (self .min_max_sites [1 ] - 1 )
3994+
3995+ def __len__ (self ):
3996+ """
3997+ The number of variants (i.e. sites) over which a newly created iterator will
3998+ iterate.
3999+ """
4000+ return self .min_max_sites [1 ] - self .min_max_sites [0 ]
38944001
38954002
38964003class SimpleContainerSequence :
@@ -4077,7 +4184,7 @@ def aslist(self, **kwargs):
40774184 :return: A list of the trees in this tree sequence.
40784185 :rtype: list
40794186 """
4080- return [tree . copy () for tree in self .trees (** kwargs )]
4187+ return [tree for tree in self .trees (copy = True , ** kwargs )]
40814188
40824189 @classmethod
40834190 def load (cls , file_or_path , * , skip_tables = False , skip_reference_sequence = False ):
@@ -4970,6 +5077,9 @@ def trees(
49705077 sample_lists = False ,
49715078 root_threshold = 1 ,
49725079 sample_counts = None ,
5080+ left = None ,
5081+ right = None ,
5082+ copy = None ,
49735083 tracked_leaves = None ,
49745084 leaf_counts = None ,
49755085 leaf_lists = None ,
@@ -5001,28 +5111,39 @@ def trees(
50015111 are roots. To efficiently restrict the roots of the tree to
50025112 those subtending meaningful topology, set this to 2. This value
50035113 is only relevant when trees have multiple roots.
5114+ :param float left: The left-most coordinate of the region over which
5115+ to iterate. Default: ``None`` treated as 0.
5116+ :param float right: The right-most coordinate of the region over which
5117+ to iterate. Default: ``None`` treated as ``.sequence_length``. This
5118+ value is exclusive, so that a tree whose ``interval.left`` is exactly
5119+ equivalent to ``right`` will not be included in the iteration.
5120+ :param bool copy: Return a immutable copy of each tree. This will be
5121+ inefficient. Default: ``None`` treated as False.
50045122 :param bool sample_counts: Deprecated since 0.2.4.
50055123 :return: An iterator over the Trees in this tree sequence.
5006- :rtype: collections.abc.Iterable, :class:`Tree`
5124+ :rtype: TreeIterator
50075125 """
50085126 # tracked_leaves, leaf_counts and leaf_lists are deprecated aliases
50095127 # for tracked_samples, sample_counts and sample_lists respectively.
50105128 # These are left over from an older version of the API when leaves
50115129 # and samples were synonymous.
5130+ interval = self ._check_genomic_range (left , right )
50125131 if tracked_leaves is not None :
50135132 tracked_samples = tracked_leaves
50145133 if leaf_counts is not None :
50155134 sample_counts = leaf_counts
50165135 if leaf_lists is not None :
50175136 sample_lists = leaf_lists
5137+ if copy is None :
5138+ copy = False
50185139 tree = Tree (
50195140 self ,
50205141 tracked_samples = tracked_samples ,
50215142 sample_lists = sample_lists ,
50225143 root_threshold = root_threshold ,
50235144 sample_counts = sample_counts ,
50245145 )
5025- return TreeIterator (tree )
5146+ return TreeIterator (tree , interval = interval , return_copies = copy )
50265147
50275148 def coiterate (self , other , ** kwargs ):
50285149 """
@@ -5309,8 +5430,8 @@ def variants(
53095430 :param int right: End with the last site before this position. If ``None``
53105431 (default) assume ``right`` is the sequence length, so that the last
53115432 variant corresponds to the last site in the tree sequence.
5312- :return: An iterator over all variants in this tree sequence.
5313- :rtype: iter(:class:`Variant`)
5433+ :return: An iterator over the specified variants in this tree sequence.
5434+ :rtype: VariantIterator
53145435 """
53155436 interval = self ._check_genomic_range (left , right )
53165437 if impute_missing_data is not None :
@@ -5327,26 +5448,13 @@ def variants(
53275448 copy = True
53285449 # See comments for the Variant type for discussion on why the
53295450 # present form was chosen.
5330- variant = tskit .Variant (
5451+ variant_object = tskit .Variant (
53315452 self ,
53325453 samples = samples ,
53335454 isolated_as_missing = isolated_as_missing ,
53345455 alleles = alleles ,
53355456 )
5336- if left == 0 and right == self .sequence_length :
5337- start = 0
5338- stop = self .num_sites
5339- else :
5340- start , stop = np .searchsorted (self .sites_position , interval )
5341-
5342- if copy :
5343- for site_id in range (start , stop ):
5344- variant .decode (site_id )
5345- yield variant .copy ()
5346- else :
5347- for site_id in range (start , stop ):
5348- variant .decode (site_id )
5349- yield variant
5457+ return VariantIterator (variant_object , interval = interval , copy = copy )
53505458
53515459 def genotype_matrix (
53525460 self ,
0 commit comments