Source code for scatcluster.processing.scattering

import os
import pickle

import obspy

import cupy as cp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from matplotlib import dates as mdates
from obspy.clients.filesystem.sds import Client
from obspy.core import UTCDateTime
from obspy.core.stream import Stream
from scatseisnet.network import ScatteringNetwork
from scatseisnet.operation import segmentize
from scipy import stats as sp_stats
from tqdm import tqdm


[docs]class Scattering:
[docs] def reduce_type(self): """ Pooling operation performed on the last axis. """ pooling_options = { 'avg': np.mean, 'max': np.max, 'median': np.median, 'std': np.std, 'gmean': sp_stats.gmean, 'hmean': sp_stats.hmean, 'pmean': sp_stats.hmean, 'kurtosis': sp_stats.kurtosis, 'skew': sp_stats.skew, 'entropy': sp_stats.entropy, 'sem': sp_stats.sem, 'differential_entropy': sp_stats.differential_entropy, 'median_abs_deviation': sp_stats.median_abs_deviation, } return pooling_options.get(self.network_pooling, None)
[docs] def load_data_times(self): """ Load the data times from a file and store them in the `data_times` attribute. This function reads the data times from a file located at `{self.data_savepath}data/{self.data_network}_{self.data_station}_{self.data_location}_` `{self.network_name}_times.npy` and stores them in the `data_times` attribute. """ try: file_path = f'{self.data_savepath}data/{self.data_network}_{self.data_station}_{self.data_location}_' \ f'{self.network_name}_times.npy' self.data_times = np.load(file_path) except FileNotFoundError: print(f'File not found: {file_path}') except Exception as e: print(f'An error occurred while loading data times: {e}')
[docs] def build_day_list(self) -> None: """Build data_day_list object """ try: start_time = UTCDateTime(self.data_starttime) end_time = UTCDateTime(self.data_endtime) exclude_days = [UTCDateTime(day).strftime('%Y-%m-%d') for day in self.data_exclude_days] day_list = [ day_start for day_start in pd.date_range(start_time.strftime('%Y%m%d'), end_time.strftime( '%Y%m%d')).strftime('%Y-%m-%d').tolist() if day_start not in exclude_days ] self.data_day_list = day_list except Exception as e: print(f'An error occurred while building day list: {e}')
[docs] def build_channel_list(self) -> None: if self.sample_stream is None: self.process_sample_data(plot_spectra=False) self.channel_list = [trace.stats.channel for trace in self.sample_stream]
[docs] def stream_process(self, stream: Stream) -> Stream: """PreProcessing of obspy stream before calculating scattering coefficients Args: stream (Stream): Obspy Stream Returns: Stream: processed obspy Stream """ # Remove trend stream.detrend(type='demean') # High-pass filter stream.filter(type='highpass', freq=0.5) # Remove residual trend stream.detrend(type='constant') # Remove edge effects stream.taper(0.05) return stream
[docs] def load_data(self, starttime: UTCDateTime, endtime: UTCDateTime, channel: str) -> Stream: """Load the seismic and trim according to data_starttime and data_endtime Args: starttime (UTCDateTime): Start datetime of the trim endtime (UTCDateTime): End datetime of the trim channel (str): Channel selected Returns: Stream: Processed obspy stream """ try: if 'local:' in self.data_client_path: stream = obspy.read( self.data_client_path.replace('local:','')) stream.trim(starttime, endtime) elif 'sds.chris' in self.data_client_path: client = Client(self.data_client_path) stream = client.get_waveforms(network=self.data_network, station=self.data_station, location=self.data_location, channel=channel, starttime=starttime, endtime=endtime) else: raise ValueError('Unknown data client path') stream = self.stream_process(stream) stream.merge(method=1, fill_value=0) stream.trim(starttime, endtime, pad=True, fill_value=0) return stream except Exception as e: print(f'>> Skipping {starttime}-{endtime} as there was an error in loading data from SDS Client due to {e}') return Stream()
[docs] def network_build_scatcluster(self) -> None: """Build scatcluster network, assign to self.net and store as pickle """ self.network_samples_per_segment = int(self.network_segment * self.network_sampling_rate) self.network_samples_per_step = int(self.network_step * self.network_sampling_rate) self.net = ScatteringNetwork(*self.network_banks, bins=self.network_samples_per_segment, sampling_rate=self.network_sampling_rate) # SAVE NETWORK IN PICKLE FILE with open( f'{self.data_savepath}networks/{self.data_network}_{self.data_station}_{self.data_location}_' f'{self.network_name}.pickle', 'wb') as handle: pickle.dump(self.net, handle, protocol=pickle.HIGHEST_PROTOCOL)
[docs] def plot_network_filter_banks(self, savefig: bool = True, **kwargs) -> None: """ Plot the filter banks """ NROWS = len(self.net.banks) # Crete axes octaves = [bank.octaves for bank in self.net.banks] height_ratios = 2, (octaves[1] + 2) / (octaves[0] + 2) grid = {'height_ratios': height_ratios, 'wspace': 0.1, 'hspace': 0.1} kwargs['figsize'] = (10, 10) if kwargs.get('figsize') is None else kwargs.get('figsize') _, axes = plt.subplots(NROWS, 2, gridspec_kw=grid, sharey='row', **kwargs) # Loop over network layers for ax_enum, (ax, bank) in enumerate(zip(axes, self.net.banks)): # Limit view to three times the temporal width of largest wavelet width_max = min(2 * bank.widths.max(), bank.times.max()) if isinstance(bank.wavelets, cp.ndarray): bank.wavelets = bank.wavelets.get() if isinstance(bank.spectra, cp.ndarray): bank.spectra = bank.spectra.get() if isinstance(bank.widths, cp.ndarray): bank.widths = bank.widths.get() if isinstance(bank.centers, cp.ndarray): bank.centers = bank.centers.get() if isinstance(bank.frequencies, cp.ndarray): bank.centers = bank.frequencies.get() # Temporal for octave_enum, (wavelet, octave, width) in enumerate(zip(bank.wavelets, bank.ratios, bank.widths)): # Truncate time for small-duration wavelets inner = np.abs(bank.times) < width_max t = bank.times[inner] y = wavelet[inner] / np.abs(wavelet[inner].max()) / 3 ax[0].plot(t, y.real + octave, color='C0', zorder=1) ax[0].plot(t, y.imag + octave, color='C0', zorder=0, alpha=0.4) ax[0].text(-1 * width_max * (1 if octave_enum % 2 == 0 else 1.1), octave, f'{width*4:.2f}', fontsize='small') # Spectral frequencies = bank.frequencies for octave_enum, (spectrum, octave, center) in enumerate(zip(bank.spectra, bank.ratios, bank.centers)): inner = frequencies > frequencies[1] f = frequencies[inner] y = spectrum[inner] y /= np.abs(y.max()) ax[1].plot(f, np.abs(y) + octave, color='C0') ax[1].text((10**-2) * (1 if octave_enum % 2 == 0 else 1.1), octave, f'{center:.2f}', fontsize='small') # Labels ax[0].grid(axis='x') ax[0].set_xlabel('Time (seconds)') axes[ax_enum, 0].set_ylabel(f'Order {ax_enum+1}\nOctaves (base 2 log)') ax[1].grid(axis='x') ax[1].set_xlabel('Frequency (Hz)') ax[1].set_xscale('log') axes[ax_enum, 0].text(-1 * width_max, 0.2, 'Temporal\nWidth (s)', fontsize='small') axes[ax_enum, 1].text(10**-2, 0.2, 'Centre\nFreq. (Hz)', fontsize='small') # Axes ax[1].set_ylim(-bank.octaves - 1, 1) ax[0].set_yticks(-np.arange(bank.octaves + 1)) ax[0].set_yticklabels(np.arange(bank.octaves + 1)) ax[0].set_yticks(-np.arange(bank.octaves + 1)) ax[0].set_yticklabels(np.arange(bank.octaves + 1)) ax[1].tick_params(axis='y', left=False, labelleft=False) # Legend axes[1][0].legend([r'Re $\varphi(t)$', r'Im $\varphi(t)$'], loc=1) axes[1][1].legend([r'$\hat\varphi(\omega)$'], loc=1) plt.suptitle('ScatCluster Parametrization' f'\nSegment:{self.network_segment}s Step: {self.network_step}\n Banks: {self.network_banks_name}') plt.subplots_adjust(top=0.9) if savefig: plt.savefig(f'{self.data_savepath}figures/{self.data_network}_{self.data_station}_{self.data_location}_' f'{self.network_name}_filter_banks.png')
[docs] def load_sample_data(self) -> Stream: """Load sample """ return self.load_data(starttime=UTCDateTime(self.data_sample_starttime), endtime=UTCDateTime(self.data_sample_endtime), channel=self.data_channel)
[docs] def plot_sample_spectra(self) -> None: """Plot the Network filter spectra""" frequencies = self.net.banks[0].centers timestamps = pd.to_datetime(self.sample_times, unit='D') timestamps_scats = pd.to_datetime(self.sample_times_scatterings, unit='D') _, ax = plt.subplots(2, len(self.channel_list), sharex=True, sharey='row', figsize=(20, 5)) for channel_num, _ in enumerate(self.channel_list): first_order_scattering_coefficients = self.sample_scattering_coefficients[0][:, channel_num, :].squeeze().T first_order_scattering_coefficients = np.real(np.log10(first_order_scattering_coefficients)) ax[0, channel_num].plot(timestamps, self.sample_data[channel_num], rasterized=True) ax[0, channel_num].set_title(self.channel_list[channel_num]) ax[1, channel_num].pcolormesh(timestamps_scats, frequencies, first_order_scattering_coefficients, rasterized=True) ax[1, channel_num].set_yscale('log') ax[1, channel_num].tick_params('x', labelrotation=90) ax[0, 0].set_ylabel('Sample Trace') ax[1, 0].set_ylabel('First Order Scat. Coefficients\nFrequency (Hz)') plt.subplots_adjust(wspace=0, hspace=0) plt.suptitle('Sample Trace ScatCluster Transform') plt.savefig(f'{self.data_savepath}figures/{self.data_network}_{self.data_station}_{self.data_location}_' f'{self.network_name}_sample_transform.png') plt.show()
[docs] def process_sample_data(self, plot_spectra: bool = True) -> None: """Process the sample data range. This involes: (1) load the data and process, (2) define the sample_times and sample_data, (3) segmentize into sample_data_segments and respective sample_times_scatterings, (4) transform into sample_scattering_coefficients, (5) plot filter spectra """ self.sample_stream = self.load_sample_data() self.sample_times = self.sample_stream[0].times('matplotlib') self.sample_data = np.array([trace.data for trace in self.sample_stream]) self.channel_list = [trace.stats.channel for trace in self.sample_stream] self.sample_data_segments = segmentize(self.sample_data, self.network_samples_per_segment, self.network_samples_per_step) self.sample_times_scatterings = segmentize(self.sample_times, self.network_samples_per_segment, self.network_samples_per_step)[:, 0] self.sample_scattering_coefficients = self.net.transform(self.sample_data_segments, self.reduce_type()) if plot_spectra: self.plot_sample_spectra()
[docs] def plot_seismic(self, sample: bool = False): """ Plot the seismic data. Parameters: sample (bool): If True, plot the sample data. Otherwise, plot the regular data. """ if sample: if self.sample_data is None: self.sample_stream = self.load_sample_data() self.sample_times = self.sample_stream[0].times('matplotlib') self.sample_data = np.array([trace.data for trace in self.sample_stream]) self.channel_list = [trace.stats.channel for trace in self.sample_stream] times = self.sample_times data = self.sample_data channel_list = self.channel_list else: if self.data_all is None: # pylint: disable=access-member-before-definition self.data_stream = self.load_data(starttime=UTCDateTime(self.data_starttime), endtime=UTCDateTime(self.data_endtime), channel=self.data_channel) self.data_times = self.data_stream[0].times('matplotlib') self.data_all = np.array([trace.data for trace in self.data_stream]) self.channel_list = [trace.stats.channel for trace in self.data_stream] times = self.data_times data = self.data_all channel_list = self.channel_list # Plot _, axes = plt.subplots(3, 1, figsize=(20, 10), sharex=True, sharey=True) for channel_enum, channel in enumerate(channel_list): axes[channel_enum].plot(times, data[channel_enum, :]) axes[channel_enum].set_ylabel(f'{channel}') dateticks = mdates.AutoDateLocator() datelabels = mdates.ConciseDateFormatter(dateticks) axes[0].xaxis.set_major_locator(dateticks) axes[0].xaxis.set_major_formatter(datelabels) axes[0].set_xlim(times.min(), times.max())
[docs] def process_scatcluster_yyyy_mm_dd(self, day_start: str, day_end: str) -> None: """Process scatcluster for a single day. Args: day_start (str): Start day of format "YYYY-MM-DD" day_end (str): End day of format "YYYY-MM-DD" """ print(f'Processing {day_start} - {day_end}') scatterings_path = (f'{self.data_savepath}scatterings/{self.data_network}_{self.data_station}_' f'{self.data_location}_{self.network_name}_scatterings_{day_start}.npz') if os.path.exists(scatterings_path): print('> Scatterings already exist') else: # Check if day_start exits is valid in data_day_start if day_start not in self.data_day_list: print(f'> Processing of {day_start} has been excluded as it is part of `data_exclude_days` parameter.') else: stream = self.load_data(starttime=UTCDateTime(day_start), endtime=UTCDateTime(day_end), channel=self.data_channel) if len(stream.traces) == 0: print(f'>> Skipping {day_start} as there is no traces') elif len(stream.traces) < 3: print(f'>> Skipping {day_start} as there is not all 3 channels') else: # Numpyification times = stream[0].times('matplotlib') data = np.array([trace.data for trace in stream]) # Segmentization data_segments = segmentize(data, self.network_samples_per_segment, self.network_samples_per_step) times_scat = segmentize(times, self.network_samples_per_segment, self.network_samples_per_step)[:, 0] # Scattering transform scattering_coefficients = self.net.transform(data_segments, self.reduce_type()) # SAVE SCATTERING COEFFICIENTS IN NPZ FILE np.savez(scatterings_path, scat_coef_0=scattering_coefficients[0], scat_coef_1=scattering_coefficients[1], times=times_scat) # print stats print(f'>>> min {data.min()} : max {data.max()} : mean {data.mean()}')
[docs] def process_scatcluster_for_range(self) -> None: """Process scatcluster_yyyy_mm_dd for range of YYYY-MM-DDs """ self.build_day_list() if len(self.data_day_list) > 0: print(f'The following days will be excluded from the analysis: {self.data_exclude_days}') for day_start, day_end in zip( pd.date_range( UTCDateTime(self.data_starttime).strftime('%Y%m%d'), (UTCDateTime(self.data_endtime) - (60 * 60 * 24)).strftime('%Y%m%d')).strftime('%Y-%m-%d').tolist(), pd.date_range((UTCDateTime(self.data_starttime) + (60 * 60 * 24)).strftime('%Y%m%d'), UTCDateTime(self.data_endtime).strftime('%Y%m%d')).strftime('%Y-%m-%d').tolist()): self.process_scatcluster_yyyy_mm_dd(day_start, day_end)
[docs] def filters_per_layer(self, model): """Get the number of filters per layer.""" center_frequencies = [bank.centers for bank in model.banks] return [len(centers) for centers in center_frequencies]
[docs] def layer_shape(self, model, order): return self.filters_per_layer(model)[:order + 1]
[docs] def log(self, dataset, waterlevel=1e-10): """Get the log of the scattering coefficients. Parameters ---------- dataset : xarray.Dataset The scattering coefficients in the xarray.Dataset format. waterlevel : float The waterlevel to apply to the scattering coefficients. Returns ------- xarray.Dataset The scattering coefficients in the xarray.Dataset format. """ # Select where order 1 is non-zero for all channels and frequencies select = (dataset.order_1 > waterlevel).all(dim=['channel', 'f1']) dataset = dataset.sel(time=select) # Get the log dataset.order_1.values = np.log10(dataset.order_1.values + waterlevel) dataset.order_2.values = np.log10(dataset.order_2.values + waterlevel) return dataset
[docs] def nyquist_mask(self, dataset): """Mask the scattering coefficients with a Nyquist frequency. The scattering coefficients of order 2 are masked when the frequency f2 is greater than the frequency f1 to avoid aliasing. Parameters ---------- dataset : xarray.Dataset The scattering coefficients in the xarray.Dataset format. Returns ------- xarray.Dataset The scattering coefficients in the xarray.Dataset format. """ # Mask order 2 when f2 > f1 dataset.order_2.data = dataset.order_2.where(dataset.f1 >= dataset.f2, np.nan) # Drop NaN values dataset = dataset.dropna(dim='time', how='all') return dataset
[docs] def normalize(self, dataset): """Normalize the scattering coefficients. Parameters ---------- dataset : xarray.Dataset The scattering coefficients in the xarray.Dataset format. Returns ------- xarray.Dataset The scattering coefficients in the xarray.Dataset format. """ # Working dimensions for normalization order_1_dim = ['time', 'f1', 'channel'] order_2_dim = ['time', 'f1', 'f2', 'channel'] # Normalize dataset.order_1.data -= dataset.order_1.mean(dim=order_1_dim).data dataset.order_1.data /= dataset.order_1.std(dim=order_1_dim).data dataset.order_2.data -= dataset.order_2.mean(dim=order_2_dim).data dataset.order_2.data /= dataset.order_2.std(dim=order_2_dim).data return dataset
[docs] def min_max_scaling(self, dataset): """Min-Max scaling the scattering coefficients. Parameters ---------- dataset : xarray.Dataset The scattering coefficients in the xarray.Dataset format. Returns ------- xarray.Dataset The scattering coefficients in the xarray.Dataset format. """ # Working dimensions for normalization order_1_dim = ['time', 'f1', 'channel'] order_2_dim = ['time', 'f1', 'f2', 'channel'] # Normalize dataset.order_1.data -= dataset.order_1.min(dim=order_1_dim).data dataset.order_1.data /= (dataset.order_1.max(dim=order_1_dim).data - dataset.order_1.min(dim=order_1_dim).data) dataset.order_2.data -= dataset.order_2.min(dim=order_2_dim).data dataset.order_2.data /= (dataset.order_2.max(dim=order_2_dim).data - dataset.order_2.min(dim=order_2_dim).data) return dataset
[docs] def process_vectorized_scattering_coefficients(self) -> None: """ Process the vectorized scattering coefficients by loading data from files, reshaping the coefficients, standardizing in log space, and vectorizing them. Display statistics from the vectorization and store the processed data. Parameters: self: An instance of the class. """ file_list = [(f'{self.data_savepath}scatterings/{self.data_network}_{self.data_station}_{self.data_location}_' f'{self.network_name}_scatterings_{day_start}.npz') for day_start in self.data_day_list] # LOAD DATA TIMES = [] SC0 = [] SC1 = [] for file in file_list: try: scat_file = np.load(file) TIMES.append(scat_file['times']) SC0.append(scat_file['scat_coef_0']) SC1.append(scat_file['scat_coef_1']) except FileNotFoundError: print(f'{file} is missing. This has been skipped.') times = np.hstack(TIMES) del TIMES scat_coef_0 = np.vstack(SC0) del SC0 scat_coef_1 = np.vstack(SC1) del SC1 n_samples = len(times) if self.channel_list is None: self.build_channel_list() n_channels = len(self.channel_list) attributes = {k: str(v) for k, v in self.__dict__.items()} # The coordinates of the xarray dataset are the center frequencies of the # scattering network, the starttime of the waveforms, and the channel names. center_frequencies = [bank.centers for bank in self.net.banks] coordinates = { 'time': ('time', times), 'channel': ('channel', self.channel_list), **{ f'f{i + 1}': (f'f{i + 1}', centers) for i, centers in enumerate(center_frequencies) }, } # We now fill the data variables of the xarray dataset. The data variables # are the scattering coefficients for each order. variables = {} for order in range(len(self.net)): # Variable dimensions dimension = ( 'time', 'channel', *[f'f{j + 1}' for j in range(order + 1)], ) # Initialize and fill scattering matrix with scattering coefficients variable = np.zeros((n_samples, n_channels, *self.layer_shape(self.net, order))) for time_stamp in tqdm(range(n_samples), desc=f'Xarray order {order + 1}'): if order == 0: x = np.abs(scat_coef_0[time_stamp]) elif order == 1: x = np.abs(scat_coef_1[time_stamp]) for channel in range(n_channels): variable[time_stamp, channel] = x[order][channel] # Assign scattering matrix to data variable variables[f'order_{order + 1}'] = (dimension, variable) # Assign attributes and data variables to dataset coefficients = xr.Dataset( coords=coordinates, data_vars=variables, attrs=attributes, ) # Drop empty channels (where transform_waveform returned None) coefficients = coefficients.where( coefficients.order_1.sum(dim=('f1', 'channel')) > 0, drop=True, ) coefficients = self.nyquist_mask(coefficients) coefficients = self.normalize(coefficients) coefficients = self.log(coefficients, waterlevel=1e-5) coefficients = self.min_max_scaling(coefficients) print(coefficients) self.data_times = coefficients.time.values self.data_scat_coef_vectorized = self.vectorize_scattering_coefficients_xarray(coefficients) # Display statistics from the vectorization print(f'Number of valid time windows of size {self.network_segment}s: {int(self.data_times.shape[0])}') print(f'Number of days investigated: {int((self.network_segment * self.data_times.shape[0])/86400)}') print(f'Number of Scat Coefficients: {int(self.data_scat_coef_vectorized.shape[1])}') print(f'Vectorized Scat Coefficients: {self.data_scat_coef_vectorized.shape}') # Store Data np.save( f'{self.data_savepath}data/{self.data_network}_{self.data_station}_{self.data_location}_' f'{self.network_name}_times.npy', self.data_times) np.save( f'{self.data_savepath}data/{self.data_network}_{self.data_station}_{self.data_location}_' f'{self.network_name}_scat_coef_vectorized.npy', self.data_scat_coef_vectorized) coefficients.to_netcdf(f'{self.data_savepath}data/{self.data_network}_{self.data_station}_{self.data_location}_' f'{self.network_name}_scat_coef_xarray.nc') return coefficients
[docs] def vectorize_scattering_coefficients_xarray(self, coefficients): n_samples = coefficients.time.shape[0] x1 = coefficients.order_1.data.reshape(n_samples, -1) x2 = coefficients.order_2.data.reshape(n_samples, -1) x = np.hstack((x1, x2)) x[np.isnan(x)] = 0 return x
[docs] def load_scattering_coefficients_xarray(self): """ Load the scattering coefficients from an xarray dataset file and store them in the `scattering_coefficients_xarray` attribute. Returns: xr.Dataset: The loaded scattering coefficients dataset. """ scat_coeff_xr = xr.open_dataset( f'{self.data_savepath}data/{self.data_network}_{self.data_station}_{self.data_location}_' f'{self.network_name}_scat_coef_xarray.nc') self.scattering_coefficients_xarray = scat_coeff_xr return scat_coeff_xr
[docs] def plot_scattering_coefficients_normalisation(self, **kwargs): """ Plot the normalization of scattering coefficients. This function loads the scattering coefficients from an xarray dataset file and plots the normalization of the coefficients. The plot is saved as a PNG file in the specified directory. Parameters: self (object): The instance of the class. **kwargs (dict): Additional keyword arguments to pass to the `plt.subplots` function. """ kwargs['figsize'] = (10, 7) if kwargs.get('figsize') is None else kwargs.get('figsize') scat_vec = self.vectorize_scattering_coefficients_xarray(self.load_scattering_coefficients_xarray()) _, axs = plt.subplots(1, 1, **kwargs) for col in range(scat_vec.shape[1]): axs.plot(scat_vec[col], 'b', alpha=0.1) plt.title(f'{self.data_network}_{self.data_station}_{self.data_location}_{self.network_name}\n' 'Scattering Coefficients Normalization') plt.savefig(f'{self.data_network}_{self.data_station}_{self.data_location}_{self.network_name}_' + 'Scattering_Coefficients_Normalization.png') plt.show()
[docs] def preload_times(self): """ Preloads the times data from a numpy file and assigns it to the `data_times` attribute of the class. """ data_times = np.load(f'{self.data_savepath}data/{self.data_network}_{self.data_station}_{self.data_location}_' f'{self.network_name}_times.npy') self.data_times = data_times
[docs] def load_scat_coef_vectorized(self): if not hasattr(self, 'scat_coef_vectorized'): self.scat_coef_vectorized = np.load( f'{self.data_savepath}data/{self.data_network}_{self.data_station}_{self.data_location}_' f'{self.network_name}_scat_coef_vectorized.npy') else: print('self.scat_coef_vectorized already exist') return self.scat_coef_vectorized