Source code for scatcluster.analysis.waveform_correlations
"""Waveform Correlations Analysis module."""
import datetime
import pickle
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from obspy import UTCDateTime
from obspy.signal.cross_correlation import correlate, xcorr_max
from tqdm import tqdm
# EXAMPLE
# calculate_waveform_correlations(sort_type ='distance_filter', sort_filter=100, envelope=True)
# plot_correlation_waveforms(correlations=correlations_waveforms_distance_100, sort_type ='distance_filter',
# sort_filter=100, envelope=False)
[docs]class WaveformCorrelations:
[docs] def process_cluster_trace_correlation(self,
df_preds: pd.DataFrame,
cluster: int = 1,
sort_type: str = 'all',
sort_filter: int = None,
time_second: int = 60,
channel: str = 'HHZ',
envelope: bool = False):
"""
Calculate the waveform trace correlation for a given cluster.
Args:
df_preds (pd.DataFrame): The DataFrame containing the predictions.
cluster (int, optional): The cluster number. Defaults to 1.
sort_type (str, optional): The type of sorting to apply. Options are 'all', 'xcorr_filter', and
'distance_filter'. Defaults to 'all'.
sort_filter (int, optional): The filter value for sorting. Defaults to None.
time_second (int, optional): The time window in seconds. Defaults to 60.
channel (str, optional): The channel to use. Defaults to 'HHZ'.
envelope (bool, optional): Whether to use the envelope. Defaults to False.
Returns:
dict: A dictionary containing the centre waveform time, centre waveform, and correlations.
"""
# sort_type = ['all', 'xcorr_filter', 'distance_filter']
df_times = df_preds.loc[df_preds.predictions == cluster, ['cluster_rank', 'times_YYYYMMDD']]
time_list = df_times.sort_values(by='cluster_rank')['times_YYYYMMDD'].to_list()
if sort_type == 'distance_filter':
if sort_filter is None:
msg = ("If using \'sort_type\' => distance_filter, you need to supply \'sort_filter\' of type integer. "
"\n This will be used to take the waveforms up to the n\'th based on the Euclidean distance \n "
'from the cluster centroid.')
raise ValueError(msg)
time_list = time_list[:sort_filter + 1]
wvf0_time = time_list.pop(0)
wvf0 = self.get_waveform(start_time=str(wvf0_time), time_second=time_second, channel=channel, envelope=envelope)
if len(wvf0) < 1:
print(f' Waveform cross-correlations cannot be computed for Cluster {cluster}.\n' +
' The waveform in the centre of the cluster contains no traces. This \n' +
' is most likely a cluster of damaged/missing data.')
# return None
else:
time_per_sample = 1 / wvf0[0].stats.sampling_rate
wvf0 = wvf0.select(component='Z')[0].data
corr = []
for st_enum, st in tqdm(enumerate(time_list)):
wvf1 = self.get_waveform(start_time=str(st),
time_second=time_second,
channel=channel,
envelope=envelope)
if len(wvf1) > 0:
wvf1 = wvf1.select(component='Z')[0].data
cc = correlate(wvf0, wvf1, len(wvf0))
shift, value = xcorr_max(cc)
_sort_filter = 0
if sort_type == 'xcorr_filter':
if sort_filter > 1:
msg = ('If using \'sort_type\' => xcorr_filter, you need to \'supply sort_filter\' \n'
'of type float corresponding to the absolute maximum cross-correlation \n'
'coefficient to include in analysis.')
raise ValueError(msg)
_sort_filter = sort_filter
if abs(value) >= _sort_filter:
time_start2 = UTCDateTime(str(st)) + datetime.timedelta(seconds=-1 * shift * time_per_sample)
shifted_correlated_waveform = [-1 if value < 0 else 1] * self.get_waveform(
start_time=time_start2, envelope=envelope)[0].data
corr.append({
'cluster_rank': st_enum + 1,
'correlation_value': value,
'correlation_shift': shift,
'waveform_time': str(st),
'waveform': wvf1,
'shifted_correlated_waveform_time': str(time_start2),
'shifted_correlated_waveform': shifted_correlated_waveform
})
return {'centre_waveform_time': str(wvf0_time), 'centre_waveform': wvf0, 'correlations': corr}
[docs] def calculate_waveform_correlations(self,
df_preds: pd.DataFrame,
sort_type='distance_filter',
sort_filter=100,
time_second=60,
channel='HHZ',
envelope=False):
"""
Calculate the waveform correlations for each cluster based on the input DataFrame of predictions.
Args:
df_preds (pd.DataFrame): The DataFrame containing the predictions.
sort_type (str, optional): The type of sorting to apply. Defaults to 'distance_filter'.
sort_filter (int, optional): The filter value for sorting. Defaults to 100.
time_second (int, optional): The time window in seconds. Defaults to 60.
channel (str, optional): The channel to use. Defaults to 'HHZ'.
envelope (bool, optional): Whether to use the envelope. Defaults to False.
Returns:
dict: A dictionary containing the waveform correlations for each cluster.
"""
correlations = {}
for cluster in set(df_preds.predictions):
corr = self.process_cluster_trace_correlation(df_preds, cluster, sort_type, sort_filter, time_second,
channel, envelope)
correlations[cluster] = corr
_wvf_type = 'waveform'
if envelope:
_wvf_type = 'envelope'
save_file = (f'{self.data_savepath}data/{self.data_network}_{self.data_station}_{self.data_location}_'
f'{self.network_name}_ICA_{self.ica.n_components}_clustering_{self.ica.n_components}_{_wvf_type}_'
f'correlations_{sort_type}_{sort_filter}.pkl')
with open(save_file, 'wb') as handle:
pickle.dump(correlations, handle, protocol=pickle.HIGHEST_PROTOCOL)
return correlations
[docs] def load_correlations(self, sort_type='distance_filter', sort_filter=100, envelope=False):
"""
Load the waveform correlations from a pickle file.
Args:
sort_type (str, optional): The type of sorting to be applied to the correlations.
Defaults to 'distance_filter'.
sort_filter (int, optional): The filter to be applied to the sorted correlations. Defaults to 100.
envelope (bool, optional): Whether to use the envelope of the waveform. Defaults to False.
Returns:
dict: A dictionary containing the waveform correlations.
"""
_wvf_type = 'waveform'
if envelope:
_wvf_type = 'envelope'
save_file = (f'{self.data_savepath}data/{self.data_network}_{self.data_station}_{self.data_location}_'
f'{self.network_name}_ICA_{self.ica.n_components}_clustering_{self.ica.n_components}_{_wvf_type}_'
f'correlations_{sort_type}_{sort_filter}.pkl')
with open(save_file, 'rb') as handle:
correlations = pickle.load(handle)
return correlations
[docs] def stack_correlations(self, correlations, cluster):
"""
Calculate the stacked correlations for a given cluster.
Parameters:
correlations (dict): A dictionary containing the correlations for different clusters.
cluster (int): The cluster number.
Returns:
numpy.ndarray or None: The stacked correlations for the given cluster, or None if the cluster is not
present in the correlations' dictionary.
"""
if correlations[cluster] is None:
return None
else:
waveforms = [x['shifted_correlated_waveform'] for x in correlations[cluster]['correlations']]
return np.mean(waveforms, axis=0)
[docs] def process_waveform_correlations_stacked_waveform(self,
df_preds,
correlations,
sort_type,
sort_filter,
envelope=False):
"""
Process the waveform correlations and stack the correlated waveforms for each cluster.
Args:
df_preds (pandas.DataFrame): The DataFrame containing the predictions.
correlations (dict): A dictionary containing the correlations for different clusters.
sort_type (str): The type of sorting to apply.
sort_filter (int): The filter value for sorting.
envelope (bool, optional): Whether to use the envelope. Defaults to False.
Returns:
dict: A dictionary containing the stacked correlated waveforms for each cluster.
This function iterates over the unique clusters in the predictions DataFrame and calculates the stacked
correlated waveforms for each cluster using the `stack_correlations` method. The resulting stacked correlated
waveforms are stored in the `correlation_waveform` dictionary.
The `_wvf_type` variable is set to 'waveform' by default. If the `envelope` parameter is True, `_wvf_type` is
set to 'envelope'.
The `correlation_waveform` dictionary is then saved as a NumPy binary file using the `np.save` function. The
file name is constructed using various attributes of the instance (`self`) and the input parameters.
Finally, the `correlation_waveform` dictionary is returned.
"""
correlation_waveform = {}
for cluster in set(df_preds.predictions):
correlation_waveform[cluster] = self.stack_correlations(correlations, cluster)
_wvf_type = 'waveform'
if envelope:
_wvf_type = 'envelope'
np.save(
f'{self.data_savepath}data/{self.data_network}_{self.data_station}_{self.data_location}_'
f'{self.network_name}_ICA_{self.ica.n_components}_clustering_{self.ica.n_components}_{_wvf_type}_'
f'correlations_stacked_waveform_{sort_type}_{sort_filter}.npy', correlation_waveform)
return correlation_waveform
[docs] def plot_correlation_waveforms(self, df_preds, correlations, sort_type, sort_filter, envelope=False):
"""
Plot the correlation waveforms for each cluster.
Parameters:
df_preds (pandas.DataFrame): The DataFrame containing the predictions.
correlations (dict): A dictionary containing the correlations for different clusters.
sort_type (str): The type of sorting to apply.
sort_filter (int): The filter value for sorting.
envelope (bool, optional): Whether to use the envelope. Defaults to False.
This function plots the correlation waveforms for each cluster. It first calculates the correlation waveform
using the `process_waveform_correlations_stacked_waveform` method. Then, it creates a figure with subplots for
each cluster. If the cluster has no correlations, the subplot title is set to 'Cluster {cluster_number} -
Empty Traces'. Otherwise, it plots the shifted and corrected waveforms, the centroid waveform, and the
correlation waveform for that cluster. The subplot title includes the number of traces, the average
cross-correlation coefficient, and the type of waveform. The figure legend includes labels for the shifted and
corrected waveforms, the centroid waveform, and the correlation waveform. The figure is saved as a PNG image
with a unique file name based on the data network, station, location, network name, ICA components, clustering
method, and waveform type. The plot is displayed using `plt.show()`.
"""
correlation_waveform = self.process_waveform_correlations_stacked_waveform(df_preds, correlations, sort_type,
sort_filter, envelope)
fig, ax = plt.subplots(len(set(df_preds.predictions)), 1, figsize=(20, 10), sharex=True, sharey=True)
for clust_enum, cluster in enumerate(correlations.keys()):
if correlations[cluster] is None:
ax[clust_enum].set_title(f'Cluster {clust_enum + 1} - Empty Traces')
else:
lines = []
num_traces = len(correlations[cluster]['correlations'])
avg_xcorr = np.mean([x['correlation_value'] for x in correlations[cluster]['correlations']])
ax[clust_enum].set_title(
f'Cluster {clust_enum + 1}: Number of Traces {num_traces}: Avg. Cross-correlation Coeff. '
f'{avg_xcorr:.3}')
for x in correlations[cluster]['correlations'][:100]:
lines += ax[clust_enum].plot(x['shifted_correlated_waveform'],
color='b',
alpha=0.2,
linewidth=1,
label='shifted_correlated_waveforms')
centre_waveform = correlations[cluster]['centre_waveform']
centre_waveform_shift = np.mean(np.abs(centre_waveform)) * .5
lines += ax[clust_enum].plot(centre_waveform,
color='k',
alpha=0.3,
linewidth=1,
label='centre_waveform')
ax[clust_enum].set_ylim(
[min(centre_waveform) - centre_waveform_shift,
max(centre_waveform) + centre_waveform_shift])
lines += ax[clust_enum].plot(correlation_waveform[cluster],
color='r',
linewidth=0.5,
label='correlation_waveform')
_wvf_type = 'Waveform'
if envelope:
_wvf_type = 'Envelope'
fig.legend(lines[-3:],
[f'Shifted and corrected {_wvf_type}s', f'Centroid {_wvf_type}', f'Correlation {_wvf_type}'],
loc='upper center',
ncol=3,
bbox_to_anchor=(0.5, -0.01))
ax[0].margins(x=0)
file_name = (f'{self.data_network}_{self.data_station}_{self.data_location}_{self.network_name}_ICA_'
f'{self.ica.n_components}_clustering_{self.ica.n_components}_{_wvf_type}_Correlations_{sort_type}'
f'_{sort_filter}')
plt.suptitle(file_name)
plt.savefig(f'{self.data_savepath}figures/{file_name}.png', bbox_inches='tight')
plt.show()
[docs] def plot_correlation_frequency(self, df_preds, correlations_waveforms_distance_all):
"""
Plot the correlation frequency for each cluster.
Parameters:
df_preds (pandas.DataFrame): The DataFrame containing the predictions.
correlations_waveforms_distance_all (dict): A dictionary containing the correlations and waveforms distance
for different clusters.
This function plots the correlation frequency for each cluster. It creates a figure with subplots for each
cluster. If the cluster has no correlations, the subplot title is set to 'Cluster {cluster_number} - Empty
Traces'. Otherwise, it calculates the absolute value of the correlation_value for each correlation in the
cluster and plots a histogram of the data using seaborn. The subplot title includes the cluster number.
The figure is displayed using `plt.show()`.
"""
_, ax = plt.subplots(len(set(df_preds.predictions)), 1, figsize=(20, 10), sharex=True, sharey=True)
for clust_enum, cluster in enumerate(correlations_waveforms_distance_all.keys()):
if correlations_waveforms_distance_all[cluster] is None:
ax[clust_enum].set_title(f'Cluster {clust_enum + 1} - Empty Traces')
else:
data = [
np.abs(xcor['correlation_value'])
for xcor in correlations_waveforms_distance_all[cluster]['correlations']
]
sns.histplot(data, ax=ax[clust_enum])
ax[clust_enum].set_title(f'Cluster {clust_enum + 1} ')
plt.show()
[docs] def plot_correlation_shift(self, correlations: dict, clusters: list, within_cluster_number: int = None):
"""
Plot the correlation shift for each cluster.
Parameters:
correlations (dict): A dictionary containing the correlations for each cluster.
clusters (list): A list of clusters.
within_cluster_number (int, optional): The number of correlations to plot within each cluster.
Defaults to None.
"""
_within_cluster_number = 100 if within_cluster_number is None else within_cluster_number
_, ax = plt.subplots(2, len(clusters), figsize=(20, 10), sharex=True, sharey=True)
for cluster_enum, cluster in enumerate(clusters):
ax[0, cluster_enum].plot(correlations[cluster]['centre_waveform'], 'k')
for x_enum, x in enumerate(correlations[cluster]['correlations'][:_within_cluster_number]):
ax[1, cluster_enum].plot(x['waveform'] + ((x_enum + 1) * 0.0000001), 'b', alpha=0.2)
plt.title(f'Cluster {cluster}')
file_name = (f'{self.data_network}_{self.data_station}_{self.data_location}_{self.network_name}_ICA_'
f'{self.ica.n_components}_clustering_{self.ica.n_components}_{self._wvf_type}_Correlations')
plt.suptitle(f'{file_name}\nWaveform Correlations')
plt.savefig(f'{self.data_savepath}figures/{file_name}_Cluster_waveform_correlations.png', bbox_inches='tight')
plt.show()