Source code for quimb.utils

"""Miscellenous
"""
import itertools
import collections
from importlib.util import find_spec


try:
    import cytoolz
    last = cytoolz.last
    concat = cytoolz.concat
    frequencies = cytoolz.frequencies
    partition_all = cytoolz.partition_all
    merge_with = cytoolz.merge_with
    valmap = cytoolz.valmap
    partitionby = cytoolz.partitionby
    concatv = cytoolz.concatv
    partition_all = cytoolz.partition_all
    compose = cytoolz.compose
    identity = cytoolz.identity
    isiterable = cytoolz.isiterable
    unique = cytoolz.unique
    keymap = cytoolz.keymap
except ImportError:
    import toolz
    last = toolz.last
    concat = toolz.concat
    frequencies = toolz.frequencies
    partition_all = toolz.partition_all
    merge_with = toolz.merge_with
    valmap = toolz.valmap
    partitionby = toolz.partitionby
    concatv = toolz.concatv
    partition_all = toolz.partition_all
    compose = toolz.compose
    identity = toolz.identity
    isiterable = toolz.isiterable
    unique = toolz.unique
    keymap = toolz.keymap


_CHECK_OPT_MSG = "Option `{}` should be one of {}, but got '{}'."


def check_opt(name, value, valid):
    if value not in valid:
        raise ValueError(_CHECK_OPT_MSG.format(name, valid, value))


[docs]def find_library(x): """Check if library is installed. Parameters ---------- x : str Name of library Returns ------- bool If library is available. """ return find_spec(x) is not None
[docs]def raise_cant_find_library_function(x, extra_msg=None): """Return function to flag up a missing necessary library. This is simplify the task of flagging optional dependencies only at the point at which they are needed, and not earlier. Parameters ---------- x : str Name of library extra_msg : str, optional Make the function print this message as well, for additional information. Returns ------- callable A mock function that when called, raises an import error specifying the required library. """ def function_that_will_raise(*_, **__): error_msg = f"The library {x} is not installed. " if extra_msg is not None: error_msg += extra_msg raise ImportError(error_msg) return function_that_will_raise
FOUND_TQDM = find_library('tqdm') if FOUND_TQDM: from tqdm import tqdm
[docs] class continuous_progbar(tqdm): """A continuous version of tqdm, so that it can be updated with a float within some pre-given range, rather than a number of steps. Parameters ---------- args : (stop) or (start, stop) Stopping point (and starting point if ``len(args) == 2``) of window within which to evaluate progress. total : int The number of steps to represent the continuous progress with. kwargs Supplied to ``tqdm.tqdm`` """ def __init__(self, *args, total=100, **kwargs): """ """ kwargs.setdefault('ascii', True) super(continuous_progbar, self).__init__(total=total, unit="%", **kwargs) if len(args) == 2: self.start, self.stop = args else: self.start, self.stop = 0, args[0] self.range = self.stop - self.start self.step = 1
[docs] def cupdate(self, x): """'Continuous' update of progress bar. Parameters ---------- x : float Current position within the range ``[self.start, self.stop]``. """ num_update = int( (self.total + 1) * (x - self.start) / self.range - self.step ) if num_update > 0: self.update(num_update) self.step += num_update
def progbar(*args, **kwargs): kwargs.setdefault('ascii', True) return tqdm(*args, **kwargs) else: # pragma: no cover extra_msg = "This is needed to show progress bars." progbar = raise_cant_find_library_function("tqdm", extra_msg) continuous_progbar = raise_cant_find_library_function("tqdm", extra_msg) def deprecated(fn, old_name, new_name): def new_fn(*args, **kwargs): import warnings warnings.warn(f"The {old_name} function is deprecated in favor " f"of {new_name}", Warning) return fn(*args, **kwargs) return new_fn def int2tup(x): return (x if isinstance(x, tuple) else (x,) if isinstance(x, int) else tuple(x))
[docs]def ensure_dict(x): """Make sure ``x`` is a ``dict``, creating an empty one if ``x is None``. """ if x is None: return {} return dict(x)
[docs]def pairwise(iterable): """Iterate over each pair of neighbours in ``iterable``. """ a, b = itertools.tee(iterable) next(b, None) return zip(a, b)
def print_multi_line(*lines, max_width=None): if max_width is None: import shutil max_width, _ = shutil.get_terminal_size() max_line_lenth = max(len(ln) for ln in lines) if max_line_lenth <= max_width: for ln in lines: print(ln) else: # pragma: no cover max_width -= 10 # for ellipses and pad n_lines = len(lines) n_blocks = (max_line_lenth - 1) // max_width + 1 for i in range(n_blocks): if i == 0: for j, l in enumerate(lines): print( "..." if j == n_lines // 2 else " ", l[i * max_width:(i + 1) * max_width], "..." if j == n_lines // 2 else " " ) print(("{:^" + str(max_width) + "}").format("...")) elif i == n_blocks - 1: for ln in lines: print(" ", ln[i * max_width:(i + 1) * max_width]) else: for j, ln in enumerate(lines): print( "..." if j == n_lines // 2 else " ", ln[i * max_width:(i + 1) * max_width], "..." if j == n_lines // 2 else " ", ) print(("{:^" + str(max_width) + "}").format("..."))
[docs]def save_to_disk(obj, fname, **dump_opts): """Save an object to disk using joblib.dump. """ import joblib return joblib.dump(obj, fname, **dump_opts)
[docs]def load_from_disk(fname, **load_opts): """Load an object form disk using joblib.load. """ import joblib return joblib.load(fname, **load_opts)
[docs]class Verbosify: # pragma: no cover """Decorator for making functions print their inputs. Simply for illustrating a MPI example in the docs. """ def __init__(self, fn, highlight=None, mpi=False): self.fn = fn self.highlight = highlight self.mpi = mpi def __call__(self, *args, **kwargs): if self.mpi: from mpi4py import MPI pre_msg = f"{MPI.COMM_WORLD.Get_rank()}: " else: pre_msg = "" if self.highlight is None: print(f"{pre_msg} args {args}, kwargs {kwargs}") else: print(f"{pre_msg}{self.highlight}={kwargs[self.highlight]}") return self.fn(*args, **kwargs)
[docs]class oset: """An ordered set which stores elements as the keys of dict (ordered as of python 3.6). 'A few times' slower than using a set directly for small sizes, but makes everything deterministic. """ __slots__ = ('_d',) def __init__(self, it=()): self._d = dict.fromkeys(it) @classmethod def _from_dict(cls, d): obj = object.__new__(oset) obj._d = d return obj
[docs] @classmethod def from_dict(cls, d): """Public method makes sure to copy incoming dictionary. """ return oset._from_dict(d.copy())
def copy(self): return oset.from_dict(self._d) def add(self, k): self._d[k] = None def discard(self, k): self._d.pop(k, None) def remove(self, k): del self._d[k] def clear(self): self._d.clear() def update(self, *others): for o in others: self._d.update(o._d) def union(self, *others): u = self.copy() u.update(*others) return u def intersection_update(self, *others): if len(others) > 1: si = set.intersection(*(set(o._d) for o in others)) else: si = others[0]._d self._d = {k: None for k in self._d if k in si} def intersection(self, *others): n_others = len(others) if n_others == 0: return self.copy() elif n_others == 1: si = others[0]._d else: si = set.intersection(*(set(o._d) for o in others)) return oset._from_dict({k: None for k in self._d if k in si}) def difference_update(self, *others): if len(others) > 1: su = set.union(*(set(o._d) for o in others)) else: su = others[0]._d self._d = {k: None for k in self._d if k not in su} def difference(self, *others): if len(others) > 1: su = set.union(*(set(o._d) for o in others)) else: su = others[0]._d return oset._from_dict({k: None for k in self._d if k not in su}) def popleft(self): k = next(iter(self._d)) del self._d[k] return k def popright(self): return self._d.popitem()[0] def __eq__(self, other): if isinstance(other, oset): return self._d == other._d return False def __or__(self, other): return self.union(other) def __ior__(self, other): self.update(other) return self def __and__(self, other): return self.intersection(other) def __iand__(self, other): self.intersection_update(other) return self def __sub__(self, other): return self.difference(other) def __isub__(self, other): self.difference_update(other) return self def __len__(self): return self._d.__len__() def __iter__(self): return self._d.__iter__() def __contains__(self, x): return self._d.__contains__(x) def __repr__(self): return f"oset({list(self._d)})"
[docs]class LRU(collections.OrderedDict): """Least recently used dict, which evicts old items. Taken from python collections OrderedDict docs. """ def __init__(self, maxsize, *args, **kwds): self.maxsize = maxsize super().__init__(*args, **kwds) def __getitem__(self, key): value = super().__getitem__(key) self.move_to_end(key) return value def __setitem__(self, key, value): if key in self: self.move_to_end(key) super().__setitem__(key, value) if len(self) > self.maxsize: oldest = next(iter(self)) del self[oldest]
[docs]def gen_bipartitions(it): """Generate all unique bipartitions of ``it``. Unique meaning ``(1, 2), (3, 4)`` is considered the same as ``(3, 4), (1, 2)``. """ n = len(it) if n: for i in range(1, 2**(n - 1)): bitstring_repr = f'{i:0>{n}b}' l, r = [], [] for b, x in zip(bitstring_repr, it): (l if b == '0' else r).append(x) yield l, r