import glob
import os

import numpy
from matplotlib import pyplot as plt
from obspy import read
import numpy as np
from obspy import Trace
from scipy import fftpack


def glitch_catcher(
        tr: Trace | numpy.ndarray,
        window_length: float,
        threshold: float,
        output=False
):
    """
    Catch glitches - returns True if a glitch is found.

     Parameters
     ---
         :param tr:
         Trace to look for glitches in
         :param window_length:
         Window length to compute average differences over in seconds
         :param threshold:
         Threshold multiplier of average differences to declare a glitch
         :param output:
     """

    window_length = int(window_length * tr.stats.sampling_rate)
    # Using a window length to try and avoid removing earthquake signals...

    if tr is Trace:
        diffs = np.abs(tr.data[0:-1] - tr.data[1:])
    else:
        diffs = np.abs(tr[0:-1] - tr[1:])
    diffs_moving = np.cumsum(diffs)
    diffs_moving = diffs_moving[window_length:] - diffs_moving[0:-window_length:]
    diffs_moving = diffs_moving / window_length
    # Extend diffs_moving to the same length as diffs - not a great way to do this!
    diffs_moving = np.concatenate([
        np.array([diffs_moving.mean()] * (len(diffs) - len(diffs_moving))),
        diffs_moving
    ])
    if np.any(diffs > diffs_moving * threshold):
        if output:
            print(f"Found large differences at {
                np.where(diffs > diffs_moving * threshold)
            }")
        return True
    return False


lox = os.path.join("data\\replots-glitches")
fil = glob.glob(os.path.join("data\\seeds", "*.mseed"))
for n in fil:
    dta = read(n).traces[0]
    dta_nam = dta.stats.station
    print(dta_nam)
    win = 8
    thr = 10

    gli = glitch_catcher(dta, win, thr)
    plt.plot(dta.data, lw=.5, color="red" if gli else "blue")
    plt.title(f"window: {win}, threshold: {thr}\n{"glitched" if gli else ""}")
    plt.savefig(fname=f"{lox}\\{dta_nam}.png")
    plt.close()






