Source code for oqupy.util

# Copyright 2022 The TEMPO Collaboration
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module for utilities.
"""

import sys
import copy as cp
from typing import Any, List, Optional, Text
from threading import Timer
from time import time
from datetime import timedelta

import numpy as np
from numpy import ndarray

from oqupy.config import PROGRESS_TYPE

# -- numpy utils --------------------------------------------------------------

[docs]def create_delta( tensor: ndarray, index_scrambling: List[int]) -> ndarray: """ Creates deltas in numpy tensor. .. warning:: This is a computationally inefficient method to perform the task. .. todo:: Make it better. """ tensor_shape = tensor.shape a = [0]*len(tensor_shape) ret_shape = tuple(list(tensor_shape)[i] for i in index_scrambling) ret_ndarray = np.zeros(ret_shape, dtype=tensor.dtype) # emulating a do-while loop do_while_condition = True while do_while_condition: tensor_indices = tuple(a) ret_indices = tuple(a[i] for i in index_scrambling) ret_ndarray[ret_indices] = tensor[tensor_indices] do_while_condition = increase_list_of_index(a, tensor_shape) return ret_ndarray
[docs]def increase_list_of_index( a: List, shape: List, index: Optional[int] = -1) -> bool: """Circle through a list of indices. """ a[index] += 1 if a[index] >= shape[index]: if index == -len(shape): return False a[index] = 0 return increase_list_of_index(a, shape, index-1) return True
[docs]def add_singleton( tensor: ndarray, index: Optional[int] = -1, copy: Optional[bool] = True) -> ndarray: """Add a singleton to a numpy tensor. """ if copy: ten = cp.copy(tensor) else: ten = tensor shape_list = list(ten.shape) shape_list.insert(index, 1) ten.shape = tuple(shape_list) return ten
[docs]def is_diagonal_matrix(tensor: ndarray): """Check if matrix is diagonal """ assert len(tensor.shape) == 2 i, j = tensor.shape assert i == j test = tensor.reshape(-1)[:-1].reshape(i-1, j+1) return ~np.any(test[:, 1:])
# -- input parsing -----------------------------------------------------------
[docs]def check_convert( variable: Any, conv_type: Any, name: Text = None, msg: Text = None): """Attempt to convert variable into a specific type. """ try: converted_variable = conv_type(variable) except Exception as e: name_str = f"`{name}`" if name is not None else "" msg_str = msg if msg is not None else "" err_str = f"Variable `{name_str}` must be type `{conv_type.__name__}`." raise TypeError(err_str + msg_str) from e return converted_variable
[docs]def check_true( expr: bool, msg: Text = None): """Check that an specific expression is true. """ if not expr: msg_str = msg if msg is not None else "" raise ValueError(msg_str)
[docs]def check_isinstance( variable: Any, types: Any, name: Text = None, msg: Text = None): """Check that a variable is an instance of one of the given types. """ if not isinstance(types, tuple): types_list = (types, ) else: types_list = types if not isinstance(variable, types_list): name_str = f"`{name}`" if name is not None else "" types_str = " or ".join([f"`{type.__name__}`" for type in types_list]) msg_str = msg if msg is not None else "" raise TypeError(f"Variable {name_str} is not of the type " + f"{types_str}. {msg_str}")
# -- process bar --------------------------------------------------------------
[docs]class BaseProgress: """Base class to display computation progress. """ def __enter__(self): """Contextmanager enter. """ return self.enter() def __exit__(self, exception_type, exception_value, traceback): """Contextmanager exit. """ self.exit()
[docs] def enter(self): """Context enter.""" raise NotImplementedError()
[docs] def exit(self): """Context exit. """ raise NotImplementedError()
[docs] def update(self, step=None): """Update the progress. """ raise NotImplementedError()
[docs]class ProgressSilent(BaseProgress): """Class NOT to display the computation progress. """ def __init__(self, max_value, title = None): """Create a ProgressSilent object. """ self.max_value = max_value self.title = title self.step = None
[docs] def enter(self): """Context enter. """ return self
[docs] def exit(self): """Context exit. """ pass
[docs] def update(self, step=None): """Update the progress. """ self.step = step
[docs]class ProgressSimple(BaseProgress): """Class to display the computation progress step by step. """ def __init__(self, max_value, title = None): """Create a ProgressSimple object. """ self.max_value = max_value self.title = title self.step = None self._file = sys.stdout self._start_time = None self._previouse_time = None
[docs] def enter(self): """Context enter. """ if self.title is not None: print(self.title, flush=True) self._start_time = time() self._previouse_time = time() return self
[docs] def exit(self): """Context exit. """ current_time = time() total_t = current_time - self._start_time print("Total elapsed time: {:9.1f}s".format(total_t), flush=True)
[docs] def update(self, step=None): """Update the progress. """ current_time = time() dt = current_time - self._previouse_time total_t = current_time - self._start_time self.step = step print("Step {:5d} of {:5d}, total time: {:9.1f}s (+{:8.2f}s)".format( self.step, self.max_value, total_t, dt), flush=True) self._previouse_time = current_time
PROGRESS_BAR_LENGTH = 40
[docs]class ProgressBar(BaseProgress): """Class to display the computation progress with a nice progress bar. """ def __init__(self, max_value, title = None): """Create a ProgressBar object. """ self._timer = None self._start_time = time() self._file = sys.stdout self.max_value = max_value self.title = title self._length = PROGRESS_BAR_LENGTH self._step = None
[docs] def enter(self): """Context enter. """ if self.title is not None: print(self.title, file=self._file, flush=True) self._timer = Timer(1.0, self._print_status) self._timer.start() return self
def _print_status(self): if self._step is None: step = 0 else: step = self._step try: frac = float(step)/float(self.max_value) except ZeroDivisionError: frac = 1.0 delta_t = time() - self._start_time time_string = "{:0>8}".format(str(timedelta(seconds=int(delta_t)))) done_int = int(frac*self._length) bar_string = "\r{:5.1f}% {:4d} of {:4d} [{}{}] {}" bar_string = bar_string.format(frac*100, step, self.max_value, "#" * done_int, "-" * (self._length - done_int), time_string) self._file.write(bar_string) self._file.flush()
[docs] def exit(self): """Context exit. """ self._timer.cancel() self._print_status() delta_t = time() - self._start_time print("\nElapsed time: {:.1f}s".format(delta_t), file=self._file, flush=True)
[docs] def update(self, step=None): """Update the progress. """ self._timer.cancel() self._timer = Timer(1.0, self.update) self._timer.start() if step is not None: self._step = step self._print_status()
PROGRESS_DICT = { "silent": ProgressSilent, "simple": ProgressSimple, "bar": ProgressBar, }
[docs]def get_progress(progress_type: Text = None) -> BaseProgress: """Get a progress class from the progress_type. """ if progress_type is None: progress_type = PROGRESS_TYPE assert progress_type in PROGRESS_DICT, \ "Unknown progress_type='{}', know are {}".format( progress_type, PROGRESS_DICT.keys()) return PROGRESS_DICT[progress_type]