"""
This code is intended purely to generate the plots accompanying the post at https://dbarker.uk/posts/volatility-standard-error
and to serve as illustration of the concepts detailed therein.
It is provided "as is", and without warranty of any kind.
"""

# %%

import matplotlib.axes as ma
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as sts

from danpy.plotting import style, save_show

style()

np.random.seed(5)


T = 1_000
D = 10_000


def calc_errs(xs: np.ndarray, ns: np.ndarray) -> np.ndarray:
    zs = np.cumsum(xs, axis=0)

    def sample(zs: np.ndarray, n: int) -> np.ndarray:
        return zs[n::n, :] - zs[0::n, :][:-1, :]

    err = np.zeros_like(ns, dtype=float)

    for i, n in enumerate(ns):
        std = sample(zs, n).std(ddof=1, axis=0) / np.sqrt(n)
        err[i] = std.std(ddof=1)

    return err


ns = np.arange(1, 150, dtype=int)  # Sampling frequency.
ms = T // ns  # Number of samples for each sample interval.


def do_plot(xs: np.ndarray, v: float, k: float, ax: ma.Axes | None = None):
    ax = ax or plt.gca()

    err = calc_errs(xs, ns)

    k += 3
    s = np.sqrt(v)
    # fmt: off
    print(f"Stddev.:  {s:.3f}  (observed: {np.std(xs.flatten(), ddof=1):.3f})")
    print(f"Kurtosis: {k:.3f}  (observed: {sts.kurtosis(xs.flatten(), fisher=False):.3f})")
    # fmt: on

    # fmt: off
    ax.plot(ns, s * np.sqrt(((k - 3) + 2 * ns) / (4 * ms * ns)), color="C1", label="Theoretical")
    ax.scatter(ns, err, marker=".", edgecolor="none", color="C2", s=40, zorder=10, label="Empirical")
    ax.axhline(np.sqrt(((k - 3) + 2) / (4 * T)), label="Asymptote")
    # fmt: on


fig = plt.figure(figsize=(9, 4))
ax1 = fig.add_subplot(1, 2, 1)
ax2 = fig.add_subplot(1, 2, 2, sharey=ax1)

# Normal, excess kurtosis = 0.
f = sts.norm(loc=0, scale=1)
v, k = f.stats(moments="vk")
xs = f.rvs(size=(T, D))
do_plot(xs, v, k, ax1)

# Student's T, excess kursis > 0.
nu = 5
f = sts.t(nu, scale=1 / np.sqrt(nu / (nu - 2)))
v, k = f.stats(moments="vk")
xs = f.rvs(size=(T, D))
do_plot(xs, v, k, ax2)

ax2.tick_params("y", labelleft=False)

ax1.set_ylabel(r"Standard Error - $\varepsilon(S_X)$")
ax1.set_xlabel(r"Sampling Period - $n$")
ax2.set_xlabel(r"Sampling Period - $n$")
ax1.set_title(r"Gaussian ($\kappa_X = 3$)")
ax2.set_title(r"Student's T ($\nu = 5, \kappa_X = 9$)")

ax1.legend()
ax2.legend()

ax1.minorticks_on()
ax2.minorticks_on()
plt.tight_layout()
plt.ylim(0)

save_show("std_err_empirical.svg")

# %%
