123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434 |
- """
- Template for intervaltree
- WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in
- """
- from pandas._libs.algos import is_monotonic
- ctypedef fused int_scalar_t:
- int64_t
- float64_t
- ctypedef fused uint_scalar_t:
- uint64_t
- float64_t
- ctypedef fused scalar_t:
- int_scalar_t
- uint_scalar_t
- # ----------------------------------------------------------------------
- # IntervalTree
- # ----------------------------------------------------------------------
- cdef class IntervalTree(IntervalMixin):
- """A centered interval tree
- Based off the algorithm described on Wikipedia:
- https://en.wikipedia.org/wiki/Interval_tree
- we are emulating the IndexEngine interface
- """
- cdef readonly:
- ndarray left, right
- IntervalNode root
- object dtype
- str closed
- object _is_overlapping, _left_sorter, _right_sorter
- Py_ssize_t _na_count
- def __init__(self, left, right, closed='right', leaf_size=100):
- """
- Parameters
- ----------
- left, right : np.ndarray[ndim=1]
- Left and right bounds for each interval. Assumed to contain no
- NaNs.
- closed : {'left', 'right', 'both', 'neither'}, optional
- Whether the intervals are closed on the left-side, right-side, both
- or neither. Defaults to 'right'.
- leaf_size : int, optional
- Parameter that controls when the tree switches from creating nodes
- to brute-force search. Tune this parameter to optimize query
- performance.
- """
- if closed not in ['left', 'right', 'both', 'neither']:
- raise ValueError("invalid option for 'closed': %s" % closed)
- left = np.asarray(left)
- right = np.asarray(right)
- self.dtype = np.result_type(left, right)
- self.left = np.asarray(left, dtype=self.dtype)
- self.right = np.asarray(right, dtype=self.dtype)
- indices = np.arange(len(left), dtype='int64')
- self.closed = closed
- # GH 23352: ensure no nan in nodes
- mask = ~np.isnan(self.left)
- self._na_count = len(mask) - mask.sum()
- self.left = self.left[mask]
- self.right = self.right[mask]
- indices = indices[mask]
- node_cls = NODE_CLASSES[str(self.dtype), closed]
- self.root = node_cls(self.left, self.right, indices, leaf_size)
- @property
- def left_sorter(self) -> np.ndarray:
- """How to sort the left labels; this is used for binary search
- """
- if self._left_sorter is None:
- values = [self.right, self.left]
- self._left_sorter = np.lexsort(values)
- return self._left_sorter
- @property
- def right_sorter(self) -> np.ndarray:
- """How to sort the right labels
- """
- if self._right_sorter is None:
- self._right_sorter = np.argsort(self.right)
- return self._right_sorter
- @property
- def is_overlapping(self) -> bool:
- """
- Determine if the IntervalTree contains overlapping intervals.
- Cached as self._is_overlapping.
- """
- if self._is_overlapping is not None:
- return self._is_overlapping
- # <= when both sides closed since endpoints can overlap
- op = le if self.closed == 'both' else lt
- # overlap if start of current interval < end of previous interval
- # (current and previous in terms of sorted order by left/start side)
- current = self.left[self.left_sorter[1:]]
- previous = self.right[self.left_sorter[:-1]]
- self._is_overlapping = bool(op(current, previous).any())
- return self._is_overlapping
- @property
- def is_monotonic_increasing(self) -> bool:
- """
- Return True if the IntervalTree is monotonic increasing (only equal or
- increasing values), else False
- """
- if self._na_count > 0:
- return False
- sort_order = self.left_sorter
- return is_monotonic(sort_order, False)[0]
- def get_indexer(self, scalar_t[:] target) -> np.ndarray:
- """Return the positions corresponding to unique intervals that overlap
- with the given array of scalar targets.
- """
- # TODO: write get_indexer_intervals
- cdef:
- Py_ssize_t old_len
- Py_ssize_t i
- Int64Vector result
- result = Int64Vector()
- old_len = 0
- for i in range(len(target)):
- try:
- self.root.query(result, target[i])
- except OverflowError:
- # overflow -> no match, which is already handled below
- pass
- if result.data.n == old_len:
- result.append(-1)
- elif result.data.n > old_len + 1:
- raise KeyError(
- 'indexer does not intersect a unique set of intervals')
- old_len = result.data.n
- return result.to_array().astype('intp')
- def get_indexer_non_unique(self, scalar_t[:] target):
- """Return the positions corresponding to intervals that overlap with
- the given array of scalar targets. Non-unique positions are repeated.
- """
- cdef:
- Py_ssize_t old_len
- Py_ssize_t i
- Int64Vector result, missing
- result = Int64Vector()
- missing = Int64Vector()
- old_len = 0
- for i in range(len(target)):
- try:
- self.root.query(result, target[i])
- except OverflowError:
- # overflow -> no match, which is already handled below
- pass
- if result.data.n == old_len:
- result.append(-1)
- missing.append(i)
- old_len = result.data.n
- return (result.to_array().astype('intp'),
- missing.to_array().astype('intp'))
- def __repr__(self) -> str:
- return ('<IntervalTree[{dtype},{closed}]: '
- '{n_elements} elements>'.format(
- dtype=self.dtype, closed=self.closed,
- n_elements=self.root.n_elements))
- # compat with IndexEngine interface
- def clear_mapping(self) -> None:
- pass
- cdef take(ndarray source, ndarray indices):
- """Take the given positions from a 1D ndarray
- """
- return PyArray_Take(source, indices, 0)
- cdef sort_values_and_indices(all_values, all_indices, subset):
- indices = take(all_indices, subset)
- values = take(all_values, subset)
- sorter = PyArray_ArgSort(values, 0, NPY_QUICKSORT)
- sorted_values = take(values, sorter)
- sorted_indices = take(indices, sorter)
- return sorted_values, sorted_indices
- # ----------------------------------------------------------------------
- # Nodes
- # ----------------------------------------------------------------------
- @cython.internal
- cdef class IntervalNode:
- cdef readonly:
- int64_t n_elements, n_center, leaf_size
- bint is_leaf_node
- def __repr__(self) -> str:
- if self.is_leaf_node:
- return (
- f"<{type(self).__name__}: {self.n_elements} elements (terminal)>"
- )
- else:
- n_left = self.left_node.n_elements
- n_right = self.right_node.n_elements
- n_center = self.n_elements - n_left - n_right
- return (
- f"<{type(self).__name__}: "
- f"pivot {self.pivot}, {self.n_elements} elements "
- f"({n_left} left, {n_right} right, {n_center} overlapping)>"
- )
- def counts(self):
- """
- Inspect counts on this node
- useful for debugging purposes
- """
- if self.is_leaf_node:
- return self.n_elements
- else:
- m = len(self.center_left_values)
- l = self.left_node.counts()
- r = self.right_node.counts()
- return (m, (l, r))
- # we need specialized nodes and leaves to optimize for different dtype and
- # closed values
- {{py:
- nodes = []
- for dtype in ['float64', 'int64', 'uint64']:
- for closed, cmp_left, cmp_right in [
- ('left', '<=', '<'),
- ('right', '<', '<='),
- ('both', '<=', '<='),
- ('neither', '<', '<')]:
- cmp_left_converse = '<' if cmp_left == '<=' else '<='
- cmp_right_converse = '<' if cmp_right == '<=' else '<='
- if dtype.startswith('int'):
- fused_prefix = 'int_'
- elif dtype.startswith('uint'):
- fused_prefix = 'uint_'
- elif dtype.startswith('float'):
- fused_prefix = ''
- nodes.append((dtype, dtype.title(),
- closed, closed.title(),
- cmp_left,
- cmp_right,
- cmp_left_converse,
- cmp_right_converse,
- fused_prefix))
- }}
- {{for dtype, dtype_title, closed, closed_title, cmp_left, cmp_right,
- cmp_left_converse, cmp_right_converse, fused_prefix in nodes}}
- @cython.internal
- cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode(IntervalNode):
- """Non-terminal node for an IntervalTree
- Categorizes intervals by those that fall to the left, those that fall to
- the right, and those that overlap with the pivot.
- """
- cdef readonly:
- {{dtype_title}}Closed{{closed_title}}IntervalNode left_node, right_node
- {{dtype}}_t[:] center_left_values, center_right_values, left, right
- int64_t[:] center_left_indices, center_right_indices, indices
- {{dtype}}_t min_left, max_right
- {{dtype}}_t pivot
- def __init__(self,
- ndarray[{{dtype}}_t, ndim=1] left,
- ndarray[{{dtype}}_t, ndim=1] right,
- ndarray[int64_t, ndim=1] indices,
- int64_t leaf_size):
- self.n_elements = len(left)
- self.leaf_size = leaf_size
- # min_left and min_right are used to speed-up query by skipping
- # query on sub-nodes. If this node has size 0, query is cheap,
- # so these values don't matter.
- if left.size > 0:
- self.min_left = left.min()
- self.max_right = right.max()
- else:
- self.min_left = 0
- self.max_right = 0
- if self.n_elements <= leaf_size:
- # make this a terminal (leaf) node
- self.is_leaf_node = True
- self.left = left
- self.right = right
- self.indices = indices
- self.n_center = 0
- else:
- # calculate a pivot so we can create child nodes
- self.is_leaf_node = False
- self.pivot = np.median(left / 2 + right / 2)
- if np.isinf(self.pivot):
- self.pivot = cython.cast({{dtype}}_t, 0)
- if self.pivot > np.max(right):
- self.pivot = np.max(left)
- if self.pivot < np.min(left):
- self.pivot = np.min(right)
- left_set, right_set, center_set = self.classify_intervals(
- left, right)
- self.left_node = self.new_child_node(left, right,
- indices, left_set)
- self.right_node = self.new_child_node(left, right,
- indices, right_set)
- self.center_left_values, self.center_left_indices = \
- sort_values_and_indices(left, indices, center_set)
- self.center_right_values, self.center_right_indices = \
- sort_values_and_indices(right, indices, center_set)
- self.n_center = len(self.center_left_indices)
- @cython.wraparound(False)
- @cython.boundscheck(False)
- cdef classify_intervals(self, {{dtype}}_t[:] left, {{dtype}}_t[:] right):
- """Classify the given intervals based upon whether they fall to the
- left, right, or overlap with this node's pivot.
- """
- cdef:
- Int64Vector left_ind, right_ind, overlapping_ind
- Py_ssize_t i
- left_ind = Int64Vector()
- right_ind = Int64Vector()
- overlapping_ind = Int64Vector()
- for i in range(self.n_elements):
- if right[i] {{cmp_right_converse}} self.pivot:
- left_ind.append(i)
- elif self.pivot {{cmp_left_converse}} left[i]:
- right_ind.append(i)
- else:
- overlapping_ind.append(i)
- return (left_ind.to_array(),
- right_ind.to_array(),
- overlapping_ind.to_array())
- cdef new_child_node(self,
- ndarray[{{dtype}}_t, ndim=1] left,
- ndarray[{{dtype}}_t, ndim=1] right,
- ndarray[int64_t, ndim=1] indices,
- ndarray[int64_t, ndim=1] subset):
- """Create a new child node.
- """
- left = take(left, subset)
- right = take(right, subset)
- indices = take(indices, subset)
- return {{dtype_title}}Closed{{closed_title}}IntervalNode(
- left, right, indices, self.leaf_size)
- @cython.wraparound(False)
- @cython.boundscheck(False)
- @cython.initializedcheck(False)
- cpdef query(self, Int64Vector result, {{fused_prefix}}scalar_t point):
- """Recursively query this node and its sub-nodes for intervals that
- overlap with the query point.
- """
- cdef:
- int64_t[:] indices
- {{dtype}}_t[:] values
- Py_ssize_t i
- if self.is_leaf_node:
- # Once we get down to a certain size, it doesn't make sense to
- # continue the binary tree structure. Instead, we use linear
- # search.
- for i in range(self.n_elements):
- if self.left[i] {{cmp_left}} point {{cmp_right}} self.right[i]:
- result.append(self.indices[i])
- else:
- # There are child nodes. Based on comparing our query to the pivot,
- # look at the center values, then go to the relevant child.
- if point < self.pivot:
- values = self.center_left_values
- indices = self.center_left_indices
- for i in range(self.n_center):
- if not values[i] {{cmp_left}} point:
- break
- result.append(indices[i])
- if point {{cmp_right}} self.left_node.max_right:
- self.left_node.query(result, point)
- elif point > self.pivot:
- values = self.center_right_values
- indices = self.center_right_indices
- for i in range(self.n_center - 1, -1, -1):
- if not point {{cmp_right}} values[i]:
- break
- result.append(indices[i])
- if self.right_node.min_left {{cmp_left}} point:
- self.right_node.query(result, point)
- else:
- result.extend(self.center_left_indices)
- NODE_CLASSES['{{dtype}}',
- '{{closed}}'] = {{dtype_title}}Closed{{closed_title}}IntervalNode
- {{endfor}}