Source code for oqupy.helpers

# Copyright 2022 The TEMPO Collaboration
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file  in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Handy helper functions.
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.axes import Axes

from oqupy.correlations import BaseCorrelations
from oqupy.tempo import TempoParameters


[docs]def plot_correlations_with_parameters( correlations: BaseCorrelations, parameters: TempoParameters, ax: Axes = None) -> Axes: """Plot the correlation function on a grid that corresponds to some tempo parameters. For comparison, it also draws a solid line that is 10% longer and has two more sampling points per interval. Parameters ---------- correlations: BaseCorrelations The correlation object we are interested in. parameters: TempoParameters The tempo parameters that determine the grid. """ if parameters.add_correlation_time is None: add_time = 0.0 infinity = False elif parameters.add_correlation_time == np.infty: add_time = 0.0 infinity = True else: add_time = parameters.add_correlation_time infinity = False dt = parameters.dt dkmax = parameters.dkmax int(add_time/dt) times_infl = dt/3.0 * np.arange((dkmax+1)*3 - 2) times_add = np.hstack((dt * np.arange(dkmax, dkmax+int(add_time/dt)), np.array([dt * dkmax + add_time]))) times_extra = np.linspace(times_add[-1], times_add[-1]*1.5, 10) corr = np.vectorize(correlations.correlation) corr_infl = corr(times_infl) sample = [3*i for i in range(dkmax+1)] corr_add = corr(times_add) corr_extra = corr(times_extra) show = False if ax is None: fig, ax = plt.subplots() show = True ax.set_xlabel(r"$\tau$") ax.set_ylabel(r"$C(\tau)$") ax.plot( times_infl, np.real(corr_infl), color="C0", linestyle="-", label="real") ax.scatter( times_infl[sample], np.real(corr_infl[sample]), marker="d", color="C0") ax.plot( times_infl, np.imag(corr_infl), color="C1", linestyle="-", label="imag") ax.scatter( times_infl[sample], np.imag(corr_infl[sample]), marker="o", color="C1") ax.plot(times_extra, np.real(corr_extra), color="C0", linestyle="-") ax.plot(times_extra, np.imag(corr_extra), color="C1", linestyle="-") if infinity: ax.axvline(times_add[0], color="r", linestyle="dashed") ax.fill_between( times_extra, np.real(corr_extra), 0.0, color="C0", alpha=0.30) ax.fill_between( times_extra, np.imag(corr_extra), 0.0, color="C1", alpha=0.30) elif add_time != 0.0: ax.plot(times_add, np.real(corr_add), color="C0", linestyle="-") ax.plot(times_add, np.imag(corr_add), color="C1", linestyle="-") ax.fill_between( times_add, np.real(corr_add), 0.0, color="C0", alpha=0.30) ax.fill_between( times_add, np.imag(corr_add), 0.0, color="C1", alpha=0.30) ax.axvline(times_add[0], color="k", linestyle="dashed") ax.axvline(times_add[-1], color="k", linestyle="dotted") else: pass ax.legend() if show: fig.show() return ax