From 2878076d4862d3d7b7437978881b6345766d21dc Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sat, 12 Apr 2025 16:31:18 -0400 Subject: [PATCH 01/40] move model components to sub-object --- specparam/algorithms/spectral_fit.py | 24 ++-- specparam/measures/pointwise.py | 2 +- specparam/models/base.py | 8 +- specparam/models/utils.py | 2 +- specparam/objs/metrics.py | 2 +- specparam/objs/results.py | 149 +++++++++++++++---------- specparam/plts/annotate.py | 2 +- specparam/plts/model.py | 15 +-- specparam/tests/measures/test_error.py | 13 ++- specparam/tests/measures/test_gof.py | 7 +- specparam/tests/models/test_event.py | 2 +- specparam/tests/models/test_model.py | 12 +- specparam/tests/objs/test_results.py | 7 ++ 13 files changed, 138 insertions(+), 107 deletions(-) diff --git a/specparam/algorithms/spectral_fit.py b/specparam/algorithms/spectral_fit.py index 4d76918f..a8babfa1 100644 --- a/specparam/algorithms/spectral_fit.py +++ b/specparam/algorithms/spectral_fit.py @@ -78,14 +78,6 @@ class SpectralFitAlgorithm(Algorithm): _gauss_std_limits : list of [float, float] Settings attribute: peak width limits, to use for gaussian standard deviation parameter. This attribute is computed based on `peak_width_limits` and should not be updated directly. - _spectrum_flat : 1d array - Data attribute: flattened power spectrum, with the aperiodic component removed. - _spectrum_peak_rm : 1d array - Data attribute: power spectrum, with peaks removed. - _ap_fit : 1d array - Model attribute: values of the isolated aperiodic fit. - _peak_fit : 1d array - Model attribute: values of the isolated peak fit. """ # pylint: disable=attribute-defined-outside-init @@ -153,21 +145,21 @@ def _fit(self): # Calculate the peak fit # Note: if no peaks are found, this creates a flat (all zero) peak fit - self.results._peak_fit = self.modes.periodic.func(\ + self.results.model._peak_fit = self.modes.periodic.func(\ self.data.freqs, *np.ndarray.flatten(self.results.gaussian_params_)) # Create peak-removed (but not flattened) power spectrum - self.results._spectrum_peak_rm = self.data.power_spectrum - self.results._peak_fit + self.results.model._spectrum_peak_rm = self.data.power_spectrum - self.results.model._peak_fit # Run final aperiodic fit on peak-removed power spectrum self.results.aperiodic_params_ = self._simple_ap_fit(\ - self.data.freqs, self.results._spectrum_peak_rm) - self.results._ap_fit = self.modes.aperiodic.func(\ + self.data.freqs, self.results.model._spectrum_peak_rm) + self.results.model._ap_fit = self.modes.aperiodic.func(\ self.data.freqs, *self.results.aperiodic_params_) # Create remaining model components: flatspec & full power_spectrum model fit - self.results._spectrum_flat = self.data.power_spectrum - self.results._ap_fit - self.results.modeled_spectrum_ = self.results._peak_fit + self.results._ap_fit + self.results.model._spectrum_flat = self.data.power_spectrum - self.results.model._ap_fit + self.results.model.modeled_spectrum_ = self.results.model._peak_fit + self.results.model._ap_fit ## PARAMETER UPDATES @@ -632,13 +624,13 @@ def _create_peak_params(self, gaus_params): # Collect peak parameter data if self.modes.periodic.name == 'gaussian': ## TEMP peak_params[ii] = [peak[0], - self.results.modeled_spectrum_[ind] - self.results._ap_fit[ind], + self.results.model.modeled_spectrum_[ind] - self.results.model._ap_fit[ind], peak[2] * 2] ## TEMP: if self.modes.periodic.name == 'skewnorm': peak_params[ii] = [peak[0], - self.results.modeled_spectrum_[ind] - self.results._ap_fit[ind], + self.results.model.modeled_spectrum_[ind] - self.results.model._ap_fit[ind], peak[2] * 2, peak[3]] return peak_params diff --git a/specparam/measures/pointwise.py b/specparam/measures/pointwise.py index b5553636..9b27751e 100644 --- a/specparam/measures/pointwise.py +++ b/specparam/measures/pointwise.py @@ -43,7 +43,7 @@ def compute_pointwise_error(model, plot_errors=True, return_errors=False, **plt_ raise NoModelError("No model is available to use, can not proceed.") errors = compute_pointwise_error_arr(\ - model.results.modeled_spectrum_, model.data.power_spectrum) + model.results.model.modeled_spectrum_, model.data.power_spectrum) if plot_errors: plot_spectral_error(model.data.freqs, errors, **plt_kwargs) diff --git a/specparam/models/base.py b/specparam/models/base.py index fe630d11..1b2e1626 100644 --- a/specparam/models/base.py +++ b/specparam/models/base.py @@ -84,11 +84,11 @@ def get_data(self, component='full', space='log'): output = self.data.power_spectrum if space == 'log' \ else unlog(self.data.power_spectrum) elif component == 'aperiodic': - output = self.results._spectrum_peak_rm if space == 'log' else \ - unlog(self.data.power_spectrum) / unlog(self.results._peak_fit) + output = self.results.model._spectrum_peak_rm if space == 'log' else \ + unlog(self.data.power_spectrum) / unlog(self.results.model._peak_fit) elif component == 'peak': - output = self.results._spectrum_flat if space == 'log' else \ - unlog(self.data.power_spectrum) - unlog(self.results._ap_fit) + output = self.results.model._spectrum_flat if space == 'log' else \ + unlog(self.data.power_spectrum) - unlog(self.results.model._ap_fit) else: raise ValueError('Input for component invalid.') diff --git a/specparam/models/utils.py b/specparam/models/utils.py index b3b39731..e63fd034 100644 --- a/specparam/models/utils.py +++ b/specparam/models/utils.py @@ -188,7 +188,7 @@ def average_reconstructions(group, avg_method='mean'): models = np.zeros(shape=group.data.power_spectra.shape) for ind in range(len(group.results)): - models[ind, :] = group.get_model(ind, regenerate=True).results.modeled_spectrum_ + models[ind, :] = group.get_model(ind, regenerate=True).results.model.modeled_spectrum_ avg_model = avg_funcs[avg_method](models, 0) diff --git a/specparam/objs/metrics.py b/specparam/objs/metrics.py index 9170f041..9a79f445 100644 --- a/specparam/objs/metrics.py +++ b/specparam/objs/metrics.py @@ -75,7 +75,7 @@ def compute_metric(self, data, results): for key, lfunc in self.kwargs.items(): kwargs[key] = lfunc(data, results) - self.result = self.func(data.power_spectrum, results.modeled_spectrum_, **kwargs) + self.result = self.func(data.power_spectrum, results.model.modeled_spectrum_, **kwargs) class Metrics(): diff --git a/specparam/objs/results.py b/specparam/objs/results.py index 66ea6194..df49a302 100644 --- a/specparam/objs/results.py +++ b/specparam/objs/results.py @@ -26,6 +26,90 @@ DEFAULT_METRICS = ['error_mae', 'gof_rsquared'] +class ModelComponents(): + """Object for managing model components. + + Attributes + ---------- + modeled_spectrum : 1d array + Modeled spectrum. + _spectrum_flat : 1d array + Data attribute: flattened power spectrum, with the aperiodic component removed. + _spectrum_peak_rm : 1d array + Data attribute: power spectrum, with peaks removed. + _ap_fit : 1d array + Model attribute: values of the isolated aperiodic fit. + _peak_fit : 1d array + Model attribute: values of the isolated peak fit. + """ + + def __init__(self): + """Initialize ModelComponents object.""" + + self.reset() + + + def reset(self): + """Reset model components attributes.""" + + # Data components + self._spectrum_flat = None + self._spectrum_peak_rm = None + + # Full model + self.modeled_spectrum_ = None + + # Model components + self._ap_fit = None + self._peak_fit = None + + + def get_component(self, component='full', space='log'): + """Get a model component. + + Parameters + ---------- + component : {'full', 'aperiodic', 'peak'} + Which model component to return. + 'full' - full model + 'aperiodic' - isolated aperiodic model component + 'peak' - isolated peak model component + space : {'log', 'linear'} + Which space to return the model component in. + 'log' - returns in log10 space. + 'linear' - returns in linear space. + + Returns + ------- + output : 1d array + Specified model component, in specified spacing. + + Notes + ----- + The 'space' parameter doesn't just define the spacing of the model component + values, but rather defines the space of the additive model such that + `model = aperiodic_component + peak_component`. + With space set as 'log', this combination holds in log space. + With space set as 'linear', this combination holds in linear space. + """ + + if self.modeled_spectrum_ is None: + raise NoModelError("No model fit results are available, can not proceed.") + assert space in ['linear', 'log'], "Input for 'space' invalid." + + if component == 'full': + output = self.modeled_spectrum_ if space == 'log' else unlog(self.modeled_spectrum_) + elif component == 'aperiodic': + output = self._ap_fit if space == 'log' else unlog(self._ap_fit) + elif component == 'peak': + output = self._peak_fit if space == 'log' else \ + unlog(self.modeled_spectrum_) - unlog(self._ap_fit) + else: + raise ValueError('Input for component invalid.') + + return output + + class Results(): """Object for managing results - base / 1D version. @@ -48,6 +132,8 @@ def __init__(self, modes=None, metrics=None, bands=None): self.add_bands(bands) self.add_metrics(metrics) + self.model = ModelComponents() + # Initialize results attributes self._reset_results(True) self._fields = RESULTS_FIELDS @@ -158,52 +244,6 @@ def get_results(self): return results - def get_component(self, component='full', space='log'): - """Get a model component. - - Parameters - ---------- - component : {'full', 'aperiodic', 'peak'} - Which model component to return. - 'full' - full model - 'aperiodic' - isolated aperiodic model component - 'peak' - isolated peak model component - space : {'log', 'linear'} - Which space to return the model component in. - 'log' - returns in log10 space. - 'linear' - returns in linear space. - - Returns - ------- - output : 1d array - Specified model component, in specified spacing. - - Notes - ----- - The 'space' parameter doesn't just define the spacing of the model component - values, but rather defines the space of the additive model such that - `model = aperiodic_component + peak_component`. - With space set as 'log', this combination holds in log space. - With space set as 'linear', this combination holds in linear space. - """ - - if not self.has_model: - raise NoModelError("No model fit results are available, can not proceed.") - assert space in ['linear', 'log'], "Input for 'space' invalid." - - if component == 'full': - output = self.modeled_spectrum_ if space == 'log' else unlog(self.modeled_spectrum_) - elif component == 'aperiodic': - output = self._ap_fit if space == 'log' else unlog(self._ap_fit) - elif component == 'peak': - output = self._peak_fit if space == 'log' else \ - unlog(self.modeled_spectrum_) - unlog(self._ap_fit) - else: - raise ValueError('Input for component invalid.') - - return output - - def get_params(self, name, field=None): """Return model fit parameters for specified feature(s). @@ -277,14 +317,8 @@ def _reset_results(self, clear_results=False): self.gaussian_params_ = np.nan self.peak_params_ = np.nan - # Data components - self._spectrum_flat = None - self._spectrum_peak_rm = None - - # Modeled spectrum components - self.modeled_spectrum_ = None - self._ap_fit = None - self._peak_fit = None + # Reset model components + self.model.reset() def _regenerate_model(self, freqs): @@ -296,10 +330,9 @@ def _regenerate_model(self, freqs): Frequency values for the power_spectrum, in linear scale. """ - self.modeled_spectrum_, self._peak_fit, self._ap_fit = gen_model(freqs, \ - self.modes.aperiodic, self.aperiodic_params_, - self.modes.periodic, self.gaussian_params_, - return_components=True) + self.model.modeled_spectrum_, self.model._peak_fit, self.model._ap_fit = \ + gen_model(freqs, self.modes.aperiodic, self.aperiodic_params_, + self.modes.periodic, self.gaussian_params_, return_components=True) @replace_docstring_sections([docs_get_section(Results.__doc__, 'Parameters')]) diff --git a/specparam/plts/annotate.py b/specparam/plts/annotate.py index 5a8df135..fbb92a59 100644 --- a/specparam/plts/annotate.py +++ b/specparam/plts/annotate.py @@ -183,7 +183,7 @@ def plot_annotated_model(model, plt_log=False, annotate_peaks=True, # Annotate Aperiodic Offset # Add a line to indicate offset, without adjusting plot limits below it ax.set_autoscaley_on(False) - ax.plot([freqs[0], freqs[0]], [ax.get_ylim()[0], model.results.modeled_spectrum_[0]], + ax.plot([freqs[0], freqs[0]], [ax.get_ylim()[0], model.results.model.modeled_spectrum_[0]], color=PLT_COLORS['aperiodic'], linewidth=lw2, alpha=0.5) ax.annotate('Offset', xy=(freqs[0]+bug_buff, model.data.power_spectrum[0]-y_buff1), diff --git a/specparam/plts/model.py b/specparam/plts/model.py index 6a2ca633..44695f92 100644 --- a/specparam/plts/model.py +++ b/specparam/plts/model.py @@ -8,7 +8,6 @@ import numpy as np from specparam.modutils.dependencies import safe_import, check_dependency -from specparam.sim.gen import gen_periodic from specparam.utils.select import nearest_ind from specparam.utils.spectral import trim_spectrum from specparam.measures.params import compute_fwhm @@ -87,7 +86,7 @@ def plot_model(model, plot_peaks=None, plot_aperiodic=True, freqs=None, power_sp model_defaults = {'color' : PLT_COLORS['model'], 'linewidth' : 3.0, 'alpha' : 0.5, 'label' : 'Full Model Fit' if add_legend else None} model_kwargs = check_plot_kwargs(model_kwargs, model_defaults) - plot_spectra(model.data.freqs, model.results.modeled_spectrum_, + plot_spectra(model.data.freqs, model.results.model.modeled_spectrum_, log_freqs, log_powers, ax=ax, **model_kwargs) # Plot the aperiodic component of the model fit @@ -96,7 +95,7 @@ def plot_model(model, plot_peaks=None, plot_aperiodic=True, freqs=None, power_sp 'alpha' : 0.5, 'linestyle' : 'dashed', 'label' : 'Aperiodic Fit' if add_legend else None} aperiodic_kwargs = check_plot_kwargs(aperiodic_kwargs, aperiodic_defaults) - plot_spectra(model.data.freqs, model.results._ap_fit, + plot_spectra(model.data.freqs, model.results.model._ap_fit, log_freqs, log_powers, ax=ax, **aperiodic_kwargs) # Plot the periodic components of the model fit @@ -172,10 +171,9 @@ def _add_peaks_shade(model, plt_log, ax, **plot_kwargs): for peak in model.results.gaussian_params_: peak_freqs = np.log10(model.data.freqs) if plt_log else model.data.freqs - #peak_line = model.results._ap_fit + gen_periodic(model.data.freqs, peak) - peak_line = model.results._ap_fit + model.modes.periodic.func(model.data.freqs, *peak) + peak_line = model.results.model._ap_fit + model.modes.periodic.func(model.data.freqs, *peak) - ax.fill_between(peak_freqs, peak_line, model.results._ap_fit, **plot_kwargs) + ax.fill_between(peak_freqs, peak_line, model.results.model._ap_fit, **plot_kwargs) def _add_peaks_dot(model, plt_log, ax, **plot_kwargs): @@ -198,7 +196,7 @@ def _add_peaks_dot(model, plt_log, ax, **plot_kwargs): for peak in model.results.peak_params_: - ap_point = np.interp(peak[0], model.data.freqs, model.results._ap_fit) + ap_point = np.interp(peak[0], model.data.freqs, model.results.model._ap_fit) freq_point = np.log10(peak[0]) if plt_log else peak[0] # Add the line from the aperiodic fit up the tip of the peak @@ -232,8 +230,7 @@ def _add_peaks_outline(model, plt_log, ax, **plot_kwargs): peak_range = [peak[0] - peak[2]*3, peak[0] + peak[2]*3] # Generate a peak reconstruction for each peak, and trim to desired range - #peak_line = model.results._ap_fit + gen_periodic(model.data.freqs, peak) - peak_line = model.results._ap_fit + model.modes.periodic.func(model.data.freqs, *peak) + peak_line = model.results.model._ap_fit + model.modes.periodic.func(model.data.freqs, *peak) peak_freqs, peak_line = trim_spectrum(model.data.freqs, peak_line, peak_range) # Plot the peak outline diff --git a/specparam/tests/measures/test_error.py b/specparam/tests/measures/test_error.py index 7b3c0650..6864e1bd 100644 --- a/specparam/tests/measures/test_error.py +++ b/specparam/tests/measures/test_error.py @@ -7,26 +7,29 @@ def test_compute_mean_abs_error(tfm): - error = compute_mean_abs_error(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) + error = compute_mean_abs_error(tfm.data.power_spectrum, tfm.results.model.modeled_spectrum_) assert isinstance(error, float) def test_compute_mean_squared_error(tfm): - error = compute_mean_squared_error(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) + error = compute_mean_squared_error(tfm.data.power_spectrum, + tfm.results.model.modeled_spectrum_) assert isinstance(error, float) def test_compute_root_mean_squared_error(tfm): - error = compute_root_mean_squared_error(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) + error = compute_root_mean_squared_error(tfm.data.power_spectrum, + tfm.results.model.modeled_spectrum_) assert isinstance(error, float) def test_compute_median_abs_error(tfm): - error = compute_median_abs_error(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) + error = compute_median_abs_error(tfm.data.power_spectrum, + tfm.results.model.modeled_spectrum_) assert isinstance(error, float) def test_compute_error(tfm): for metric in ['mae', 'mse', 'rmse', 'medae']: - error = compute_error(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) + error = compute_error(tfm.data.power_spectrum, tfm.results.model.modeled_spectrum_) assert isinstance(error, float) diff --git a/specparam/tests/measures/test_gof.py b/specparam/tests/measures/test_gof.py index a1990749..678c3ce0 100644 --- a/specparam/tests/measures/test_gof.py +++ b/specparam/tests/measures/test_gof.py @@ -7,16 +7,17 @@ def test_compute_r_squared(tfm): - r_squared = compute_r_squared(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) + r_squared = compute_r_squared(tfm.data.power_spectrum, tfm.results.model.modeled_spectrum_) assert isinstance(r_squared, float) def test_compute_adj_r_squared(tfm): - r_squared = compute_adj_r_squared(tfm.data.power_spectrum, tfm.results.modeled_spectrum_, 5) + r_squared = compute_adj_r_squared(tfm.data.power_spectrum, + tfm.results.model.modeled_spectrum_, 5) assert isinstance(r_squared, float) def test_compute_gof(tfm): for metric in ['r_squared', 'adj_r_squared']: - gof = compute_gof(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) + gof = compute_gof(tfm.data.power_spectrum, tfm.results.model.modeled_spectrum_) assert isinstance(gof, float) diff --git a/specparam/tests/models/test_event.py b/specparam/tests/models/test_event.py index 149827be..d65f087c 100644 --- a/specparam/tests/models/test_event.py +++ b/specparam/tests/models/test_event.py @@ -138,7 +138,7 @@ def test_event_get_model(tfe): assert tfm1 assert tfm1.data.has_data assert tfm1.results.has_model - assert np.all(tfm1.results.modeled_spectrum_) + assert np.all(tfm1.results.model.modeled_spectrum_) def test_event_get_params(tfe): diff --git a/specparam/tests/models/test_model.py b/specparam/tests/models/test_model.py index 23210a6e..6d2b5ec9 100644 --- a/specparam/tests/models/test_model.py +++ b/specparam/tests/models/test_model.py @@ -115,7 +115,7 @@ def test_fit_default_metrics(): # Hack fake data with known properties: total error magnitude 2 tfm.data.power_spectrum = np.array([1, 2, 3, 4, 5]) - tfm.results.modeled_spectrum_ = np.array([1, 2, 5, 4, 5]) + tfm.results.model.modeled_spectrum_ = np.array([1, 2, 5, 4, 5]) # Check default goodness of fit and error measures tfm.results.metrics.compute_metrics(tfm.data, tfm.results) @@ -296,7 +296,7 @@ def test_get_component(tfm): for comp in ['full', 'aperiodic', 'peak']: for space in ['log', 'linear']: - assert isinstance(tfm.results.get_component(comp, space), np.ndarray) + assert isinstance(tfm.results.model.get_component(comp, space), np.ndarray) def test_prints(tfm): """Test methods that print (alias and pass through methods). @@ -326,13 +326,11 @@ def test_resets(): for field in tfm.data._fields: assert getattr(tfm.data, field) is None - model_components = ['modeled_spectrum_', '_spectrum_flat', - '_spectrum_peak_rm', '_ap_fit', '_peak_fit'] - for field in model_components: - assert getattr(tfm.results, field) is None + for key, value in tfm.results.model.__dict__.items(): + assert value is None for field in tfm.results._fields: assert np.all(np.isnan(getattr(tfm.results, field))) - assert tfm.data.freqs is None and tfm.results.modeled_spectrum_ is None + assert tfm.data.freqs is None and tfm.results.model.modeled_spectrum_ is None def test_report(skip_if_no_mpl): """Check that running the top level model method runs.""" diff --git a/specparam/tests/objs/test_results.py b/specparam/tests/objs/test_results.py index d2317590..13230f41 100644 --- a/specparam/tests/objs/test_results.py +++ b/specparam/tests/objs/test_results.py @@ -5,6 +5,13 @@ ################################################################################################### ################################################################################################### +## ModelComponents object + +def test_model_components(): + + mc = ModelComponents() + assert mc + ## 1D results object def test_results(): From 005718803eb88a960bc4cdb453dba2f28cbc8d13 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sat, 12 Apr 2025 16:35:49 -0400 Subject: [PATCH 02/40] update BaseModel docs --- specparam/models/base.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/specparam/models/base.py b/specparam/models/base.py index 1b2e1626..9af02e6c 100644 --- a/specparam/models/base.py +++ b/specparam/models/base.py @@ -11,7 +11,24 @@ ################################################################################################### class BaseModel(): - """Define BaseModel object.""" + """Define BaseModel object. + + Parameters + ---------- + aperiodic_mode : Mode or str + Mode for aperiodic component, or string specifying which mode to use. + periodic_mode : Mode or str + Mode for periodic component, or string specifying which mode to use. + verbose : bool + Whether to print out updates from the object. + + Attributes + ---------- + modes : Modes + Fit modes definitions. + verbose : bool + Verbosity status. + """ def __init__(self, aperiodic_mode, periodic_mode, verbose): """Initialize object.""" From 38b6734b27ea398e193b1fd1933657838422eb56 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sat, 12 Apr 2025 16:44:50 -0400 Subject: [PATCH 03/40] modeled_spectrum_ -> modeled_spectrum --- specparam/algorithms/spectral_fit.py | 8 ++++---- specparam/measures/pointwise.py | 2 +- specparam/models/group.py | 2 +- specparam/models/utils.py | 2 +- specparam/objs/metrics.py | 2 +- specparam/objs/results.py | 18 +++++++++--------- specparam/plts/annotate.py | 2 +- specparam/plts/model.py | 2 +- specparam/tests/measures/test_error.py | 10 +++++----- specparam/tests/measures/test_gof.py | 6 +++--- specparam/tests/models/test_event.py | 2 +- specparam/tests/models/test_model.py | 4 ++-- 12 files changed, 30 insertions(+), 30 deletions(-) diff --git a/specparam/algorithms/spectral_fit.py b/specparam/algorithms/spectral_fit.py index a8babfa1..cc86dae5 100644 --- a/specparam/algorithms/spectral_fit.py +++ b/specparam/algorithms/spectral_fit.py @@ -159,7 +159,7 @@ def _fit(self): # Create remaining model components: flatspec & full power_spectrum model fit self.results.model._spectrum_flat = self.data.power_spectrum - self.results.model._ap_fit - self.results.model.modeled_spectrum_ = self.results.model._peak_fit + self.results.model._ap_fit + self.results.model.modeled_spectrum = self.results.model._peak_fit + self.results.model._ap_fit ## PARAMETER UPDATES @@ -611,7 +611,7 @@ def _create_peak_params(self, gaus_params): 'bandwidth' of the peak, as opposed to the gaussian parameter, which is 1-sided. Performing this conversion requires that the model has been run, - with `freqs`, `modeled_spectrum_` and `_ap_fit` all required to be available. + with `freqs`, `modeled_spectrum` and `_ap_fit` all required to be available. """ peak_params = np.empty((len(gaus_params), self.modes.periodic.n_params)) @@ -624,13 +624,13 @@ def _create_peak_params(self, gaus_params): # Collect peak parameter data if self.modes.periodic.name == 'gaussian': ## TEMP peak_params[ii] = [peak[0], - self.results.model.modeled_spectrum_[ind] - self.results.model._ap_fit[ind], + self.results.model.modeled_spectrum[ind] - self.results.model._ap_fit[ind], peak[2] * 2] ## TEMP: if self.modes.periodic.name == 'skewnorm': peak_params[ii] = [peak[0], - self.results.model.modeled_spectrum_[ind] - self.results.model._ap_fit[ind], + self.results.model.modeled_spectrum[ind] - self.results.model._ap_fit[ind], peak[2] * 2, peak[3]] return peak_params diff --git a/specparam/measures/pointwise.py b/specparam/measures/pointwise.py index 9b27751e..75be4d33 100644 --- a/specparam/measures/pointwise.py +++ b/specparam/measures/pointwise.py @@ -43,7 +43,7 @@ def compute_pointwise_error(model, plot_errors=True, return_errors=False, **plt_ raise NoModelError("No model is available to use, can not proceed.") errors = compute_pointwise_error_arr(\ - model.results.model.modeled_spectrum_, model.data.power_spectrum) + model.results.model.modeled_spectrum, model.data.power_spectrum) if plot_errors: plot_spectral_error(model.data.freqs, errors, **plt_kwargs) diff --git a/specparam/models/group.py b/specparam/models/group.py index 40e62172..cecaa95d 100644 --- a/specparam/models/group.py +++ b/specparam/models/group.py @@ -48,7 +48,7 @@ class SpectralGroupModel(SpectralModel): ----- % copied in from SpectralModel object - The group object inherits from the model object. As such it also has data - attributes (`power_spectrum` & `modeled_spectrum_`), and parameter attributes + attributes (`power_spectrum` & `modeled_spectrum`), and parameter attributes (`aperiodic_params_`, `peak_params_`, `gaussian_params_`, `r_squared_`, `error_`) which are defined in the context of individual model fits. These attributes are used during the fitting process, but in the group context do not store results diff --git a/specparam/models/utils.py b/specparam/models/utils.py index e63fd034..dfe4bd33 100644 --- a/specparam/models/utils.py +++ b/specparam/models/utils.py @@ -188,7 +188,7 @@ def average_reconstructions(group, avg_method='mean'): models = np.zeros(shape=group.data.power_spectra.shape) for ind in range(len(group.results)): - models[ind, :] = group.get_model(ind, regenerate=True).results.model.modeled_spectrum_ + models[ind, :] = group.get_model(ind, regenerate=True).results.model.modeled_spectrum avg_model = avg_funcs[avg_method](models, 0) diff --git a/specparam/objs/metrics.py b/specparam/objs/metrics.py index 9a79f445..f2cc6295 100644 --- a/specparam/objs/metrics.py +++ b/specparam/objs/metrics.py @@ -75,7 +75,7 @@ def compute_metric(self, data, results): for key, lfunc in self.kwargs.items(): kwargs[key] = lfunc(data, results) - self.result = self.func(data.power_spectrum, results.model.modeled_spectrum_, **kwargs) + self.result = self.func(data.power_spectrum, results.model.modeled_spectrum, **kwargs) class Metrics(): diff --git a/specparam/objs/results.py b/specparam/objs/results.py index df49a302..32d095d5 100644 --- a/specparam/objs/results.py +++ b/specparam/objs/results.py @@ -52,17 +52,17 @@ def __init__(self): def reset(self): """Reset model components attributes.""" - # Data components - self._spectrum_flat = None - self._spectrum_peak_rm = None - # Full model - self.modeled_spectrum_ = None + self.modeled_spectrum = None # Model components self._ap_fit = None self._peak_fit = None + # Data components + self._spectrum_flat = None + self._spectrum_peak_rm = None + def get_component(self, component='full', space='log'): """Get a model component. @@ -93,17 +93,17 @@ def get_component(self, component='full', space='log'): With space set as 'linear', this combination holds in linear space. """ - if self.modeled_spectrum_ is None: + if self.modeled_spectrum is None: raise NoModelError("No model fit results are available, can not proceed.") assert space in ['linear', 'log'], "Input for 'space' invalid." if component == 'full': - output = self.modeled_spectrum_ if space == 'log' else unlog(self.modeled_spectrum_) + output = self.modeled_spectrum if space == 'log' else unlog(self.modeled_spectrum) elif component == 'aperiodic': output = self._ap_fit if space == 'log' else unlog(self._ap_fit) elif component == 'peak': output = self._peak_fit if space == 'log' else \ - unlog(self.modeled_spectrum_) - unlog(self._ap_fit) + unlog(self.modeled_spectrum) - unlog(self._ap_fit) else: raise ValueError('Input for component invalid.') @@ -330,7 +330,7 @@ def _regenerate_model(self, freqs): Frequency values for the power_spectrum, in linear scale. """ - self.model.modeled_spectrum_, self.model._peak_fit, self.model._ap_fit = \ + self.model.modeled_spectrum, self.model._peak_fit, self.model._ap_fit = \ gen_model(freqs, self.modes.aperiodic, self.aperiodic_params_, self.modes.periodic, self.gaussian_params_, return_components=True) diff --git a/specparam/plts/annotate.py b/specparam/plts/annotate.py index fbb92a59..bbc889de 100644 --- a/specparam/plts/annotate.py +++ b/specparam/plts/annotate.py @@ -183,7 +183,7 @@ def plot_annotated_model(model, plt_log=False, annotate_peaks=True, # Annotate Aperiodic Offset # Add a line to indicate offset, without adjusting plot limits below it ax.set_autoscaley_on(False) - ax.plot([freqs[0], freqs[0]], [ax.get_ylim()[0], model.results.model.modeled_spectrum_[0]], + ax.plot([freqs[0], freqs[0]], [ax.get_ylim()[0], model.results.model.modeled_spectrum[0]], color=PLT_COLORS['aperiodic'], linewidth=lw2, alpha=0.5) ax.annotate('Offset', xy=(freqs[0]+bug_buff, model.data.power_spectrum[0]-y_buff1), diff --git a/specparam/plts/model.py b/specparam/plts/model.py index 44695f92..e19e03ae 100644 --- a/specparam/plts/model.py +++ b/specparam/plts/model.py @@ -86,7 +86,7 @@ def plot_model(model, plot_peaks=None, plot_aperiodic=True, freqs=None, power_sp model_defaults = {'color' : PLT_COLORS['model'], 'linewidth' : 3.0, 'alpha' : 0.5, 'label' : 'Full Model Fit' if add_legend else None} model_kwargs = check_plot_kwargs(model_kwargs, model_defaults) - plot_spectra(model.data.freqs, model.results.model.modeled_spectrum_, + plot_spectra(model.data.freqs, model.results.model.modeled_spectrum, log_freqs, log_powers, ax=ax, **model_kwargs) # Plot the aperiodic component of the model fit diff --git a/specparam/tests/measures/test_error.py b/specparam/tests/measures/test_error.py index 6864e1bd..34bed15a 100644 --- a/specparam/tests/measures/test_error.py +++ b/specparam/tests/measures/test_error.py @@ -7,29 +7,29 @@ def test_compute_mean_abs_error(tfm): - error = compute_mean_abs_error(tfm.data.power_spectrum, tfm.results.model.modeled_spectrum_) + error = compute_mean_abs_error(tfm.data.power_spectrum, tfm.results.model.modeled_spectrum) assert isinstance(error, float) def test_compute_mean_squared_error(tfm): error = compute_mean_squared_error(tfm.data.power_spectrum, - tfm.results.model.modeled_spectrum_) + tfm.results.model.modeled_spectrum) assert isinstance(error, float) def test_compute_root_mean_squared_error(tfm): error = compute_root_mean_squared_error(tfm.data.power_spectrum, - tfm.results.model.modeled_spectrum_) + tfm.results.model.modeled_spectrum) assert isinstance(error, float) def test_compute_median_abs_error(tfm): error = compute_median_abs_error(tfm.data.power_spectrum, - tfm.results.model.modeled_spectrum_) + tfm.results.model.modeled_spectrum) assert isinstance(error, float) def test_compute_error(tfm): for metric in ['mae', 'mse', 'rmse', 'medae']: - error = compute_error(tfm.data.power_spectrum, tfm.results.model.modeled_spectrum_) + error = compute_error(tfm.data.power_spectrum, tfm.results.model.modeled_spectrum) assert isinstance(error, float) diff --git a/specparam/tests/measures/test_gof.py b/specparam/tests/measures/test_gof.py index 678c3ce0..6b54a4ee 100644 --- a/specparam/tests/measures/test_gof.py +++ b/specparam/tests/measures/test_gof.py @@ -7,17 +7,17 @@ def test_compute_r_squared(tfm): - r_squared = compute_r_squared(tfm.data.power_spectrum, tfm.results.model.modeled_spectrum_) + r_squared = compute_r_squared(tfm.data.power_spectrum, tfm.results.model.modeled_spectrum) assert isinstance(r_squared, float) def test_compute_adj_r_squared(tfm): r_squared = compute_adj_r_squared(tfm.data.power_spectrum, - tfm.results.model.modeled_spectrum_, 5) + tfm.results.model.modeled_spectrum, 5) assert isinstance(r_squared, float) def test_compute_gof(tfm): for metric in ['r_squared', 'adj_r_squared']: - gof = compute_gof(tfm.data.power_spectrum, tfm.results.model.modeled_spectrum_) + gof = compute_gof(tfm.data.power_spectrum, tfm.results.model.modeled_spectrum) assert isinstance(gof, float) diff --git a/specparam/tests/models/test_event.py b/specparam/tests/models/test_event.py index d65f087c..83ba2b39 100644 --- a/specparam/tests/models/test_event.py +++ b/specparam/tests/models/test_event.py @@ -138,7 +138,7 @@ def test_event_get_model(tfe): assert tfm1 assert tfm1.data.has_data assert tfm1.results.has_model - assert np.all(tfm1.results.model.modeled_spectrum_) + assert np.all(tfm1.results.model.modeled_spectrum) def test_event_get_params(tfe): diff --git a/specparam/tests/models/test_model.py b/specparam/tests/models/test_model.py index 6d2b5ec9..e0bb7523 100644 --- a/specparam/tests/models/test_model.py +++ b/specparam/tests/models/test_model.py @@ -115,7 +115,7 @@ def test_fit_default_metrics(): # Hack fake data with known properties: total error magnitude 2 tfm.data.power_spectrum = np.array([1, 2, 3, 4, 5]) - tfm.results.model.modeled_spectrum_ = np.array([1, 2, 5, 4, 5]) + tfm.results.model.modeled_spectrum = np.array([1, 2, 5, 4, 5]) # Check default goodness of fit and error measures tfm.results.metrics.compute_metrics(tfm.data, tfm.results) @@ -330,7 +330,7 @@ def test_resets(): assert value is None for field in tfm.results._fields: assert np.all(np.isnan(getattr(tfm.results, field))) - assert tfm.data.freqs is None and tfm.results.model.modeled_spectrum_ is None + assert tfm.data.freqs is None and tfm.results.model.modeled_spectrum is None def test_report(skip_if_no_mpl): """Check that running the top level model method runs.""" From 9bb23311d4cf9d8f8fbbc5947548eecec3726dae Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sat, 12 Apr 2025 16:50:01 -0400 Subject: [PATCH 04/40] move ModelComponents --- specparam/objs/components.py | 89 +++++++++++++++++++++++++ specparam/objs/data.py | 2 +- specparam/objs/results.py | 88 +----------------------- specparam/tests/objs/test_components.py | 13 ++++ specparam/tests/objs/test_data.py | 2 +- specparam/tests/objs/test_results.py | 9 +-- 6 files changed, 107 insertions(+), 96 deletions(-) create mode 100644 specparam/objs/components.py create mode 100644 specparam/tests/objs/test_components.py diff --git a/specparam/objs/components.py b/specparam/objs/components.py new file mode 100644 index 00000000..dca5cc2e --- /dev/null +++ b/specparam/objs/components.py @@ -0,0 +1,89 @@ +"""Define model components object.""" + +from specparam.utils.array import unlog + +################################################################################################### +################################################################################################### + +class ModelComponents(): + """Object for managing model components. + + Attributes + ---------- + modeled_spectrum : 1d array + Modeled spectrum. + _spectrum_flat : 1d array + Data attribute: flattened power spectrum, with the aperiodic component removed. + _spectrum_peak_rm : 1d array + Data attribute: power spectrum, with peaks removed. + _ap_fit : 1d array + Model attribute: values of the isolated aperiodic fit. + _peak_fit : 1d array + Model attribute: values of the isolated peak fit. + """ + + def __init__(self): + """Initialize ModelComponents object.""" + + self.reset() + + + def reset(self): + """Reset model components attributes.""" + + # Full model + self.modeled_spectrum = None + + # Model components + self._ap_fit = None + self._peak_fit = None + + # Data components + self._spectrum_flat = None + self._spectrum_peak_rm = None + + + def get_component(self, component='full', space='log'): + """Get a model component. + + Parameters + ---------- + component : {'full', 'aperiodic', 'peak'} + Which model component to return. + 'full' - full model + 'aperiodic' - isolated aperiodic model component + 'peak' - isolated peak model component + space : {'log', 'linear'} + Which space to return the model component in. + 'log' - returns in log10 space. + 'linear' - returns in linear space. + + Returns + ------- + output : 1d array + Specified model component, in specified spacing. + + Notes + ----- + The 'space' parameter doesn't just define the spacing of the model component + values, but rather defines the space of the additive model such that + `model = aperiodic_component + peak_component`. + With space set as 'log', this combination holds in log space. + With space set as 'linear', this combination holds in linear space. + """ + + if self.modeled_spectrum is None: + raise NoModelError("No model fit results are available, can not proceed.") + assert space in ['linear', 'log'], "Input for 'space' invalid." + + if component == 'full': + output = self.modeled_spectrum if space == 'log' else unlog(self.modeled_spectrum) + elif component == 'aperiodic': + output = self._ap_fit if space == 'log' else unlog(self._ap_fit) + elif component == 'peak': + output = self._peak_fit if space == 'log' else \ + unlog(self.modeled_spectrum) - unlog(self._ap_fit) + else: + raise ValueError('Input for component invalid.') + + return output diff --git a/specparam/objs/data.py b/specparam/objs/data.py index e7d593dc..094fe653 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -1,4 +1,4 @@ -"""Define base data objects.""" +"""Define data objects.""" from functools import wraps diff --git a/specparam/objs/results.py b/specparam/objs/results.py index 32d095d5..1528946d 100644 --- a/specparam/objs/results.py +++ b/specparam/objs/results.py @@ -1,4 +1,4 @@ -"""Define base results objects.""" +"""Define results objects.""" from copy import deepcopy from itertools import repeat @@ -7,8 +7,8 @@ from specparam.bands.bands import check_bands from specparam.objs.metrics import Metrics +from specparam.objs.components import ModelComponents from specparam.measures.metrics import METRICS -from specparam.utils.array import unlog from specparam.utils.checks import check_inds, check_array_dim from specparam.modutils.errors import NoModelError from specparam.modutils.docs import docs_get_section, replace_docstring_sections @@ -26,90 +26,6 @@ DEFAULT_METRICS = ['error_mae', 'gof_rsquared'] -class ModelComponents(): - """Object for managing model components. - - Attributes - ---------- - modeled_spectrum : 1d array - Modeled spectrum. - _spectrum_flat : 1d array - Data attribute: flattened power spectrum, with the aperiodic component removed. - _spectrum_peak_rm : 1d array - Data attribute: power spectrum, with peaks removed. - _ap_fit : 1d array - Model attribute: values of the isolated aperiodic fit. - _peak_fit : 1d array - Model attribute: values of the isolated peak fit. - """ - - def __init__(self): - """Initialize ModelComponents object.""" - - self.reset() - - - def reset(self): - """Reset model components attributes.""" - - # Full model - self.modeled_spectrum = None - - # Model components - self._ap_fit = None - self._peak_fit = None - - # Data components - self._spectrum_flat = None - self._spectrum_peak_rm = None - - - def get_component(self, component='full', space='log'): - """Get a model component. - - Parameters - ---------- - component : {'full', 'aperiodic', 'peak'} - Which model component to return. - 'full' - full model - 'aperiodic' - isolated aperiodic model component - 'peak' - isolated peak model component - space : {'log', 'linear'} - Which space to return the model component in. - 'log' - returns in log10 space. - 'linear' - returns in linear space. - - Returns - ------- - output : 1d array - Specified model component, in specified spacing. - - Notes - ----- - The 'space' parameter doesn't just define the spacing of the model component - values, but rather defines the space of the additive model such that - `model = aperiodic_component + peak_component`. - With space set as 'log', this combination holds in log space. - With space set as 'linear', this combination holds in linear space. - """ - - if self.modeled_spectrum is None: - raise NoModelError("No model fit results are available, can not proceed.") - assert space in ['linear', 'log'], "Input for 'space' invalid." - - if component == 'full': - output = self.modeled_spectrum if space == 'log' else unlog(self.modeled_spectrum) - elif component == 'aperiodic': - output = self._ap_fit if space == 'log' else unlog(self._ap_fit) - elif component == 'peak': - output = self._peak_fit if space == 'log' else \ - unlog(self.modeled_spectrum) - unlog(self._ap_fit) - else: - raise ValueError('Input for component invalid.') - - return output - - class Results(): """Object for managing results - base / 1D version. diff --git a/specparam/tests/objs/test_components.py b/specparam/tests/objs/test_components.py new file mode 100644 index 00000000..b9fa3d32 --- /dev/null +++ b/specparam/tests/objs/test_components.py @@ -0,0 +1,13 @@ +"""Tests for specparam.objs.components.""" + +from specparam.objs.components import * + +################################################################################################### +################################################################################################### + +## ModelComponents object + +def test_model_components(): + + mc = ModelComponents() + assert mc diff --git a/specparam/tests/objs/test_data.py b/specparam/tests/objs/test_data.py index 87aeab46..42982fb8 100644 --- a/specparam/tests/objs/test_data.py +++ b/specparam/tests/objs/test_data.py @@ -1,4 +1,4 @@ -"""Tests for specparam.objs.data, including the data object and it's methods.""" +"""Tests for specparam.objs.data.""" from specparam.data import SpectrumMetaData, ModelChecks diff --git a/specparam/tests/objs/test_results.py b/specparam/tests/objs/test_results.py index 13230f41..e5ecb03e 100644 --- a/specparam/tests/objs/test_results.py +++ b/specparam/tests/objs/test_results.py @@ -1,17 +1,10 @@ -"""Tests for specparam.objs.results, including the data object and it's methods.""" +"""Tests for specparam.objs.results.""" from specparam.objs.results import * ################################################################################################### ################################################################################################### -## ModelComponents object - -def test_model_components(): - - mc = ModelComponents() - assert mc - ## 1D results object def test_results(): From 3b2c2c0631423051c2ac46f7d83c80ac6318c448 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sat, 12 Apr 2025 16:56:59 -0400 Subject: [PATCH 05/40] add ModelParameters object --- specparam/objs/params.py | 51 +++++++++++++++++++++++++++++ specparam/tests/objs/test_params.py | 13 ++++++++ 2 files changed, 64 insertions(+) create mode 100644 specparam/objs/params.py create mode 100644 specparam/tests/objs/test_params.py diff --git a/specparam/objs/params.py b/specparam/objs/params.py new file mode 100644 index 00000000..aeb4ddbf --- /dev/null +++ b/specparam/objs/params.py @@ -0,0 +1,51 @@ +"""Define model parameters object.""" + +import numpy as np + +################################################################################################### +################################################################################################### + +class ModelParameters(): + """Object to manage model fit parameters. + + Parameters + ---------- + modes : Modes + Fit modes defintion. + If provided, used to initialize parameter arrays to correct sizes. + + Attributes + ---------- + aperiodic : 1d array + Aperiodic parameters of the model fit. + peak : 1d array + Peak parameters of the model fit. + gaussian : 1d array + Gaussian parameters of the model fit. + """ + + def __init__(self, modes=None): + """Initialize ModelParameters object.""" + + self.aperiodic = np.nan + self.peak = np.nan + self.gaussian = np.nan + + self.reset(modes) + + def reset(self, modes=None): + """Reset parameters.""" + + # Aperiodic parameters + if modes: + self.aperiodic = np.array([np.nan] * modes.aperiodic.n_params) + else: + self.aperiodic = np.nan + + # Periodic parameters + if modes: + self.gaussian = np.empty([0, modes.periodic.n_params]) + self.peak = np.empty([0, modes.periodic.n_params]) + else: + self.gaussian = np.nan + self.peak = np.nan diff --git a/specparam/tests/objs/test_params.py b/specparam/tests/objs/test_params.py new file mode 100644 index 00000000..5a19426e --- /dev/null +++ b/specparam/tests/objs/test_params.py @@ -0,0 +1,13 @@ +"""Tests for specparam.objs.params.""" + +from specparam.objs.params import * + +################################################################################################### +################################################################################################### + +## ModelParameters object + +def test_model_parameters(): + + mp = ModelParameters() + assert mp From 114d20ae522b7578b43531c65e8773953103c0ee Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 13 Apr 2025 00:36:19 -0400 Subject: [PATCH 06/40] update to use results.params --- specparam/algorithms/spectral_fit.py | 14 ++++---- specparam/data/periodic.py | 14 ++++---- specparam/data/utils.py | 2 +- specparam/io/files.py | 2 +- specparam/io/models.py | 26 +++++++------- specparam/measures/metrics.py | 2 +- specparam/models/base.py | 5 +++ specparam/models/group.py | 2 +- specparam/models/model.py | 2 +- specparam/objs/results.py | 53 +++++++++++++++------------- specparam/plts/annotate.py | 8 ++--- specparam/plts/model.py | 10 +++--- specparam/reports/strings.py | 4 +-- specparam/tests/io/test_models.py | 4 +-- specparam/tests/models/test_group.py | 6 ++-- specparam/tests/models/test_model.py | 22 ++++++------ specparam/tests/objs/test_metrics.py | 4 +-- specparam/tests/objs/test_results.py | 2 +- 18 files changed, 97 insertions(+), 85 deletions(-) diff --git a/specparam/algorithms/spectral_fit.py b/specparam/algorithms/spectral_fit.py index cc86dae5..dd66daf0 100644 --- a/specparam/algorithms/spectral_fit.py +++ b/specparam/algorithms/spectral_fit.py @@ -136,26 +136,26 @@ def _fit(self): ## FIT PROCEDURES # Take an initial fit of the aperiodic component - temp_aperiodic_params_ = self._robust_ap_fit(self.data.freqs, self.data.power_spectrum) - temp_ap_fit = self.modes.aperiodic.func(self.data.freqs, *temp_aperiodic_params_) + temp_aperiodic_params = self._robust_ap_fit(self.data.freqs, self.data.power_spectrum) + temp_ap_fit = self.modes.aperiodic.func(self.data.freqs, *temp_aperiodic_params) # Find peaks from the flattened power spectrum, and fit them with gaussians temp_spectrum_flat = self.data.power_spectrum - temp_ap_fit - self.results.gaussian_params_ = self._fit_peaks(temp_spectrum_flat) + self.results.params.gaussian = self._fit_peaks(temp_spectrum_flat) # Calculate the peak fit # Note: if no peaks are found, this creates a flat (all zero) peak fit self.results.model._peak_fit = self.modes.periodic.func(\ - self.data.freqs, *np.ndarray.flatten(self.results.gaussian_params_)) + self.data.freqs, *np.ndarray.flatten(self.results.params.gaussian)) # Create peak-removed (but not flattened) power spectrum self.results.model._spectrum_peak_rm = self.data.power_spectrum - self.results.model._peak_fit # Run final aperiodic fit on peak-removed power spectrum - self.results.aperiodic_params_ = self._simple_ap_fit(\ + self.results.params.aperiodic = self._simple_ap_fit(\ self.data.freqs, self.results.model._spectrum_peak_rm) self.results.model._ap_fit = self.modes.aperiodic.func(\ - self.data.freqs, *self.results.aperiodic_params_) + self.data.freqs, *self.results.params.aperiodic) # Create remaining model components: flatspec & full power_spectrum model fit self.results.model._spectrum_flat = self.data.power_spectrum - self.results.model._ap_fit @@ -164,7 +164,7 @@ def _fit(self): ## PARAMETER UPDATES # Convert gaussian definitions to peak parameters - self.results.peak_params_ = self._create_peak_params(self.results.gaussian_params_) + self.results.params.peak = self._create_peak_params(self.results.params.gaussian) def _reset_internal_settings(self): diff --git a/specparam/data/periodic.py b/specparam/data/periodic.py index 4be5ae22..753e7386 100644 --- a/specparam/data/periodic.py +++ b/specparam/data/periodic.py @@ -6,7 +6,7 @@ ################################################################################################### def get_band_peak(model, band, select_highest=True, threshold=None, - thresh_param='PW', attribute='peak_params'): + thresh_param='PW', attribute='peak'): """Extract peaks from a band of interest from a model object. Parameters @@ -23,7 +23,7 @@ def get_band_peak(model, band, select_highest=True, threshold=None, A minimum threshold value to apply. thresh_param : {'PW', 'BW'} Which parameter to threshold on. 'PW' is power and 'BW' is bandwidth. - attribute : {'peak_params', 'gaussian_params'} + attribute : {'peak', 'gaussian'} Which attribute of peak data to extract data from. Returns @@ -42,11 +42,11 @@ def get_band_peak(model, band, select_highest=True, threshold=None, >>> betas = get_band_peak(model, [13, 30], select_highest=False) # doctest:+SKIP """ - return get_band_peak_arr(getattr(model.results, attribute + '_'), band, + return get_band_peak_arr(getattr(model.results.params, attribute), band, select_highest, threshold, thresh_param) -def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribute='peak_params'): +def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribute='peak'): """Extract peaks from a band of interest from a group model object. Parameters @@ -60,7 +60,7 @@ def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribut A minimum threshold value to apply. thresh_param : {'PW', 'BW'} Which parameter to threshold on. 'PW' is power and 'BW' is bandwidth. - attribute : {'peak_params', 'gaussian_params'} + attribute : {'peak', 'gaussian'} Which attribute of peak data to extract data from. Returns @@ -99,7 +99,7 @@ def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribut threshold, thresh_param) -def get_band_peak_event(event, band, threshold=None, thresh_param='PW', attribute='peak_params'): +def get_band_peak_event(event, band, threshold=None, thresh_param='PW', attribute='peak'): """Extract peaks from a band of interest from an event model object. Parameters @@ -116,7 +116,7 @@ def get_band_peak_event(event, band, threshold=None, thresh_param='PW', attribut A minimum threshold value to apply. thresh_param : {'PW', 'BW'} Which parameter to threshold on. 'PW' is power and 'BW' is bandwidth. - attribute : {'peak_params', 'gaussian_params'} + attribute : {'peak', 'gaussian'} Which attribute of peak data to extract data from. Returns diff --git a/specparam/data/utils.py b/specparam/data/utils.py index 30efa160..3c872926 100644 --- a/specparam/data/utils.py +++ b/specparam/data/utils.py @@ -77,7 +77,7 @@ def get_group_params(group_results, modes, name, field=None): List of FitResults objects, reflecting model results across a group of power spectra. modes : Modes Model modes definition. - name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} + name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'metrics'} Name of the data field to extract across the group. field : str or int, optional Column name / index to extract from selected data, if requested. diff --git a/specparam/io/files.py b/specparam/io/files.py index dde0166b..6b0b6f88 100644 --- a/specparam/io/files.py +++ b/specparam/io/files.py @@ -69,7 +69,7 @@ def load_json(file_name, file_path): # Get dictionary of available attributes, and convert specified lists back into arrays arrays_to_convert = ['freqs', 'power_spectrum', - 'aperiodic_params_', 'peak_params_', 'gaussian_params_'] + 'aperiodic_params', 'peak_params', 'gaussian_params'] data = dict_lst_to_array(data, arrays_to_convert) return data diff --git a/specparam/io/models.py b/specparam/io/models.py index f0ccb75d..f14dd826 100644 --- a/specparam/io/models.py +++ b/specparam/io/models.py @@ -48,17 +48,12 @@ def save_model(model, file_name, file_path=None, append=False, If the save file is not understood. """ - # Convert object to dictionary & convert all arrays to lists, for JSON serializing - # This 'flattens' the object, getting all relevant attributes in the same dictionary - obj_dict = dict_array_to_lst(model.__dict__) - data_dict = dict_array_to_lst(model.data.__dict__) - results_dict = dict_array_to_lst(model.results.__dict__) - algo_dict = dict_array_to_lst(model.algorithm.__dict__) - obj_dict = {**obj_dict, **data_dict, **results_dict, **algo_dict} + # 'Flatten' the model object by extracting relevant attributes to a dictionary + obj_dict = {**model.data.__dict__, **model.algorithm.__dict__} # Convert modes object to their saveable string name - obj_dict['aperiodic_mode'] = obj_dict['modes'].aperiodic.name - obj_dict['periodic_mode'] = obj_dict['modes'].periodic.name + obj_dict['aperiodic_mode'] = model.modes.aperiodic.name + obj_dict['periodic_mode'] = model.modes.periodic.name mode_labels = ['aperiodic_mode', 'periodic_mode'] # Add bands information to saveable information @@ -66,8 +61,15 @@ def save_model(model, file_name, file_path=None, append=False, if not model.results.bands._n_bands else model.results.bands._n_bands bands_label = ['bands'] if model.results.bands else [] - # Convert metrics results to saveable information - obj_dict['metrics'] = obj_dict['metrics'].results + # Convert results & metrics to saveable information + results_labels = [] + for rfield in model.results._fields: + results_labels.append(rfield + '_params') + obj_dict[rfield + '_params'] = getattr(model.results.params, rfield) + obj_dict['metrics'] = model.results.metrics.results + + # Convert all arrays to list for JSON serialization + obj_dict = dict_array_to_lst(obj_dict) # Check for saving out base information / check if base only if save_base is None: @@ -79,7 +81,7 @@ def save_model(model, file_name, file_path=None, append=False, keep = set(\ (mode_labels + bands_label if save_base else []) + \ (model.data._meta_fields if save_base or base_only else []) + \ - (model.results._fields + ['metrics'] if save_results else []) + \ + (results_labels + ['metrics'] if save_results else []) + \ (model.algorithm.settings.names if save_settings else []) + \ (model.data._fields if save_data else [])) diff --git a/specparam/measures/metrics.py b/specparam/measures/metrics.py index 9d1f5106..1352be7c 100644 --- a/specparam/measures/metrics.py +++ b/specparam/measures/metrics.py @@ -20,5 +20,5 @@ 'gof_rsquared' : Metric('gof', 'rsquared', compute_r_squared), 'gof_adjrsquared' : Metric('gof', 'adjrsquared', compute_adj_r_squared, \ {'n_params' : lambda data, results: \ - results.peak_params_.size + results.aperiodic_params_.size}) + results.params.peak.size + results.params.aperiodic.size}) } diff --git a/specparam/models/base.py b/specparam/models/base.py index 9af02e6c..6657d9d0 100644 --- a/specparam/models/base.py +++ b/specparam/models/base.py @@ -3,6 +3,7 @@ from copy import deepcopy from specparam.utils.array import unlog +from specparam.utils.checks import check_array_dim from specparam.modes.modes import Modes from specparam.modutils.errors import NoDataError from specparam.reports.strings import gen_modes_str, gen_settings_str, gen_issue_str @@ -172,6 +173,10 @@ def _add_from_dict(self, data): tmetrics = data.pop('metrics') self.results.add_metrics(list(tmetrics.keys())) self.results.metrics.add_results(tmetrics) + for label, params in {key : vals for key, vals in data.items() if 'params' in key}.items(): + if 'peak' in label or 'gaussian' in label: + params = check_array_dim(params) + setattr(self.results.params, label.split('_')[0], params) # Add additional attributes directly to object for key in data.keys(): diff --git a/specparam/models/group.py b/specparam/models/group.py index cecaa95d..d553cc29 100644 --- a/specparam/models/group.py +++ b/specparam/models/group.py @@ -232,7 +232,7 @@ def load(self, file_name, file_path=None): self.algorithm._check_loaded_settings(data) # If results part of current data added, check and update object results - if set(self.results._fields).issubset(set(data.keys())): + if set([el + '_params' for el in self.results._fields]).issubset(set(data.keys())): self.results._check_loaded_results(data) self.results.group_results.append(self.results._get_results()) diff --git a/specparam/models/model.py b/specparam/models/model.py index 85813839..fe08928e 100644 --- a/specparam/models/model.py +++ b/specparam/models/model.py @@ -283,7 +283,7 @@ def load(self, file_name, file_path=None, regenerate=True): if regenerate: if self.data.freq_res: self.data._regenerate_freqs() - if np.all(self.data.freqs) and np.all(self.results.aperiodic_params_): + if np.all(self.data.freqs) and np.all(self.results.params.aperiodic): self.results._regenerate_model(self.data.freqs) diff --git a/specparam/objs/results.py b/specparam/objs/results.py index 1528946d..82c5a20a 100644 --- a/specparam/objs/results.py +++ b/specparam/objs/results.py @@ -7,6 +7,7 @@ from specparam.bands.bands import check_bands from specparam.objs.metrics import Metrics +from specparam.objs.params import ModelParameters from specparam.objs.components import ModelComponents from specparam.measures.metrics import METRICS from specparam.utils.checks import check_inds, check_array_dim @@ -22,7 +23,8 @@ ################################################################################################### # Define set of results fields & default metrics to use -RESULTS_FIELDS = ['aperiodic_params_', 'gaussian_params_', 'peak_params_'] +#RESULTS_FIELDS = ['aperiodic_params_', 'gaussian_params_', 'peak_params_'] +RESULTS_FIELDS = ['aperiodic', 'gaussian', 'peak'] DEFAULT_METRICS = ['error_mae', 'gof_rsquared'] @@ -49,6 +51,7 @@ def __init__(self, modes=None, metrics=None, bands=None): self.add_metrics(metrics) self.model = ModelComponents() + self.params = ModelParameters() # Initialize results attributes self._reset_results(True) @@ -67,7 +70,7 @@ def has_model(self): - necessarily defined, as floats, if model has been fit """ - return not np.all(np.isnan(self.aperiodic_params_)) + return not np.all(np.isnan(self.params.aperiodic)) @property @@ -76,7 +79,7 @@ def n_peaks_(self): n_peaks = None if self.has_model: - n_peaks = self.peak_params_.shape[0] + n_peaks = self.params.peak.shape[0] return n_peaks @@ -137,7 +140,7 @@ def add_results(self, results): # Add parameter fields and then select and add metrics results for pfield in self._fields: - setattr(self, pfield, getattr(results, pfield.strip('_'))) + setattr(self.params, pfield, getattr(results, pfield + '_params')) self.metrics.add_results(results.metrics) @@ -154,7 +157,7 @@ def get_results(self): """ results = FitResults( - **{key.strip('_') : getattr(self, key) for key in self._fields}, + **{key + '_params' : getattr(self.params, key) for key in self._fields}, metrics=self.metrics.results) return results @@ -192,6 +195,7 @@ def get_params(self, name, field=None): return get_model_params(self.get_results(), self.modes, name, field) + # TODO: check / move to ModelParameters? def _check_loaded_results(self, data): """Check if results have been added and check data. @@ -204,8 +208,8 @@ def _check_loaded_results(self, data): # If results loaded, check dimensions of peak parameters # This fixes an issue where they end up the wrong shape if they are empty (no peaks) if set(self._fields).issubset(set(data.keys())): - self.peak_params_ = check_array_dim(self.peak_params_) - self.gaussian_params_ = check_array_dim(self.gaussian_params_) + self.params.peak = check_array_dim(self.params.peak) + self.params.gaussian = check_array_dim(self.params.gaussian) def _reset_results(self, clear_results=False): @@ -219,21 +223,22 @@ def _reset_results(self, clear_results=False): if clear_results: - # Aperiodic parameters - if self.modes: - self.aperiodic_params_ = np.array([np.nan] * self.modes.aperiodic.n_params) - else: - self.aperiodic_params_ = np.nan - - # Periodic parameters - if self.modes: - self.gaussian_params_ = np.empty([0, self.modes.periodic.n_params]) - self.peak_params_ = np.empty([0, self.modes.periodic.n_params]) - else: - self.gaussian_params_ = np.nan - self.peak_params_ = np.nan - - # Reset model components + # # Aperiodic parameters + # if self.modes: + # self.aperiodic_params_ = np.array([np.nan] * self.modes.aperiodic.n_params) + # else: + # self.aperiodic_params_ = np.nan + + # # Periodic parameters + # if self.modes: + # self.gaussian_params_ = np.empty([0, self.modes.periodic.n_params]) + # self.peak_params_ = np.empty([0, self.modes.periodic.n_params]) + # else: + # self.gaussian_params_ = np.nan + # self.peak_params_ = np.nan + + # Reset model parameters & components + self.params.reset(self.modes) self.model.reset() @@ -247,8 +252,8 @@ def _regenerate_model(self, freqs): """ self.model.modeled_spectrum, self.model._peak_fit, self.model._ap_fit = \ - gen_model(freqs, self.modes.aperiodic, self.aperiodic_params_, - self.modes.periodic, self.gaussian_params_, return_components=True) + gen_model(freqs, self.modes.aperiodic, self.params.aperiodic, + self.modes.periodic, self.params.gaussian, return_components=True) @replace_docstring_sections([docs_get_section(Results.__doc__, 'Parameters')]) diff --git a/specparam/plts/annotate.py b/specparam/plts/annotate.py index bbc889de..e2775fb7 100644 --- a/specparam/plts/annotate.py +++ b/specparam/plts/annotate.py @@ -33,14 +33,14 @@ def plot_annotated_peak_search(model): # is the same as the one that is used in the peak fitting procedure flatspec = model.data.power_spectrum - \ model.modes.aperiodic.func(model.data.freqs, \ - *model.algorithm._robust_ap_fit(model.data.freqs, model.data.power_spectrum),) + *model.algorithm._robust_ap_fit(model.data.freqs, model.data.power_spectrum)) # Calculate ylims of the plot that are scaled to the range of the data ylims = [min(flatspec) - 0.1 * np.abs(min(flatspec)), max(flatspec) + 0.1 * max(flatspec)] # Sort parameters by peak height - gaussian_params = model.results.gaussian_params_[\ - model.results.gaussian_params_[:, 1].argsort()][::-1] + gaussian_params = model.results.params.gaussian[\ + model.results.params.gaussian[:, 1].argsort()][::-1] # Loop through the iterative search for each peak for ind in range(model.results.n_peaks_ + 1): @@ -139,7 +139,7 @@ def plot_annotated_model(model, plt_log=False, annotate_peaks=True, if annotate_peaks and model.results.n_peaks_: # Extract largest peak, to annotate, grabbing gaussian params - gauss = get_band_peak(model, model.data.freq_range, attribute='gaussian_params') + gauss = get_band_peak(model, model.data.freq_range, attribute='gaussian') peak_ctr, peak_hgt, peak_wid = gauss bw_freqs = [peak_ctr - 0.5 * compute_fwhm(peak_wid), diff --git a/specparam/plts/model.py b/specparam/plts/model.py index e19e03ae..2499b1c1 100644 --- a/specparam/plts/model.py +++ b/specparam/plts/model.py @@ -168,7 +168,7 @@ def _add_peaks_shade(model, plt_log, ax, **plot_kwargs): defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.25} plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) - for peak in model.results.gaussian_params_: + for peak in model.results.params.gaussian: peak_freqs = np.log10(model.data.freqs) if plt_log else model.data.freqs peak_line = model.results.model._ap_fit + model.modes.periodic.func(model.data.freqs, *peak) @@ -194,7 +194,7 @@ def _add_peaks_dot(model, plt_log, ax, **plot_kwargs): defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6} plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) - for peak in model.results.peak_params_: + for peak in model.results.params.peak: ap_point = np.interp(peak[0], model.data.freqs, model.results.model._ap_fit) freq_point = np.log10(peak[0]) if plt_log else peak[0] @@ -224,7 +224,7 @@ def _add_peaks_outline(model, plt_log, ax, **plot_kwargs): defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.7, 'lw' : 1.5} plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) - for peak in model.results.gaussian_params_: + for peak in model.results.params.gaussian: # Define the frequency range around each peak to plot - peak bandwidth +/- 3 peak_range = [peak[0] - peak[2]*3, peak[0] + peak[2]*3] @@ -258,7 +258,7 @@ def _add_peaks_line(model, plt_log, ax, **plot_kwargs): ylims = ax.get_ylim() - for peak in model.results.peak_params_: + for peak in model.results.params.peak: freq_point = np.log10(peak[0]) if plt_log else peak[0] ax.plot([freq_point, freq_point], ylims, '-', **plot_kwargs) @@ -288,7 +288,7 @@ def _add_peaks_width(model, plt_log, ax, **plot_kwargs): defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6} plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) - for peak in model.results.gaussian_params_: + for peak in model.results.params.gaussian: peak_top = model.data.power_spectrum[nearest_ind(model.data.freqs, peak[0])] bw_freqs = [peak[0] - 0.5 * compute_fwhm(peak[2]), diff --git a/specparam/reports/strings.py b/specparam/reports/strings.py index cea9261b..8f0a0e0c 100644 --- a/specparam/reports/strings.py +++ b/specparam/reports/strings.py @@ -388,13 +388,13 @@ def gen_model_results_str(model, concise=False): 'Aperiodic Parameters (\'{}\' mode)'.format(model.modes.aperiodic.name), '(' + ', '.join(model.modes.aperiodic.params.labels) + ')', ', '.join(['{:2.4f}'] * \ - len(model.results.aperiodic_params_)).format(*model.results.aperiodic_params_), + len(model.results.params.aperiodic)).format(*model.results.params.aperiodic), '', # Peak parameters 'Peak Parameters (\'{}\' mode) {} peaks found'.format(\ model.modes.periodic.name, model.results.n_peaks_), - *[peak_str.format(*op) for op in model.results.peak_params_], + *[peak_str.format(*op) for op in model.results.params.peak], '', # Metrics diff --git a/specparam/tests/io/test_models.py b/specparam/tests/io/test_models.py index a22c7594..a0c05799 100644 --- a/specparam/tests/io/test_models.py +++ b/specparam/tests/io/test_models.py @@ -161,7 +161,7 @@ def test_load_file_contents(tfm): for setting in tfm.algorithm.settings.names: assert setting in loaded_data.keys() for result in tfm.results._fields: - assert result in loaded_data.keys() + assert result + '_params' in loaded_data.keys() assert 'metrics' in loaded_data.keys() for datum in tfm.data._fields: assert datum in loaded_data.keys() @@ -181,7 +181,7 @@ def test_load_model(tfm): for setting in ntfm.algorithm.settings.names: assert getattr(ntfm.algorithm, setting) is not None for result in tfm.results._fields: - assert not np.all(np.isnan(getattr(ntfm.results, result))) + assert not np.all(np.isnan(getattr(ntfm.results.params, result))) assert tfm.results.metrics.results == ntfm.results.metrics.results for data in tfm.data._fields: assert getattr(ntfm.data, data) is not None diff --git a/specparam/tests/models/test_group.py b/specparam/tests/models/test_group.py index 00634f98..3f2d2e59 100644 --- a/specparam/tests/models/test_group.py +++ b/specparam/tests/models/test_group.py @@ -293,7 +293,7 @@ def test_load(tfg): assert getattr(ntfg.algorithm, setting) is not None # Test that results and data are None for result in tfg.results._fields: - assert np.all(np.isnan(getattr(ntfg.results, result))) + assert np.all(np.isnan(getattr(ntfg.results.params, result))) assert ntfg.data.power_spectra is None # Test loading just data @@ -304,7 +304,7 @@ def test_load(tfg): for setting in tfg.algorithm.settings.names: assert getattr(ntfg.algorithm, setting) is None for result in tfg.results._fields: - assert np.all(np.isnan(getattr(ntfg.results, result))) + assert np.all(np.isnan(getattr(ntfg.results.params, result))) # Test loading all elements ntfg = SpectralGroupModel(verbose=False) @@ -352,7 +352,7 @@ def test_get_model(tfg): assert tfm1 # Check that regenerated model is created for result in tfg.results._fields: - assert np.all(getattr(tfm1.results, result)) + assert np.all(getattr(tfm1.results.params, result)) # Test when object has no data (clear a copy of tfg) new_tfg = tfg.copy() diff --git a/specparam/tests/models/test_model.py b/specparam/tests/models/test_model.py index e0bb7523..aa89874e 100644 --- a/specparam/tests/models/test_model.py +++ b/specparam/tests/models/test_model.py @@ -71,11 +71,11 @@ def test_fit_nk(): tfm.fit(xs, ys) # Check model results - aperiodic parameters - assert np.allclose(ap_params, tfm.results.aperiodic_params_, [0.5, 0.1]) + assert np.allclose(ap_params, tfm.results.params.aperiodic, [0.5, 0.1]) # Check model results - gaussian parameters for ii, gauss in enumerate(groupby(gauss_params, 3)): - assert np.allclose(gauss, tfm.results.gaussian_params_[ii], [2.0, 0.5, 1.0]) + assert np.allclose(gauss, tfm.results.params.gaussian[ii], [2.0, 0.5, 1.0]) def test_fit_nk_noise(): """Test fit on noisy data, to make sure nothing breaks.""" @@ -102,11 +102,11 @@ def test_fit_knee(): tfm.fit(xs, ys) # Check model results - aperiodic parameters - assert np.allclose(ap_params, tfm.results.aperiodic_params_, [1, 2, 0.2]) + assert np.allclose(ap_params, tfm.results.params.aperiodic, [1, 2, 0.2]) # Check model results - gaussian parameters for ii, gauss in enumerate(groupby(gauss_params, 3)): - assert np.allclose(gauss, tfm.results.gaussian_params_[ii], [2.0, 0.5, 1.0]) + assert np.allclose(gauss, tfm.results.params.gaussian[ii], [2.0, 0.5, 1.0]) def test_fit_default_metrics(): """Test goodness of fit & error metrics, post model fitting.""" @@ -198,7 +198,7 @@ def test_load(tfm): ntfm.load(file_name_res, TEST_DATA_PATH) # Check that result attributes get filled for result in tfm.results._fields: - assert not np.all(np.isnan(getattr(ntfm.results, result))) + assert not np.all(np.isnan(getattr(ntfm.results.params, result))) # Test that settings and data are None for setting in tfm.algorithm.settings.names: assert getattr(ntfm.algorithm, setting) is None @@ -212,7 +212,7 @@ def test_load(tfm): assert getattr(ntfm.algorithm, setting) is not None # Test that results and data are None for result in tfm.results._fields: - assert np.all(np.isnan(getattr(ntfm.results, result))) + assert np.all(np.isnan(getattr(ntfm.results.params, result))) assert ntfm.data.power_spectrum is None # Test loading just data @@ -224,14 +224,14 @@ def test_load(tfm): for setting in tfm.algorithm.settings.names: assert getattr(ntfm.algorithm, setting) is None for result in tfm.results._fields: - assert np.all(np.isnan(getattr(ntfm.results, result))) + assert np.all(np.isnan(getattr(ntfm.results.params, result))) # Test loading all elements ntfm = SpectralModel(verbose=False) file_name_all = 'test_model_all' ntfm.load(file_name_all, TEST_DATA_PATH) for result in tfm.results._fields: - assert not np.all(np.isnan(getattr(ntfm.results, result))) + assert not np.all(np.isnan(getattr(ntfm.results.params, result))) for setting in tfm.algorithm.settings.names: assert getattr(ntfm.algorithm, setting) is not None for data in tfm.data._fields: @@ -329,7 +329,7 @@ def test_resets(): for key, value in tfm.results.model.__dict__.items(): assert value is None for field in tfm.results._fields: - assert np.all(np.isnan(getattr(tfm.results, field))) + assert np.all(np.isnan(getattr(tfm.results.params, field))) assert tfm.data.freqs is None and tfm.results.model.modeled_spectrum is None def test_report(skip_if_no_mpl): @@ -351,7 +351,7 @@ def test_fit_failure(): # Check after failing out of fit, all results are reset for result in tfm.results._fields: - assert np.all(np.isnan(getattr(tfm.results, result))) + assert np.all(np.isnan(getattr(tfm.results.params, result))) ## Monkey patch to check errors in general # This mimics the main fit-failure, without requiring bad data / waiting for it to fail. @@ -365,7 +365,7 @@ def raise_runtime_error(*args, **kwargs): # Check after failing out of fit, all results are reset for result in tfm.results._fields: - assert np.all(np.isnan(getattr(tfm.results, result))) + assert np.all(np.isnan(getattr(tfm.results.params, result))) def test_debug(): """Test model object in debug state, including with fit failures.""" diff --git a/specparam/tests/objs/test_metrics.py b/specparam/tests/objs/test_metrics.py index 0885453a..07028b9f 100644 --- a/specparam/tests/objs/test_metrics.py +++ b/specparam/tests/objs/test_metrics.py @@ -23,7 +23,7 @@ def test_metric_kwargs(tfm): metric = Metric('gof', 'ar2', compute_adj_r_squared, {'n_params' : lambda data, results: \ - results.peak_params_.size + results.aperiodic_params_.size}) + results.params.peak.size + results.params.aperiodic.size}) assert isinstance(metric, Metric) assert isinstance(metric.label, str) @@ -77,7 +77,7 @@ def test_metrics_kwargs(tfm): ar2_met_def = {'type' : 'gof', 'measure' : 'arsquared', 'func' : compute_adj_r_squared, 'kwargs' : {'n_params' : lambda data, results: \ - results.peak_params_.size + results.aperiodic_params_.size}} + results.params.peak.size + results.params.aperiodic.size}} metrics = Metrics([er_met_def, ar2_met_def]) assert isinstance(metrics, Metrics) diff --git a/specparam/tests/objs/test_results.py b/specparam/tests/objs/test_results.py index e5ecb03e..c5ac847e 100644 --- a/specparam/tests/objs/test_results.py +++ b/specparam/tests/objs/test_results.py @@ -19,7 +19,7 @@ def test_results_results(tresults): tres.add_results(tresults) assert tres.has_model for result in tres._fields: - assert np.array_equal(getattr(tres, result), getattr(tresults, result.strip('_'))) + assert np.array_equal(getattr(tres.params, result), getattr(tresults, result + '_params')) results_out = tres.get_results() assert results_out == tresults From c5b8f96dcfd249ef3ecba00fe08e22ba5e318bb9 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 13 Apr 2025 00:41:07 -0400 Subject: [PATCH 07/40] add non-default settings to test objects --- specparam/tests/tdata.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/specparam/tests/tdata.py b/specparam/tests/tdata.py index 6791d2d0..2b585d3a 100644 --- a/specparam/tests/tdata.py +++ b/specparam/tests/tdata.py @@ -53,7 +53,8 @@ def get_tdata2d(): def get_tfm(): """Get a model object, with a fit power spectrum, for testing.""" - tfm = SpectralModel(bands=Bands({'alpha' : (7, 14)}), verbose=False) + tfm = SpectralModel(bands=Bands({'alpha' : (7, 14)}), + min_peak_height=0.05, peak_width_limits=[1, 8]) tfm.fit(*sim_power_spectrum(*default_spectrum_params())) return tfm @@ -62,6 +63,7 @@ def get_tfm2(): """Get a model object, with a fit power spectrum, for testing - custom metrics & modes.""" tfm2 = SpectralModel(bands=Bands({'alpha' : (7, 14), 'beta' : [15, 30]}), + min_peak_height=0.05, peak_width_limits=[1, 8], metrics=['error_mse', 'gof_adjrsquared'], aperiodic_mode='knee', periodic_mode='gaussian') tfm2.fit(*sim_power_spectrum(*default_spectrum_params())) @@ -72,7 +74,8 @@ def get_tfg(): """Get a group object, with some fit power spectra, for testing.""" n_spectra = 3 - tfg = SpectralGroupModel(bands=Bands({'alpha' : (7, 14)}), verbose=False) + tfg = SpectralGroupModel(bands=Bands({'alpha' : (7, 14)}), + min_peak_height=0.05, peak_width_limits=[1, 8]) tfg.fit(*sim_group_power_spectra(n_spectra, *default_group_params())) return tfg @@ -82,6 +85,7 @@ def get_tfg2(): n_spectra = 3 tfg2 = SpectralGroupModel(bands=Bands({'alpha' : (7, 14), 'beta' : [15, 30]}), + min_peak_height=0.05, peak_width_limits=[1, 8], metrics=['error_mse', 'gof_adjrsquared'], aperiodic_mode='knee', periodic_mode='gaussian') tfg2.fit(*sim_group_power_spectra(n_spectra, *default_group_params())) @@ -94,7 +98,8 @@ def get_tft(): n_spectra = 3 xs, ys = sim_spectrogram(n_spectra, *default_group_params()) - tft = SpectralTimeModel(bands=Bands({'alpha' : (7, 14)}), verbose=False) + tft = SpectralTimeModel(bands=Bands({'alpha' : (7, 14)}), \ + min_peak_height=0.05, peak_width_limits=[1, 8],) tft.fit(xs, ys) return tft @@ -106,7 +111,8 @@ def get_tfe(): xs, ys = sim_spectrogram(n_spectra, *default_group_params()) ys = [ys, ys] - tfe = SpectralTimeEventModel(bands=Bands({'alpha' : (7, 14)}), verbose=False) + tfe = SpectralTimeEventModel(bands=Bands({'alpha' : (7, 14)}), + min_peak_height=0.05, peak_width_limits=[1, 8],) tfe.fit(xs, ys) return tfe From bb3cee9852c6fb60f1b0c8f65a8f006180bce39a Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 13 Apr 2025 01:15:47 -0400 Subject: [PATCH 08/40] extend compare_model_objs --- specparam/models/utils.py | 7 ++++++- specparam/tests/models/test_utils.py | 11 ++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/specparam/models/utils.py b/specparam/models/utils.py index dfe4bd33..dd7bb64d 100644 --- a/specparam/models/utils.py +++ b/specparam/models/utils.py @@ -72,14 +72,19 @@ def compare_model_objs(model_objs, aspect): outputs.append(compare_model_objs(model_objs, caspect)) return np.all(outputs) - check_input_options(aspect, ['settings', 'meta_data', 'metrics'], 'aspect') + aspects = ['modes', 'settings', 'meta_data', 'bands', 'metrics'] + check_input_options(aspect, aspects, 'aspect') # Check specified aspect of the objects are the same across instances for m_obj_1, m_obj_2 in zip(model_objs[:-1], model_objs[1:]): + if aspect == 'modes': + consistent = m_obj_1.modes.get_modes() == m_obj_2.modes.get_modes() if aspect == 'settings': consistent = m_obj_1.algorithm.get_settings() == m_obj_2.algorithm.get_settings() if aspect == 'meta_data': consistent = m_obj_1.data.get_meta_data() == m_obj_2.data.get_meta_data() + if aspect == 'bands': + consistent = m_obj_1.results.bands == m_obj_2.results.bands if aspect == 'metrics': consistent = m_obj_1.results.metrics.labels == m_obj_2.results.metrics.labels diff --git a/specparam/tests/models/test_utils.py b/specparam/tests/models/test_utils.py index b34e0399..85a28b39 100644 --- a/specparam/tests/models/test_utils.py +++ b/specparam/tests/models/test_utils.py @@ -33,7 +33,12 @@ def test_compare_model_objs(tfm, tfg): f_obj2 = f_obj.copy() - assert compare_model_objs([f_obj, f_obj2], ['settings', 'meta_data', 'metrics']) + assert compare_model_objs([f_obj, f_obj2], + ['modes', 'settings', 'meta_data', 'bands', 'metrics']) + + assert compare_model_objs([f_obj, f_obj2], 'modes') + f_obj2.add_modes('knee', 'cauchy') + assert not compare_model_objs([f_obj, f_obj2], 'modes') assert compare_model_objs([f_obj, f_obj2], 'settings') f_obj2.algorithm.peak_width_limits = [2, 4] @@ -44,6 +49,10 @@ def test_compare_model_objs(tfm, tfg): f_obj2.data.freq_range = [5, 25] assert not compare_model_objs([f_obj, f_obj2], 'meta_data') + assert compare_model_objs([f_obj, f_obj2], 'bands') + f_obj2.results.add_bands({'new' : [1, 4]}) + assert not compare_model_objs([f_obj, f_obj2], 'bands') + assert compare_model_objs([f_obj, f_obj2], 'metrics') f_obj2.results.metrics.add_metric(METRICS['error_rmse']) assert not compare_model_objs([f_obj, f_obj2], 'metrics') From 9804529e5c7cd56120c8345e87c1dd93f65beeef Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 13 Apr 2025 01:44:15 -0400 Subject: [PATCH 09/40] udpate IO tests to be more stringent --- specparam/tests/io/test_models.py | 100 ++++++++++++--------------- specparam/tests/models/test_event.py | 30 ++++---- specparam/tests/models/test_group.py | 22 ++---- specparam/tests/models/test_model.py | 27 +++----- specparam/tests/models/test_time.py | 30 ++++---- 5 files changed, 96 insertions(+), 113 deletions(-) diff --git a/specparam/tests/io/test_models.py b/specparam/tests/io/test_models.py index a0c05799..90e5e303 100644 --- a/specparam/tests/io/test_models.py +++ b/specparam/tests/io/test_models.py @@ -151,10 +151,7 @@ def test_load_file_contents(tfm): """Check that loaded model files contain the contents they should.""" # Loads file saved from `test_save_model_str` - file_name = 'test_model_all' - - loaded_data = load_json(file_name, TEST_DATA_PATH) - + loaded_data = load_json('test_model_all', TEST_DATA_PATH) for mode in tfm.modes.get_modes()._fields: assert mode in loaded_data.keys() assert 'bands' in loaded_data.keys() @@ -169,99 +166,94 @@ def test_load_file_contents(tfm): def test_load_model(tfm): # Loads file saved from `test_save_model_str` - file_name = 'test_model_all' - ntfm = load_model(file_name, TEST_DATA_PATH) + ntfm = load_model('test_model_all', TEST_DATA_PATH) assert isinstance(ntfm, SpectralModel) - - # Check that all elements get loaded - assert tfm.modes.get_modes() == ntfm.modes.get_modes() - assert tfm.results.bands == ntfm.results.bands - for meta_dat in tfm.data._meta_fields: - assert getattr(ntfm.data, meta_dat) is not None - for setting in ntfm.algorithm.settings.names: - assert getattr(ntfm.algorithm, setting) is not None + compare_model_objs([tfm, ntfm], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) + for data in tfm.data._fields: + assert np.array_equal(getattr(tfm.data, data), getattr(ntfm.data, data)) for result in tfm.results._fields: assert not np.all(np.isnan(getattr(ntfm.results.params, result))) assert tfm.results.metrics.results == ntfm.results.metrics.results - for data in tfm.data._fields: - assert getattr(ntfm.data, data) is not None # Check directory matches (loading didn't add any unexpected attributes) cfm = SpectralModel() assert dir(cfm) == dir(ntfm) + assert dir(cfm.algorithm) == dir(ntfm.algorithm) assert dir(cfm.data) == dir(ntfm.data) assert dir(cfm.results) == dir(ntfm.results) + assert dir(cfm.results.params) == dir(ntfm.results.params) def test_load_model2(tfm2): # Loads file saved from `test_save_model_str2` - file_name = 'test_model_all2' - ntfm2 = load_model(file_name, TEST_DATA_PATH) - assert tfm2.modes.get_modes() == ntfm2.modes.get_modes() - compare_model_objs([tfm2, ntfm2], ['settings', 'meta_data', 'metrics']) + ntfm2 = load_model('test_model_all2', TEST_DATA_PATH) + compare_model_objs([tfm2, ntfm2], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) def test_load_group(tfg): # Loads file saved from `test_save_group` - file_name = 'test_group_all' - ntfg = load_group(file_name, TEST_DATA_PATH) + ntfg = load_group('test_group_all', TEST_DATA_PATH) assert isinstance(ntfg, SpectralGroupModel) - - # Check that all elements get loaded - assert tfg.modes.get_modes() == ntfg.modes.get_modes() - assert tfg.results.bands == ntfg.results.bands - for setting in tfg.algorithm.settings.names: - assert getattr(ntfg.algorithm, setting) is not None + compare_model_objs([tfg, ntfg], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) + for data in tfg.data._fields: + assert np.array_equal(getattr(tfg.data, data), getattr(ntfg.data, data)) assert len(ntfg.results.group_results) > 0 for metric in tfg.results.metrics.labels: - assert tfg.results.metrics.results[metric] is not None - assert ntfg.data.power_spectra is not None - for meta_dat in tfg.data._meta_fields: - assert getattr(ntfg.data, meta_dat) is not None + assert tfg.results.metrics.results[metric] == ntfg.results.metrics.results[metric] # Check directory matches (loading didn't add any unexpected attributes) cfg = SpectralGroupModel() assert dir(cfg) == dir(ntfg) + assert dir(cfg.algorithm) == dir(ntfg.algorithm) assert dir(cfg.data) == dir(ntfg.data) assert dir(cfg.results) == dir(ntfg.results) + assert dir(cfg.results.params) == dir(ntfg.results.params) def test_load_group2(tfg2): # Loads file saved from `test_save_group_str2` - file_name = 'test_group_all2' - ntfg2 = load_group(file_name, TEST_DATA_PATH) - assert tfg2.modes.get_modes() == ntfg2.modes.get_modes() - compare_model_objs([tfg2, ntfg2], ['settings', 'meta_data', 'metrics']) + ntfg2 = load_group('test_group_all2', TEST_DATA_PATH) + compare_model_objs([tfg2, ntfg2], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) -def test_load_time(): +def test_load_time(tft): # Loads file saved from `test_save_time` - file_name = 'test_time_all' - - # Load without bands definition - tft = load_time(file_name, TEST_DATA_PATH) + ntft = load_time('test_time_all', TEST_DATA_PATH) assert isinstance(tft, SpectralTimeModel) - assert tft.results.time_results + compare_model_objs([tft, ntft], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) + for data in tft.data._fields: + assert np.array_equal(getattr(tft.data, data), getattr(ntft.data, data)) + assert tft.results.time_results.keys() == ntft.results.time_results.keys() + for key in tft.results.time_results: + assert np.array_equal(\ + tft.results.time_results[key], ntft.results.time_results[key], equal_nan=True) # Check directory matches (loading didn't add any unexpected attributes) cft = SpectralTimeModel() - assert dir(cft) == dir(tft) - assert dir(cft.data) == dir(tft.data) - assert dir(cft.results) == dir(tft.results) + assert dir(cft) == dir(ntft) + assert dir(cft.algorithm) == dir(ntft.algorithm) + assert dir(cft.data) == dir(ntft.data) + assert dir(cft.results) == dir(ntft.results) + assert dir(cft.results.params) == dir(ntft.results.params) -def test_load_event(): +def test_load_event(tfe): # Loads file saved from `test_save_event` - file_name = 'test_event_all' - - # Load without bands definition - tfe = load_event(file_name, TEST_DATA_PATH) + ntfe = load_event('test_event_all', TEST_DATA_PATH) assert isinstance(tfe, SpectralTimeEventModel) + compare_model_objs([tfe, ntfe], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) + for data in tfe.data._fields: + assert np.array_equal(getattr(tfe.data, data), getattr(ntfe.data, data)) assert len(tfe.results) > 1 - assert tfe.results.event_time_results + assert tfe.results.time_results.keys() == ntfe.results.time_results.keys() + for key in tfe.results.time_results: + assert np.array_equal(\ + tfe.results.time_results[key], ntfe.results.time_results[key], equal_nan=True) # Check directory matches (loading didn't add any unexpected attributes) cfe = SpectralTimeEventModel() - assert dir(cfe) == dir(tfe) - assert dir(cfe.data) == dir(tfe.data) - assert dir(cfe.results) == dir(tfe.results) + assert dir(cfe) == dir(ntfe) + assert dir(cfe.algorithm) == dir(ntfe.algorithm) + assert dir(cfe.data) == dir(ntfe.data) + assert dir(cfe.results) == dir(ntfe.results) + assert dir(cfe.results.params) == dir(ntfe.results.params) diff --git a/specparam/tests/models/test_event.py b/specparam/tests/models/test_event.py index 83ba2b39..dba44547 100644 --- a/specparam/tests/models/test_event.py +++ b/specparam/tests/models/test_event.py @@ -9,6 +9,7 @@ import numpy as np from specparam.models import SpectralGroupModel, SpectralTimeModel +from specparam.models.utils import compare_model_objs from specparam.sim import sim_spectrogram from specparam.modutils.dependencies import safe_import @@ -95,26 +96,27 @@ def test_event_report(skip_if_no_mpl): assert tfe -def test_event_load(): - - file_name_res = 'test_event_res' - file_name_set = 'test_event_set' - file_name_dat = 'test_event_dat' +def test_event_load(tfe): # Test loading results - tfe = SpectralTimeEventModel(verbose=False) - tfe.load(file_name_res, TEST_DATA_PATH) - assert tfe.results.event_time_results + ntfe = SpectralTimeEventModel(verbose=False) + ntfe.load('test_event_res', TEST_DATA_PATH) + assert ntfe.results.event_time_results # Test loading settings - tfe = SpectralTimeEventModel(verbose=False) - tfe.load(file_name_set, TEST_DATA_PATH) - assert tfe.algorithm.get_settings() + ntfe = SpectralTimeEventModel(verbose=False) + ntfe.load('test_event_set', TEST_DATA_PATH) + assert ntfe.algorithm.get_settings() # Test loading data - tfe = SpectralTimeEventModel(verbose=False) - tfe.load(file_name_dat, TEST_DATA_PATH) - assert np.all(tfe.data.spectrograms) + ntfe = SpectralTimeEventModel(verbose=False) + ntfe.load('test_event_dat', TEST_DATA_PATH) + assert np.all(ntfe.data.spectrograms) + + # Test loading all elements + ntfe = SpectralTimeEventModel(verbose=False) + ntfe.load('test_event_all', TEST_DATA_PATH) + assert compare_model_objs([tfe, ntfe], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) def test_event_get_model(tfe): diff --git a/specparam/tests/models/test_group.py b/specparam/tests/models/test_group.py index 3f2d2e59..586501df 100644 --- a/specparam/tests/models/test_group.py +++ b/specparam/tests/models/test_group.py @@ -12,6 +12,7 @@ from numpy.testing import assert_equal from specparam.measures.metrics import METRICS +from specparam.models.utils import compare_model_objs from specparam.modutils.dependencies import safe_import from specparam.sim import sim_group_power_spectra @@ -273,13 +274,9 @@ def test_load(tfg): """Test load into group object. Note: loads files from test_save_group in specparam/tests/io/test_models.py.""" - file_name_res = 'test_group_res' - file_name_set = 'test_group_set' - file_name_dat = 'test_group_dat' - # Test loading just results ntfg = SpectralGroupModel(verbose=False) - ntfg.load(file_name_res, TEST_DATA_PATH) + ntfg.load('test_group_res', TEST_DATA_PATH) assert len(ntfg.results.group_results) > 0 # Test that settings and data are None for setting in tfg.algorithm.settings.names: @@ -288,9 +285,9 @@ def test_load(tfg): # Test loading just settings ntfg = SpectralGroupModel(verbose=False) - ntfg.load(file_name_set, TEST_DATA_PATH) + ntfg.load('test_group_set', TEST_DATA_PATH) for setting in tfg.algorithm.settings.names: - assert getattr(ntfg.algorithm, setting) is not None + assert getattr(tfg.algorithm, setting) == getattr(ntfg.algorithm, setting) # Test that results and data are None for result in tfg.results._fields: assert np.all(np.isnan(getattr(ntfg.results.params, result))) @@ -298,7 +295,7 @@ def test_load(tfg): # Test loading just data ntfg = SpectralGroupModel(verbose=False) - ntfg.load(file_name_dat, TEST_DATA_PATH) + ntfg.load('test_group_dat', TEST_DATA_PATH) assert ntfg.data.has_data # Test that settings and results are None for setting in tfg.algorithm.settings.names: @@ -308,14 +305,9 @@ def test_load(tfg): # Test loading all elements ntfg = SpectralGroupModel(verbose=False) - file_name_all = 'test_group_all' - ntfg.load(file_name_all, TEST_DATA_PATH) + ntfg.load('test_group_all', TEST_DATA_PATH) + assert compare_model_objs([tfg, ntfg], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) assert len(ntfg.results.group_results) > 0 - for setting in tfg.algorithm.settings.names: - assert getattr(ntfg.algorithm, setting) is not None - assert ntfg.data.has_data - for meta_dat in tfg.data._meta_fields: - assert getattr(ntfg.data, meta_dat) is not None def test_report(skip_if_no_mpl): """Check that running the top level model method runs.""" diff --git a/specparam/tests/models/test_model.py b/specparam/tests/models/test_model.py index aa89874e..98c5d00e 100644 --- a/specparam/tests/models/test_model.py +++ b/specparam/tests/models/test_model.py @@ -14,6 +14,7 @@ from specparam.measures.metrics import METRICS from specparam.sim import gen_freqs, sim_power_spectrum from specparam.modes.definitions import AP_MODES, PE_MODES +from specparam.models.utils import compare_model_objs from specparam.modutils.dependencies import safe_import from specparam.modutils.errors import DataError, NoDataError, InconsistentDataError @@ -194,8 +195,7 @@ def test_load(tfm): # Test loading just results ntfm = SpectralModel(verbose=False) - file_name_res = 'test_model_res' - ntfm.load(file_name_res, TEST_DATA_PATH) + ntfm.load('test_model_res', TEST_DATA_PATH) # Check that result attributes get filled for result in tfm.results._fields: assert not np.all(np.isnan(getattr(ntfm.results.params, result))) @@ -206,10 +206,9 @@ def test_load(tfm): # Test loading just settings ntfm = SpectralModel(verbose=False) - file_name_set = 'test_model_set' - ntfm.load(file_name_set, TEST_DATA_PATH) + ntfm.load('test_model_set', TEST_DATA_PATH) for setting in tfm.algorithm.settings.names: - assert getattr(ntfm.algorithm, setting) is not None + assert getattr(tfm.algorithm, setting) == getattr(ntfm.algorithm, setting) # Test that results and data are None for result in tfm.results._fields: assert np.all(np.isnan(getattr(ntfm.results.params, result))) @@ -217,9 +216,9 @@ def test_load(tfm): # Test loading just data ntfm = SpectralModel(verbose=False) - file_name_dat = 'test_model_dat' - ntfm.load(file_name_dat, TEST_DATA_PATH) - assert ntfm.data.power_spectrum is not None + ntfm.load('test_model_dat', TEST_DATA_PATH) + assert ntfm.data.has_data + assert np.array_equal(tfm.data.power_spectrum, ntfm.data.power_spectrum) # Test that settings and results are None for setting in tfm.algorithm.settings.names: assert getattr(ntfm.algorithm, setting) is None @@ -228,16 +227,12 @@ def test_load(tfm): # Test loading all elements ntfm = SpectralModel(verbose=False) - file_name_all = 'test_model_all' - ntfm.load(file_name_all, TEST_DATA_PATH) + ntfm.load('test_model_all', TEST_DATA_PATH) + assert compare_model_objs([tfm, ntfm], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) + for data in tfm.data._fields: + assert np.array_equal(getattr(tfm.data, data), getattr(ntfm.data, data)) for result in tfm.results._fields: assert not np.all(np.isnan(getattr(ntfm.results.params, result))) - for setting in tfm.algorithm.settings.names: - assert getattr(ntfm.algorithm, setting) is not None - for data in tfm.data._fields: - assert getattr(ntfm.data, data) is not None - for meta_dat in tfm.data._meta_fields: - assert getattr(ntfm.data, meta_dat) is not None def test_add_data(tresults): """Tests method to add data to model objects.""" diff --git a/specparam/tests/models/test_time.py b/specparam/tests/models/test_time.py index 8a5e7cf3..b286f2ed 100644 --- a/specparam/tests/models/test_time.py +++ b/specparam/tests/models/test_time.py @@ -9,6 +9,7 @@ import numpy as np from specparam.sim import sim_spectrogram +from specparam.models.utils import compare_model_objs from specparam.modutils.dependencies import safe_import pd = safe_import('pandas') @@ -78,26 +79,27 @@ def test_time_report(skip_if_no_mpl): assert tft -def test_time_load(): - - file_name_res = 'test_time_res' - file_name_set = 'test_time_set' - file_name_dat = 'test_time_dat' +def test_time_load(tft): # Test loading results - tft = SpectralTimeModel(verbose=False) - tft.load(file_name_res, TEST_DATA_PATH) - assert tft.results.time_results + ntft = SpectralTimeModel(verbose=False) + ntft.load('test_time_res', TEST_DATA_PATH) + assert ntft.results.time_results # Test loading settings - tft = SpectralTimeModel(verbose=False) - tft.load(file_name_set, TEST_DATA_PATH) - assert tft.algorithm.get_settings() + ntft = SpectralTimeModel(verbose=False) + ntft.load('test_time_set', TEST_DATA_PATH) + assert ntft.algorithm.get_settings() # Test loading data - tft = SpectralTimeModel(verbose=False) - tft.load(file_name_dat, TEST_DATA_PATH) - assert np.all(tft.data.power_spectra) + ntft = SpectralTimeModel(verbose=False) + ntft.load('test_time_dat', TEST_DATA_PATH) + assert np.all(ntft.data.spectrogram) + + # Test loading all elements + ntft = SpectralTimeModel(verbose=False) + ntft.load('test_time_all', TEST_DATA_PATH) + assert compare_model_objs([tft, ntft], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) def test_time_drop(): From d9101d487a3be7c0c4d3abf5d7e3cd2d0695579e Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 13 Apr 2025 01:47:46 -0400 Subject: [PATCH 10/40] update add_from_dict to clean up & fix algo settings --- specparam/models/base.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/specparam/models/base.py b/specparam/models/base.py index 6657d9d0..5ce949bf 100644 --- a/specparam/models/base.py +++ b/specparam/models/base.py @@ -180,9 +180,7 @@ def _add_from_dict(self, data): # Add additional attributes directly to object for key in data.keys(): - if getattr(self, key, False) is not False: - setattr(self, key, data[key]) + if getattr(self.algorithm, key, False) is not False: + setattr(self.algorithm, key, data[key]) elif getattr(self.data, key, False) is not False: setattr(self.data, key, data[key]) - elif getattr(self.results, key, False) is not False: - setattr(self.results, key, data[key]) From e0cb02686f9dcd884685b47916ada7f31cfdacce Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 13 Apr 2025 01:59:33 -0400 Subject: [PATCH 11/40] finish removing _check_loaded_results --- specparam/models/group.py | 1 - specparam/models/model.py | 1 - specparam/objs/results.py | 47 ++++++--------------------------------- 3 files changed, 7 insertions(+), 42 deletions(-) diff --git a/specparam/models/group.py b/specparam/models/group.py index d553cc29..4b71b111 100644 --- a/specparam/models/group.py +++ b/specparam/models/group.py @@ -233,7 +233,6 @@ def load(self, file_name, file_path=None): # If results part of current data added, check and update object results if set([el + '_params' for el in self.results._fields]).issubset(set(data.keys())): - self.results._check_loaded_results(data) self.results.group_results.append(self.results._get_results()) # Reconstruct frequency vector, if information is available to do so diff --git a/specparam/models/model.py b/specparam/models/model.py index fe08928e..4e7610ec 100644 --- a/specparam/models/model.py +++ b/specparam/models/model.py @@ -277,7 +277,6 @@ def load(self, file_name, file_path=None, regenerate=True): # Add loaded data to object and check loaded data self._add_from_dict(data) self.algorithm._check_loaded_settings(data) - self.results._check_loaded_results(data) # Regenerate model components, based on what is available if regenerate: diff --git a/specparam/objs/results.py b/specparam/objs/results.py index 82c5a20a..904c52ce 100644 --- a/specparam/objs/results.py +++ b/specparam/objs/results.py @@ -23,8 +23,7 @@ ################################################################################################### # Define set of results fields & default metrics to use -#RESULTS_FIELDS = ['aperiodic_params_', 'gaussian_params_', 'peak_params_'] -RESULTS_FIELDS = ['aperiodic', 'gaussian', 'peak'] +#RESULTS_FIELDS = ['aperiodic', 'gaussian', 'peak'] DEFAULT_METRICS = ['error_mae', 'gof_rsquared'] @@ -55,7 +54,8 @@ def __init__(self, modes=None, metrics=None, bands=None): # Initialize results attributes self._reset_results(True) - self._fields = RESULTS_FIELDS + #self._fields = RESULTS_FIELDS + self._fields = self.params._fields @property @@ -138,14 +138,14 @@ def add_results(self, results): A data object containing the results from fitting a power spectrum model. """ - # Add parameter fields and then select and add metrics results for pfield in self._fields: - setattr(self.params, pfield, getattr(results, pfield + '_params')) + params = getattr(results, pfield + '_params') + if 'peak' in pfield or 'gaussian' in pfield: + params = check_array_dim(params) + setattr(self.params, pfield, params) self.metrics.add_results(results.metrics) - self._check_loaded_results(results._asdict()) - def get_results(self): """Return model fit parameters and goodness of fit metrics. @@ -195,23 +195,6 @@ def get_params(self, name, field=None): return get_model_params(self.get_results(), self.modes, name, field) - # TODO: check / move to ModelParameters? - def _check_loaded_results(self, data): - """Check if results have been added and check data. - - Parameters - ---------- - data : dict - A dictionary of data that has been added to the object. - """ - - # If results loaded, check dimensions of peak parameters - # This fixes an issue where they end up the wrong shape if they are empty (no peaks) - if set(self._fields).issubset(set(data.keys())): - self.params.peak = check_array_dim(self.params.peak) - self.params.gaussian = check_array_dim(self.params.gaussian) - - def _reset_results(self, clear_results=False): """Set, or reset, results attributes to empty. @@ -222,22 +205,6 @@ def _reset_results(self, clear_results=False): """ if clear_results: - - # # Aperiodic parameters - # if self.modes: - # self.aperiodic_params_ = np.array([np.nan] * self.modes.aperiodic.n_params) - # else: - # self.aperiodic_params_ = np.nan - - # # Periodic parameters - # if self.modes: - # self.gaussian_params_ = np.empty([0, self.modes.periodic.n_params]) - # self.peak_params_ = np.empty([0, self.modes.periodic.n_params]) - # else: - # self.gaussian_params_ = np.nan - # self.peak_params_ = np.nan - - # Reset model parameters & components self.params.reset(self.modes) self.model.reset() From ab9d6968cf51c00e3bb363b914b641adda7ef9d5 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 13 Apr 2025 02:05:55 -0400 Subject: [PATCH 12/40] deprecate results fields / add as property to params --- specparam/io/models.py | 2 +- specparam/models/group.py | 2 +- specparam/objs/params.py | 7 +++++++ specparam/objs/results.py | 7 ++----- specparam/tests/io/test_models.py | 4 ++-- specparam/tests/models/test_group.py | 6 +++--- specparam/tests/models/test_model.py | 14 +++++++------- specparam/tests/objs/test_results.py | 2 +- 8 files changed, 24 insertions(+), 20 deletions(-) diff --git a/specparam/io/models.py b/specparam/io/models.py index f14dd826..8aa33955 100644 --- a/specparam/io/models.py +++ b/specparam/io/models.py @@ -63,7 +63,7 @@ def save_model(model, file_name, file_path=None, append=False, # Convert results & metrics to saveable information results_labels = [] - for rfield in model.results._fields: + for rfield in model.results.params.fields: results_labels.append(rfield + '_params') obj_dict[rfield + '_params'] = getattr(model.results.params, rfield) obj_dict['metrics'] = model.results.metrics.results diff --git a/specparam/models/group.py b/specparam/models/group.py index 4b71b111..f58c2cc0 100644 --- a/specparam/models/group.py +++ b/specparam/models/group.py @@ -232,7 +232,7 @@ def load(self, file_name, file_path=None): self.algorithm._check_loaded_settings(data) # If results part of current data added, check and update object results - if set([el + '_params' for el in self.results._fields]).issubset(set(data.keys())): + if set([el + '_params' for el in self.results.params.fields]).issubset(set(data.keys())): self.results.group_results.append(self.results._get_results()) # Reconstruct frequency vector, if information is available to do so diff --git a/specparam/objs/params.py b/specparam/objs/params.py index aeb4ddbf..185a11f4 100644 --- a/specparam/objs/params.py +++ b/specparam/objs/params.py @@ -49,3 +49,10 @@ def reset(self, modes=None): else: self.gaussian = np.nan self.peak = np.nan + + + @property + def fields(self): + """Alias as a property attribute the list of fields.""" + + return list(vars(self).keys()) diff --git a/specparam/objs/results.py b/specparam/objs/results.py index 904c52ce..4dd88746 100644 --- a/specparam/objs/results.py +++ b/specparam/objs/results.py @@ -23,7 +23,6 @@ ################################################################################################### # Define set of results fields & default metrics to use -#RESULTS_FIELDS = ['aperiodic', 'gaussian', 'peak'] DEFAULT_METRICS = ['error_mae', 'gof_rsquared'] @@ -54,8 +53,6 @@ def __init__(self, modes=None, metrics=None, bands=None): # Initialize results attributes self._reset_results(True) - #self._fields = RESULTS_FIELDS - self._fields = self.params._fields @property @@ -138,7 +135,7 @@ def add_results(self, results): A data object containing the results from fitting a power spectrum model. """ - for pfield in self._fields: + for pfield in self.params.fields: params = getattr(results, pfield + '_params') if 'peak' in pfield or 'gaussian' in pfield: params = check_array_dim(params) @@ -157,7 +154,7 @@ def get_results(self): """ results = FitResults( - **{key + '_params' : getattr(self.params, key) for key in self._fields}, + **{key + '_params' : getattr(self.params, key) for key in self.params.fields}, metrics=self.metrics.results) return results diff --git a/specparam/tests/io/test_models.py b/specparam/tests/io/test_models.py index 90e5e303..0160fe40 100644 --- a/specparam/tests/io/test_models.py +++ b/specparam/tests/io/test_models.py @@ -157,7 +157,7 @@ def test_load_file_contents(tfm): assert 'bands' in loaded_data.keys() for setting in tfm.algorithm.settings.names: assert setting in loaded_data.keys() - for result in tfm.results._fields: + for result in tfm.results.params.fields: assert result + '_params' in loaded_data.keys() assert 'metrics' in loaded_data.keys() for datum in tfm.data._fields: @@ -171,7 +171,7 @@ def test_load_model(tfm): compare_model_objs([tfm, ntfm], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) for data in tfm.data._fields: assert np.array_equal(getattr(tfm.data, data), getattr(ntfm.data, data)) - for result in tfm.results._fields: + for result in tfm.results.params.fields: assert not np.all(np.isnan(getattr(ntfm.results.params, result))) assert tfm.results.metrics.results == ntfm.results.metrics.results diff --git a/specparam/tests/models/test_group.py b/specparam/tests/models/test_group.py index 586501df..a8373c68 100644 --- a/specparam/tests/models/test_group.py +++ b/specparam/tests/models/test_group.py @@ -289,7 +289,7 @@ def test_load(tfg): for setting in tfg.algorithm.settings.names: assert getattr(tfg.algorithm, setting) == getattr(ntfg.algorithm, setting) # Test that results and data are None - for result in tfg.results._fields: + for result in tfg.results.params.fields: assert np.all(np.isnan(getattr(ntfg.results.params, result))) assert ntfg.data.power_spectra is None @@ -300,7 +300,7 @@ def test_load(tfg): # Test that settings and results are None for setting in tfg.algorithm.settings.names: assert getattr(ntfg.algorithm, setting) is None - for result in tfg.results._fields: + for result in tfg.results.params.fields: assert np.all(np.isnan(getattr(ntfg.results.params, result))) # Test loading all elements @@ -343,7 +343,7 @@ def test_get_model(tfg): tfm1 = tfg.get_model(1, True) assert tfm1 # Check that regenerated model is created - for result in tfg.results._fields: + for result in tfg.results.params.fields: assert np.all(getattr(tfm1.results.params, result)) # Test when object has no data (clear a copy of tfg) diff --git a/specparam/tests/models/test_model.py b/specparam/tests/models/test_model.py index 98c5d00e..2fc736b7 100644 --- a/specparam/tests/models/test_model.py +++ b/specparam/tests/models/test_model.py @@ -197,7 +197,7 @@ def test_load(tfm): ntfm = SpectralModel(verbose=False) ntfm.load('test_model_res', TEST_DATA_PATH) # Check that result attributes get filled - for result in tfm.results._fields: + for result in tfm.results.params.fields: assert not np.all(np.isnan(getattr(ntfm.results.params, result))) # Test that settings and data are None for setting in tfm.algorithm.settings.names: @@ -210,7 +210,7 @@ def test_load(tfm): for setting in tfm.algorithm.settings.names: assert getattr(tfm.algorithm, setting) == getattr(ntfm.algorithm, setting) # Test that results and data are None - for result in tfm.results._fields: + for result in tfm.results.params.fields: assert np.all(np.isnan(getattr(ntfm.results.params, result))) assert ntfm.data.power_spectrum is None @@ -222,7 +222,7 @@ def test_load(tfm): # Test that settings and results are None for setting in tfm.algorithm.settings.names: assert getattr(ntfm.algorithm, setting) is None - for result in tfm.results._fields: + for result in tfm.results.params.fields: assert np.all(np.isnan(getattr(ntfm.results.params, result))) # Test loading all elements @@ -231,7 +231,7 @@ def test_load(tfm): assert compare_model_objs([tfm, ntfm], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) for data in tfm.data._fields: assert np.array_equal(getattr(tfm.data, data), getattr(ntfm.data, data)) - for result in tfm.results._fields: + for result in tfm.results.params.fields: assert not np.all(np.isnan(getattr(ntfm.results.params, result))) def test_add_data(tresults): @@ -323,7 +323,7 @@ def test_resets(): assert getattr(tfm.data, field) is None for key, value in tfm.results.model.__dict__.items(): assert value is None - for field in tfm.results._fields: + for field in tfm.results.params.fields: assert np.all(np.isnan(getattr(tfm.results.params, field))) assert tfm.data.freqs is None and tfm.results.model.modeled_spectrum is None @@ -345,7 +345,7 @@ def test_fit_failure(): tfm.fit(*sim_power_spectrum(*default_spectrum_params())) # Check after failing out of fit, all results are reset - for result in tfm.results._fields: + for result in tfm.results.params.fields: assert np.all(np.isnan(getattr(tfm.results.params, result))) ## Monkey patch to check errors in general @@ -359,7 +359,7 @@ def raise_runtime_error(*args, **kwargs): tfm.fit(*sim_power_spectrum(*default_spectrum_params())) # Check after failing out of fit, all results are reset - for result in tfm.results._fields: + for result in tfm.results.params.fields: assert np.all(np.isnan(getattr(tfm.results.params, result))) def test_debug(): diff --git a/specparam/tests/objs/test_results.py b/specparam/tests/objs/test_results.py index c5ac847e..c34bee29 100644 --- a/specparam/tests/objs/test_results.py +++ b/specparam/tests/objs/test_results.py @@ -18,7 +18,7 @@ def test_results_results(tresults): tres.add_results(tresults) assert tres.has_model - for result in tres._fields: + for result in tres.params.fields: assert np.array_equal(getattr(tres.params, result), getattr(tresults, result + '_params')) results_out = tres.get_results() From e4656d9ca605bc8e3b84d076abc90329c6f3de2f Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 13 Apr 2025 02:09:54 -0400 Subject: [PATCH 13/40] bands repr -> str --- specparam/bands/bands.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/specparam/bands/bands.py b/specparam/bands/bands.py index 1d3ba29a..6e04c5f6 100644 --- a/specparam/bands/bands.py +++ b/specparam/bands/bands.py @@ -60,7 +60,7 @@ def __getitem__(self, label): raise ValueError(message) from None - def __repr__(self): + def __str__(self): """Define the string representation as a printout of the band information.""" return '\n'.join(['{:8} : {:2} - {:2} Hz'.format(key, *val) \ From d0b883142f196b7f9a1949c5ab1065c3c744a27a Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 13 Apr 2025 02:14:42 -0400 Subject: [PATCH 14/40] drop trailing underscores --- specparam/objs/results.py | 10 +++++----- specparam/plts/annotate.py | 6 +++--- specparam/reports/strings.py | 4 ++-- specparam/tests/models/test_event.py | 4 ++-- specparam/tests/models/test_group.py | 4 ++-- specparam/tests/models/test_model.py | 4 ++-- specparam/tests/models/test_time.py | 4 ++-- 7 files changed, 18 insertions(+), 18 deletions(-) diff --git a/specparam/objs/results.py b/specparam/objs/results.py index 4dd88746..81655434 100644 --- a/specparam/objs/results.py +++ b/specparam/objs/results.py @@ -71,7 +71,7 @@ def has_model(self): @property - def n_peaks_(self): + def n_peaks(self): """How many peaks were fit in the model.""" n_peaks = None @@ -82,12 +82,12 @@ def n_peaks_(self): @property - def n_params_(self): + def n_params(self): """The total number of parameters fit in the model.""" n_params = None if self.has_model: - n_peak_params = self.modes.periodic.n_params * self.n_peaks_ + n_peak_params = self.modes.periodic.n_params * self.n_peaks n_params = n_peak_params + self.modes.aperiodic.n_params return n_params @@ -282,7 +282,7 @@ def has_model(self): @property - def n_peaks_(self): + def n_peaks(self): """How many peaks were fit for each model.""" n_peaks = None @@ -489,7 +489,7 @@ def has_model(self): @property - def n_peaks_(self): + def n_peaks(self): """How many peaks were fit for each model, for each event.""" n_peaks = None diff --git a/specparam/plts/annotate.py b/specparam/plts/annotate.py index e2775fb7..05670214 100644 --- a/specparam/plts/annotate.py +++ b/specparam/plts/annotate.py @@ -43,7 +43,7 @@ def plot_annotated_peak_search(model): model.results.params.gaussian[:, 1].argsort()][::-1] # Loop through the iterative search for each peak - for ind in range(model.results.n_peaks_ + 1): + for ind in range(model.results.n_peaks + 1): # This forces the creation of a new plotting axes per iteration ax = check_ax(None, PLT_FIGSIZES['spectral']) @@ -65,7 +65,7 @@ def plot_annotated_peak_search(model): ax.set_ylim(ylims) ax.set_title('Iteration #' + str(ind+1), fontsize=16) - if ind < model.results.n_peaks_: + if ind < model.results.n_peaks: gauss = model.modes.periodic.func(model.data.freqs, *gaussian_params[ind, :]) plot_spectra(model.data.freqs, gauss, ax=ax, label='Gaussian Fit', @@ -136,7 +136,7 @@ def plot_annotated_model(model, plt_log=False, annotate_peaks=True, # See: https://github.com/matplotlib/matplotlib/issues/12820. Fixed in 3.2.1. bug_buff = 0.000001 - if annotate_peaks and model.results.n_peaks_: + if annotate_peaks and model.results.n_peaks: # Extract largest peak, to annotate, grabbing gaussian params gauss = get_band_peak(model, model.data.freq_range, attribute='gaussian') diff --git a/specparam/reports/strings.py b/specparam/reports/strings.py index 8f0a0e0c..6c663d62 100644 --- a/specparam/reports/strings.py +++ b/specparam/reports/strings.py @@ -393,7 +393,7 @@ def gen_model_results_str(model, concise=False): # Peak parameters 'Peak Parameters (\'{}\' mode) {} peaks found'.format(\ - model.modes.periodic.name, model.results.n_peaks_), + model.modes.periodic.name, model.results.n_peaks), *[peak_str.format(*op) for op in model.results.params.peak], '', @@ -455,7 +455,7 @@ def gen_group_results_str(group, concise=False): # Peak Parameters 'Peak Parameters (\'{}\' mode) {} total peaks found'.format(\ - group.modes.periodic.name, sum(group.results.n_peaks_)), + group.modes.periodic.name, sum(group.results.n_peaks)), '', # Metrics diff --git a/specparam/tests/models/test_event.py b/specparam/tests/models/test_event.py index dba44547..71b8c5dc 100644 --- a/specparam/tests/models/test_event.py +++ b/specparam/tests/models/test_event.py @@ -42,8 +42,8 @@ def test_event_iter(tfe): def test_event_n_properties(tfe): - assert np.all(tfe.results.n_peaks_) - assert np.all(tfe.results.n_params_) + assert np.all(tfe.results.n_peaks) + assert np.all(tfe.results.n_params) def test_event_fit(): diff --git a/specparam/tests/models/test_group.py b/specparam/tests/models/test_group.py index a8373c68..0f7c0b44 100644 --- a/specparam/tests/models/test_group.py +++ b/specparam/tests/models/test_group.py @@ -65,8 +65,8 @@ def test_has_model(tfg): def test_n_properties(tfg): """Test the n_peaks & n_params property attributes.""" - assert np.all(tfg.results.n_peaks_) - assert np.all(tfg.results.n_params_) + assert np.all(tfg.results.n_peaks) + assert np.all(tfg.results.n_params) def test_n_null(tfg): """Test the n_null_ property attribute.""" diff --git a/specparam/tests/models/test_model.py b/specparam/tests/models/test_model.py index 2fc736b7..85299bdc 100644 --- a/specparam/tests/models/test_model.py +++ b/specparam/tests/models/test_model.py @@ -57,8 +57,8 @@ def test_has_model(tfm): def test_n_properties(tfm): - assert tfm.results.n_peaks_ - assert tfm.results.n_params_ + assert tfm.results.n_peaks + assert tfm.results.n_params def test_fit_nk(): """Test fit, no knee.""" diff --git a/specparam/tests/models/test_time.py b/specparam/tests/models/test_time.py index b286f2ed..14d10bd7 100644 --- a/specparam/tests/models/test_time.py +++ b/specparam/tests/models/test_time.py @@ -41,8 +41,8 @@ def test_time_iter(tft): def test_time_n_properties(tft): - assert np.all(tft.results.n_peaks_) - assert np.all(tft.results.n_params_) + assert np.all(tft.results.n_peaks) + assert np.all(tft.results.n_params) def test_time_fit(): From 23c134612b2d009c8bcd4d653f9798d7cb35733d Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 13 Apr 2025 02:17:25 -0400 Subject: [PATCH 15/40] remove trailing underscores - null --- specparam/objs/results.py | 4 ++-- specparam/reports/strings.py | 2 +- specparam/tests/models/test_group.py | 12 ++++++------ 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/specparam/objs/results.py b/specparam/objs/results.py index 81655434..e4966229 100644 --- a/specparam/objs/results.py +++ b/specparam/objs/results.py @@ -293,7 +293,7 @@ def n_peaks(self): @property - def n_null_(self): + def n_null(self): """How many model fits are null.""" n_null = None @@ -304,7 +304,7 @@ def n_null_(self): @property - def null_inds_(self): + def null_inds(self): """The indices for model fits that are null.""" null_inds = None diff --git a/specparam/reports/strings.py b/specparam/reports/strings.py index 6c663d62..c774e0ab 100644 --- a/specparam/reports/strings.py +++ b/specparam/reports/strings.py @@ -651,7 +651,7 @@ def _report_str_n_null(model): output = \ [el for el in ['{} power spectra failed to fit'.format(\ - model.results.n_null_)] if model.results.n_null_] + model.results.n_null)] if model.results.n_null] return output diff --git a/specparam/tests/models/test_group.py b/specparam/tests/models/test_group.py index 0f7c0b44..aec3bea1 100644 --- a/specparam/tests/models/test_group.py +++ b/specparam/tests/models/test_group.py @@ -69,16 +69,16 @@ def test_n_properties(tfg): assert np.all(tfg.results.n_params) def test_n_null(tfg): - """Test the n_null_ property attribute.""" + """Test the n_null property attribute.""" # Since there should have been no failed fits, this should return 0 - assert tfg.results.n_null_ == 0 + assert tfg.results.n_null == 0 def test_null_inds(tfg): - """Test the null_inds_ property attribute.""" + """Test the null_inds property attribute.""" # Since there should be no failed fits, this should return an empty list - assert tfg.results.null_inds_ == [] + assert tfg.results.null_inds == [] def test_fit_nk(): """Test group fit, no knee.""" @@ -175,8 +175,8 @@ def test_fg_fail(): # Test the property attributes related to null model fits # This checks that they do the right thing when there are null fits (failed fits) - assert ntfg.results.n_null_ > 0 - assert ntfg.results.null_inds_ + assert ntfg.results.n_null > 0 + assert ntfg.results.null_inds def test_drop(): """Test function to drop results from group object.""" From d1619924eaf277099674a4bdaac45457002c64ec Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 13 Apr 2025 11:53:11 -0400 Subject: [PATCH 16/40] clean up data checks --- specparam/models/model.py | 2 +- specparam/models/utils.py | 4 ++-- specparam/objs/data.py | 19 +++++++++++-------- specparam/tests/models/test_model.py | 4 ++-- specparam/tests/objs/test_data.py | 9 ++++----- 5 files changed, 20 insertions(+), 18 deletions(-) diff --git a/specparam/models/model.py b/specparam/models/model.py index 4e7610ec..fbda655e 100644 --- a/specparam/models/model.py +++ b/specparam/models/model.py @@ -160,7 +160,7 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None, prechecks=True): # If not set to fail on NaN or Inf data at add time, check data here # This serves as a catch all for curve_fits which will fail given NaN or Inf # Because FitError's are by default caught, this allows fitting to continue - if not self.data._check_data: + if not self.data.checks['data']: if np.any(np.isinf(self.data.power_spectrum)) or \ np.any(np.isnan(self.data.power_spectrum)): raise FitError("Model fitting was skipped because there are NaN or Inf " diff --git a/specparam/models/utils.py b/specparam/models/utils.py index dd7bb64d..e4b6cbef 100644 --- a/specparam/models/utils.py +++ b/specparam/models/utils.py @@ -267,8 +267,8 @@ def combine_model_objs(model_objs): # Set the status for freqs & data checking # Check states gets set as True if any of the inputs have it on, False otherwise group.data.set_checks(\ - check_freqs=any(getattr(m_obj.data, '_check_freqs') for m_obj in model_objs), - check_data=any(getattr(m_obj.data, '_check_data') for m_obj in model_objs)) + check_freqs=any(m_obj.data.checks['freqs'] for m_obj in model_objs), + check_data=any(m_obj.data.checks['data'] for m_obj in model_objs)) # Add data information information group.data.add_meta_data(model_objs[0].data.get_meta_data()) diff --git a/specparam/objs/data.py b/specparam/objs/data.py index 094fe653..e556f480 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -46,6 +46,8 @@ class Data(): Frequency range of the power spectrum, as [lowest_freq, highest_freq]. freq_res : float Frequency resolution of the power spectrum. + checks : dict + Specifiers for which aspects of the data to run checks on. """ def __init__(self, check_freqs=True, check_data=True, format='power'): @@ -55,9 +57,10 @@ def __init__(self, check_freqs=True, check_data=True, format='power'): self._fields = DATA_FIELDS self._meta_fields = META_DATA_FIELDS - # Define data check run statuses - self._check_freqs = check_freqs - self._check_data = check_data + self.checks = { + 'freqs' : check_freqs, + 'data' : check_data, + } check_input_options(format, FORMATS, 'format') self.format = format @@ -120,7 +123,7 @@ def get_checks(self): Object containing the check statuses from the current object. """ - return ModelChecks(**{key : getattr(self, '_' + key) for key in ModelChecks._fields}) + return ModelChecks(**{'check_' + key : value for key, value in self.checks.items()}) def get_meta_data(self): @@ -156,9 +159,9 @@ def set_checks(self, check_freqs=None, check_data=None): """ if check_freqs is not None: - self._check_freqs = check_freqs + self.checks['freqs'] = check_freqs if check_data is not None: - self._check_data = check_data + self.checks['data'] = check_data def _reset_data(self, clear_freqs=False, clear_spectrum=False): @@ -270,13 +273,13 @@ def _prepare_data(self, freqs, powers, freq_range, spectra_dim=1): ## Data checks - run checks on inputs based on check statuses - if self._check_freqs: + if self.checks['freqs']: # Check if the frequency data is unevenly spaced, and raise an error if so freq_diffs = np.diff(freqs) if not np.all(np.isclose(freq_diffs, freq_res)): raise DataError("The input frequency values are not evenly spaced. " "The model expects equidistant frequency values in linear space.") - if self._check_data: + if self.checks['data']: # Check if there are any infs / nans, and raise an error if so if np.any(np.isinf(powers)) or np.any(np.isnan(powers)): error_msg = ("The input power spectra data, after logging, contains NaNs or Infs. " diff --git a/specparam/tests/models/test_model.py b/specparam/tests/models/test_model.py index 85299bdc..fdabcd77 100644 --- a/specparam/tests/models/test_model.py +++ b/specparam/tests/models/test_model.py @@ -399,8 +399,8 @@ def test_set_checks(): # Reset checks to true tfm.data.set_checks(True, True) - assert tfm.data._check_freqs is True - assert tfm.data._check_data is True + assert tfm.data.checks['freqs'] is True + assert tfm.data.checks['data'] is True def test_to_df(tfm, tbands, skip_if_no_pandas): diff --git a/specparam/tests/objs/test_data.py b/specparam/tests/objs/test_data.py index 42982fb8..9b8593c0 100644 --- a/specparam/tests/objs/test_data.py +++ b/specparam/tests/objs/test_data.py @@ -44,15 +44,14 @@ def test_data_get_set_checks(tdata): tdata.set_checks(False, False) tchecks1 = tdata.get_checks() assert isinstance(tchecks1, ModelChecks) - assert tdata._check_freqs == tchecks1.check_freqs == False - assert tdata._check_data == tchecks1.check_data == False + assert tdata.checks['freqs'] == tchecks1.check_freqs == False + assert tdata.checks['data'] == tchecks1.check_data == False tdata.set_checks(True, True) tchecks2 = tdata.get_checks() assert isinstance(tchecks2, ModelChecks) - assert tdata._check_freqs == tchecks2.check_freqs == True - assert tdata._check_data == tchecks2.check_data == True - + assert tdata.checks['freqs'] == tchecks2.check_freqs == True + assert tdata.checks['data'] == tchecks2.check_data == True @plot_test def test_data_plot(tdata, skip_if_no_mpl): From 08c8961c2ea8709b83defa2156d8f9232e79b105 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 13 Apr 2025 12:25:50 -0400 Subject: [PATCH 17/40] sweep for docstrings content --- specparam/models/event.py | 15 ++++---- specparam/models/group.py | 18 +++------ specparam/models/model.py | 12 +----- specparam/models/time.py | 20 ++++------ specparam/objs/data.py | 79 +++++++++++++++++++++++---------------- specparam/objs/results.py | 45 ++++++++++++++++++++-- 6 files changed, 112 insertions(+), 77 deletions(-) diff --git a/specparam/models/event.py b/specparam/models/event.py index aa19a5d6..319b74cb 100644 --- a/specparam/models/event.py +++ b/specparam/models/event.py @@ -27,10 +27,8 @@ class SpectralTimeEventModel(SpectralTimeModel): """Model a set of event as a combination of aperiodic and periodic components. - WARNING: frequency and power values inputs must be in linear space. - - Passing in logged frequencies and/or power spectra is not detected, - and will silently produce incorrect results. + WARNING: frequency and power values inputs must be in linear space. Passing in logged + frequencies and/or power spectra is not detected, and will silently produce incorrect results. Parameters ---------- @@ -43,9 +41,12 @@ class SpectralTimeEventModel(SpectralTimeModel): Notes ----- % copied in from SpectralModel object - - The event object inherits from the time model, which in turn inherits from the - group object, etc. As such it also has data attributes defined on the underlying - objects (see notes and attribute lists in inherited objects for details). + - The event object inherits from the time model, overwriting the `data` and + `results` objects with versions for fitting models across events. + Event related, temporally organized results are collected into the + `results.event_time_results` attribute, which may include sub-selecting peaks + per band (depending on settings). Note that the `results.event_group_results` attribute + is also available, which maintains the full model results. """ def __init__(self, *args, **kwargs): diff --git a/specparam/models/group.py b/specparam/models/group.py index f58c2cc0..56ff1fbb 100644 --- a/specparam/models/group.py +++ b/specparam/models/group.py @@ -31,10 +31,8 @@ class SpectralGroupModel(SpectralModel): """Model a group of power spectra as a combination of aperiodic and periodic components. - WARNING: frequency and power values inputs must be in linear space. - - Passing in logged frequencies and/or power spectra is not detected, - and will silently produce incorrect results. + WARNING: frequency and power values inputs must be in linear space. Passing in logged + frequencies and/or power spectra is not detected, and will silently produce incorrect results. Parameters ---------- @@ -47,14 +45,10 @@ class SpectralGroupModel(SpectralModel): Notes ----- % copied in from SpectralModel object - - The group object inherits from the model object. As such it also has data - attributes (`power_spectrum` & `modeled_spectrum`), and parameter attributes - (`aperiodic_params_`, `peak_params_`, `gaussian_params_`, `r_squared_`, `error_`) - which are defined in the context of individual model fits. These attributes are - used during the fitting process, but in the group context do not store results - post-fitting. Rather, all model fit results are collected and stored into the - `group_results` attribute. To access individual parameters of the fit, use - the `get_params` method. + - The group object inherits from the model object, and in doing so overwrites the + `data` and `results` objects with versions for fitting groups of power spectra. + All model fit results are collected and stored in the `results.group_results` attribute. + To access individual parameters of the fit, use the `get_params` method. """ def __init__(self, *args, **kwargs): diff --git a/specparam/models/model.py b/specparam/models/model.py index fbda655e..2324d26c 100644 --- a/specparam/models/model.py +++ b/specparam/models/model.py @@ -28,10 +28,8 @@ class SpectralModel(BaseModel): """Model a power spectrum as a combination of aperiodic and periodic components. - WARNING: frequency and power values inputs must be in linear space. - - Passing in logged frequencies and/or power spectra is not detected, - and will silently produce incorrect results. + WARNING: frequency and power values inputs must be in linear space. Passing in logged + frequencies and/or power spectra is not detected, and will silently produce incorrect results. Parameters ---------- @@ -64,12 +62,6 @@ class SpectralModel(BaseModel): For example, raw FFT inputs are not appropriate. Where possible and appropriate, use longer time segments for power spectrum calculation to get smoother power spectra, as this will give better model fits. - - Commonly used abbreviations used in this module include: - CF: center frequency, PW: power, BW: Bandwidth, AP: aperiodic - - The gaussian params are those that define the gaussian of the fit, where as the peak - params are a modified version, in which the CF of the peak is the mean of the gaussian, - the PW of the peak is the height of the gaussian over and above the aperiodic component, - and the BW of the peak, is 2*std of the gaussian (as 'two sided' bandwidth). """ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_height=0.0, diff --git a/specparam/models/time.py b/specparam/models/time.py index fd626407..5540cc0e 100644 --- a/specparam/models/time.py +++ b/specparam/models/time.py @@ -22,10 +22,8 @@ class SpectralTimeModel(SpectralGroupModel): """Model a spectrogram as a combination of aperiodic and periodic components. - WARNING: frequency and power values inputs must be in linear space. - - Passing in logged frequencies and/or power spectra is not detected, - and will silently produce incorrect results. + WARNING: frequency and power values inputs must be in linear space. Passing in logged + frequencies and/or power spectra is not detected, and will silently produce incorrect results. Parameters ---------- @@ -38,14 +36,12 @@ class SpectralTimeModel(SpectralGroupModel): Notes ----- % copied in from SpectralModel object - - The time object inherits from the group model, which in turn inherits from the - model object. As such it also has data attributes defined on the model object, - as well as additional attributes that are added to the group object (see notes - and attribute list in SpectralGroupModel). - - Notably, while this object organizes the results into the `time_results` - attribute, which may include sub-selecting peaks per band (depending on settings) - the `group_results` attribute is also available, which maintains the full - model results. + - The time object inherits from the group model, overwriting the `data` and + `results` objects with versions for fitting models across time. Temporally + organized results are collected into the `results.time_results` attribute, + which may include sub-selecting peaks per band (depending on settings). + Note that the `results.group_results` attribute is also available, which maintains + the full model results. """ def __init__(self, *args, **kwargs): diff --git a/specparam/objs/data.py b/specparam/objs/data.py index e556f480..ae96dbb7 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -9,6 +9,7 @@ from specparam.utils.spectral import trim_spectrum from specparam.utils.checks import check_input_options from specparam.modutils.errors import DataError, InconsistentDataError +from specparam.modutils.docs import docs_get_section, replace_docstring_sections from specparam.plts.settings import PLT_COLORS from specparam.plts.spectra import plot_spectra, plot_spectrogram from specparam.plts.utils import check_plot_kwargs @@ -28,26 +29,28 @@ class Data(): Parameters ---------- check_freqs : bool - Whether to check the frequency values. - If True, checks the frequency values, and raises an error for uneven spacing. + Whether to check the frequency values. If so, raises an error for uneven spacing. check_data : bool - Whether to check the power spectrum values. - If True, checks the power values and raises an error for any NaN / Inf values. + Whether to check the spectral data. If so, raises an error for any NaN / Inf values. format : {'power'} The representation format of the data. Attributes ---------- + checks : dict + Specifiers for which aspects of the data to run checks on. freqs : 1d array - Frequency values for the power spectrum. - power_spectrum : 1d array - Power values, stored internally in log10 scale. + Frequency values for the spectral data. freq_range : list of [float, float] - Frequency range of the power spectrum, as [lowest_freq, highest_freq]. + Frequency range of the spectral data, as [lowest_freq, highest_freq]. freq_res : float - Frequency resolution of the power spectrum. - checks : dict - Specifiers for which aspects of the data to run checks on. + Frequency resolution of the spectral data. + power_spectrum : 1d array + Power values. + + Notes + ----- + All power values are stored internally in log10 scale. """ def __init__(self, check_freqs=True, check_data=True, format='power'): @@ -291,20 +294,24 @@ def _prepare_data(self, freqs, powers, freq_range, spectra_dim=1): return freqs, powers, freq_range, freq_res +@replace_docstring_sections([docs_get_section(Data.__doc__, 'Parameters'), + docs_get_section(Data.__doc__, 'Attributes')]) class Data2D(Data): """Base object for managing data for spectral parameterization - for 2D data. + Parameters + ---------- + % copied in from Data + Attributes ---------- - freqs : 1d array - Frequency values for the power spectra. + % copied in from Data power_spectra : 2d array Power values for the group of power spectra, as [n_power_spectra, n_freqs]. - Power values are stored internally in log10 scale. - freq_range : list of [float, float] - Frequency range of the power spectra, as [lowest_freq, highest_freq]. - freq_res : float - Frequency resolution of the power spectra. + + Notes + ----- + All power values are stored internally in log10 scale. """ def __init__(self): @@ -388,20 +395,24 @@ def decorated(*args, **kwargs): return decorated +@replace_docstring_sections([docs_get_section(Data.__doc__, 'Parameters'), + docs_get_section(Data2D.__doc__, 'Attributes')]) class Data2DT(Data2D): """Base object for managing data for spectral parameterization - for 2D transposed data. + Parameters + ---------- + % copied in from Data + Attributes ---------- - freqs : 1d array - Frequency values for the spectrogram. + % copied in from Data2D spectrogram : 2d array Power values for the spectrogram, as [n_freqs, n_time_windows]. - Power values are stored internally in log10 scale. - freq_range : list of [float, float] - Frequency range of the spectrogram, as [lowest_freq, highest_freq]. - freq_res : float - Frequency resolution of the spectrogram. + + Notes + ----- + All power values are stored internally in log10 scale. """ def __init__(self): @@ -454,20 +465,24 @@ def plot(self, **plt_kwargs): plot_spectrogram(self.freqs, self.spectrogram, **plot_kwargs) +@replace_docstring_sections([docs_get_section(Data.__doc__, 'Parameters'), + docs_get_section(Data2DT.__doc__, 'Attributes')]) class Data3D(Data2DT): """Base object for managing data for spectral parameterization - for 3D data. + Parameters + ---------- + % copied in from Data + Attributes ---------- - freqs : 1d array - Frequency values for the power spectra. + % copied in from Data2DT spectrograms : 3d array Power values for the spectrograms, organized as [n_events, n_freqs, n_time_windows]. - Power values are stored internally in log10 scale. - freq_range : list of [float, float] - Frequency range of the power spectra, as [lowest_freq, highest_freq]. - freq_res : float - Frequency resolution of the power spectra. + + Notes + ----- + All power values are stored internally in log10 scale. """ def __init__(self): diff --git a/specparam/objs/results.py b/specparam/objs/results.py index e4966229..9d75a215 100644 --- a/specparam/objs/results.py +++ b/specparam/objs/results.py @@ -35,8 +35,21 @@ class Results(): Modes object with fit mode definitions. metrics : Metrics Metrics object with metric definitions. - bands : bands + bands : Bands Bands object with band definitions. + + Attributes + ---------- + modes : Modes + Modes object with fit mode definitions. + bands : Bands + Bands object with band definitions. + model : ModelComponents + Manages the model fit and components. + params : ModelParameters + Manages the model fit parameters. + metrics : Metrics + Metrics object with metric definitions. """ # pylint: disable=attribute-defined-outside-init, arguments-differ @@ -220,13 +233,20 @@ def _regenerate_model(self, freqs): self.modes.periodic, self.params.gaussian, return_components=True) -@replace_docstring_sections([docs_get_section(Results.__doc__, 'Parameters')]) +@replace_docstring_sections([docs_get_section(Results.__doc__, 'Parameters'), + docs_get_section(Results.__doc__, 'Attributes')]) class Results2D(Results): """Object for managing results - 2D version. Parameters ---------- % copied in from Results + + Attributes + ---------- + % copied in from Results + group_results : list of FitResults + Results of the model fit for each power spectrum. """ def __init__(self, modes=None, metrics=None, bands=None): @@ -386,13 +406,20 @@ def get_params(self, name, field=None): return get_group_params(self.group_results, self.modes, name, field) -@replace_docstring_sections([docs_get_section(Results.__doc__, 'Parameters')]) +@replace_docstring_sections([docs_get_section(Results.__doc__, 'Parameters'), + docs_get_section(Results2D.__doc__, 'Attributes')]) class Results2DT(Results2D): """Object for managing results - 2D transpose version. Parameters ---------- % copied in from Results + + Attributes + ---------- + % copied in from Results2D + time_results : dict + Results of the model fit across each time window. """ def __init__(self, modes=None, metrics=None, bands=None): @@ -445,13 +472,23 @@ def convert_results(self): self.time_results = group_to_dict(self.group_results, self.modes, self.bands) -@replace_docstring_sections([docs_get_section(Results.__doc__, 'Parameters')]) +@replace_docstring_sections([docs_get_section(Results.__doc__, 'Parameters'), + docs_get_section(Results2DT.__doc__, 'Attributes')]) class Results3D(Results2DT): """Object for managing results - 3D version. Parameters ---------- % copied in from Results + + Attributes + ---------- + % copied in from Results2DT + event_group_results : list of list of FitResults + Full model results collected across all events and models. + event_time_results : dict + Results of the model fit across each time window, collected across events. + Each value in the dictionary stores a model fit parameter, as [n_events, n_time_windows]. """ def __init__(self, modes=None, metrics=None, bands=None): From eb8d39cb26bba272fd71f018d0d43bd4d43d81f6 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 14 Apr 2025 00:52:37 -0400 Subject: [PATCH 18/40] tweak / update SettingsDefinitions --- specparam/algorithms/settings.py | 40 ++++++++++++++++----- specparam/tests/algorithms/test_settings.py | 11 +++--- 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/specparam/algorithms/settings.py b/specparam/algorithms/settings.py index 1fa601c0..bb55a744 100644 --- a/specparam/algorithms/settings.py +++ b/specparam/algorithms/settings.py @@ -10,43 +10,60 @@ class SettingsDefinition(): Parameters ---------- - settings : dict + definitions : dict Settings definition. Each key should be a str name of a setting. Each value should be a dictionary with keys 'type' and 'description', with str values. + + Attributes + ---------- + names : list of str + Names of the settings defined in the object. + descriptions : dict of {str : str} + Description of each setting. + types : dict of {str : str} + Type for each setting. + values : dict of {str : object} + Value of each setting. """ - def __init__(self, settings): + def __init__(self, definitions): """Initialize settings definition.""" - self._settings = settings + self._definitions = definitions + + + def __len__(self): + """Define the length of the object as the number of settings.""" + return len(self._definitions) - def _get_settings_subdict(self, field): - """Helper function to select from settings dictionary.""" - return {label : self._settings[label][field] for label in self._settings.keys()} + def _get_definitions_subdict(self, field): + """Helper function to select from definitions dictionary.""" + + return {label : self._definitions[label][field] for label in self._definitions.keys()} @property def names(self): """Make property alias for setting names.""" - return list(self._settings.keys()) + return list(self._definitions.keys()) @property def types(self): """Make property alias for setting types.""" - return self._get_settings_subdict('type') + return self._get_definitions_subdict('type') @property def descriptions(self): """Make property alias for setting descriptions.""" - return self._get_settings_subdict('description') + return self._get_definitions_subdict('description') def make_setting_str(self, name): @@ -91,6 +108,11 @@ def make_model_settings(self): class ModelSettings(namedtuple('ModelSettings', self.names)): __slots__ = () + + @property + def names(self): + return list(self._fields) + ModelSettings.__doc__ = self.make_docstring() return ModelSettings diff --git a/specparam/tests/algorithms/test_settings.py b/specparam/tests/algorithms/test_settings.py index 30501a86..98e6abf7 100644 --- a/specparam/tests/algorithms/test_settings.py +++ b/specparam/tests/algorithms/test_settings.py @@ -7,16 +7,17 @@ def test_settings_definition(): - tsettings = { + tdefinitions = { 'a' : {'type' : 'a type desc', 'description' : 'a desc'}, 'b' : {'type' : 'b type desc', 'description' : 'b desc'}, } - settings = SettingsDefinition(tsettings) - assert settings._settings == tsettings - assert settings.names == list(tsettings.keys()) + settings = SettingsDefinition(tdefinitions) + assert settings._definitions == tdefinitions + assert len(settings) == len(tdefinitions) + assert settings.names == list(tdefinitions.keys()) assert settings.types assert settings.descriptions - for label in tsettings.keys(): + for label in tdefinitions.keys(): assert settings.make_setting_str(label) assert settings.make_docstring() From e8f83d1755e346f7ce6d446205ae0b6cb88bade3 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 14 Apr 2025 00:53:41 -0400 Subject: [PATCH 19/40] tweak Algorithm object for settings --- specparam/algorithms/algorithm.py | 21 ++++++++++++++------ specparam/tests/algorithms/test_algorithm.py | 6 +++--- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/specparam/algorithms/algorithm.py b/specparam/algorithms/algorithm.py index 15c10587..1fb8940a 100644 --- a/specparam/algorithms/algorithm.py +++ b/specparam/algorithms/algorithm.py @@ -18,10 +18,18 @@ class Algorithm(): Name of the fitting algorithm. description : str Description of the fitting algorithm. - settings : dict + settings : SettingsDefinition or dict Name and description of settings for the fitting algorithm. format : {'spectrum', 'spectra', 'spectrogram', 'spectrograms'} Set base format of data model can be applied to. + modes : Modes + Modes object with fit mode definitions. + data : Data + Data object with spectral data and metadata. + results : Results + Results object with model fit results and metrics. + debug : bool + Whether to run in debug state, raising an error if encountered during fitting. """ def __init__(self, name, description, settings, format, @@ -33,7 +41,8 @@ def __init__(self, name, description, settings, format, if not isinstance(settings, SettingsDefinition): settings = SettingsDefinition(settings) - self.settings = settings + self._settings = settings + self.settings = None check_input_options(format, FORMATS, 'format') self.format = format @@ -78,8 +87,8 @@ def get_settings(self): Object containing the settings from the current object. """ - return self.settings.make_model_settings()(\ - **{key : getattr(self, key) for key in self.settings.names}) + return self._settings.make_model_settings()(\ + **{key : getattr(self, key) for key in self._settings.names}) def get_debug(self): @@ -111,10 +120,10 @@ def _check_loaded_settings(self, data): # If settings not loaded from file, clear from object, so that default # settings, which are potentially wrong for loaded data, aren't kept - if not set(self.settings.names).issubset(set(data.keys())): + if not set(self._settings.names).issubset(set(data.keys())): # Reset all public settings to None - for setting in self.settings.names: + for setting in self._settings.names: setattr(self, setting, None) # Reset internal settings so that they are consistent with what was loaded diff --git a/specparam/tests/algorithms/test_algorithm.py b/specparam/tests/algorithms/test_algorithm.py index d16f0668..5441161a 100644 --- a/specparam/tests/algorithms/test_algorithm.py +++ b/specparam/tests/algorithms/test_algorithm.py @@ -20,8 +20,8 @@ def test_algorithm(): assert algo assert algo.name == tname assert algo.description == tdescription - assert isinstance(algo.settings, SettingsDefinition) - assert algo.settings == tsettings + assert isinstance(algo._settings, SettingsDefinition) + assert algo._settings == tsettings def test_algorithm_settings(): @@ -34,7 +34,7 @@ def test_algorithm_settings(): talgo = Algorithm(name=tname, description=tdescription, settings=tsettings, format='spectrum') - model_settings = talgo.settings.make_model_settings() + model_settings = talgo._settings.make_model_settings() settings = model_settings(a=1, b=2) talgo.add_settings(settings) for setting in settings._fields: From 91720a3af3416fe7116a360841468126dc9370f8 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 14 Apr 2025 02:13:51 -0400 Subject: [PATCH 20/40] add SettingsValues --- specparam/algorithms/settings.py | 60 +++++++++++++++++++++ specparam/tests/algorithms/test_settings.py | 29 +++++++--- 2 files changed, 81 insertions(+), 8 deletions(-) diff --git a/specparam/algorithms/settings.py b/specparam/algorithms/settings.py index bb55a744..a4f315a8 100644 --- a/specparam/algorithms/settings.py +++ b/specparam/algorithms/settings.py @@ -5,6 +5,66 @@ ################################################################################################### ################################################################################################### +class SettingsValues(): + """Defines a set of algorithm settings values. + + Parameters + ---------- + names : list of str + Names of the settings to hold values for. + + Attributes + ---------- + values : dict of {str : object} + Settings values. + """ + + __slots__ = 'values' + + def __init__(self, names): + """Initialize settings values.""" + + self.values = {name : None for name in names} + + + def __getattr__(self, name): + """Allow for accessing settings values as attributes.""" + + try: + return self.values[name] + except KeyError: + raise AttributeError(name) + + + def __setattr__(self, name, value): + """Allow for setting settings values as attributes.""" + + if name == 'values': + super().__setattr__(name, value) + else: + getattr(self, name) + self.values[name] = value + + + def __getstate__(self): + """Define how to get object state - for pickling.""" + + return self.values + + + def __setstate__(self, state): + """Define how to set object state - for pickling.""" + + self.values = state + + + @property + def names(self): + """Property attribute for settings names.""" + + return list(self.values.keys()) + + class SettingsDefinition(): """Defines a set of algorithm settings. diff --git a/specparam/tests/algorithms/test_settings.py b/specparam/tests/algorithms/test_settings.py index 98e6abf7..10b059c4 100644 --- a/specparam/tests/algorithms/test_settings.py +++ b/specparam/tests/algorithms/test_settings.py @@ -5,6 +5,19 @@ ################################################################################################### ################################################################################################### +def test_settings_values(): + + tsettings_names = ['a', 'b'] + settings_vals = SettingsValues(tsettings_names) + assert isinstance(settings_vals.values, dict) + assert settings_vals.names == tsettings_names + assert settings_vals.a is None + assert settings_vals.b is None + settings_vals.a = 1 + settings_vals.b = 2 + assert settings_vals.a == 1 + assert settings_vals.b == 2 + def test_settings_definition(): tdefinitions = { @@ -12,12 +25,12 @@ def test_settings_definition(): 'b' : {'type' : 'b type desc', 'description' : 'b desc'}, } - settings = SettingsDefinition(tdefinitions) - assert settings._definitions == tdefinitions - assert len(settings) == len(tdefinitions) - assert settings.names == list(tdefinitions.keys()) - assert settings.types - assert settings.descriptions + settings_def = SettingsDefinition(tdefinitions) + assert settings_def._definitions == tdefinitions + assert len(settings_def) == len(tdefinitions) + assert settings_def.names == list(tdefinitions.keys()) + assert settings_def.types + assert settings_def.descriptions for label in tdefinitions.keys(): - assert settings.make_setting_str(label) - assert settings.make_docstring() + assert settings_def.make_setting_str(label) + assert settings_def.make_docstring() From 4418e1f96c4d94468446b1c9122c0480d391bd50 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 14 Apr 2025 02:17:40 -0400 Subject: [PATCH 21/40] use SettingsValues --- specparam/algorithms/algorithm.py | 10 +++---- specparam/algorithms/spectral_fit.py | 28 ++++++++++---------- specparam/tests/algorithms/test_algorithm.py | 2 +- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/specparam/algorithms/algorithm.py b/specparam/algorithms/algorithm.py index 1fb8940a..96145e7b 100644 --- a/specparam/algorithms/algorithm.py +++ b/specparam/algorithms/algorithm.py @@ -1,7 +1,7 @@ """Define object to manage algorithm implementations.""" from specparam.utils.checks import check_input_options -from specparam.algorithms.settings import SettingsDefinition +from specparam.algorithms.settings import SettingsDefinition, SettingsValues ################################################################################################### ################################################################################################### @@ -42,7 +42,7 @@ def __init__(self, name, description, settings, format, if not isinstance(settings, SettingsDefinition): settings = SettingsDefinition(settings) self._settings = settings - self.settings = None + self.settings = SettingsValues(self._settings.names) check_input_options(format, FORMATS, 'format') self.format = format @@ -73,7 +73,7 @@ def add_settings(self, settings): """ for setting in settings._fields: - setattr(self, setting, getattr(settings, setting)) + setattr(self.settings, setting, getattr(settings, setting)) self._check_loaded_settings(settings._asdict()) @@ -88,7 +88,7 @@ def get_settings(self): """ return self._settings.make_model_settings()(\ - **{key : getattr(self, key) for key in self._settings.names}) + **{key : getattr(self.settings, key) for key in self._settings.names}) def get_debug(self): @@ -124,7 +124,7 @@ def _check_loaded_settings(self, data): # Reset all public settings to None for setting in self._settings.names: - setattr(self, setting, None) + setattr(self.settings, setting, None) # Reset internal settings so that they are consistent with what was loaded # Note that this will set internal settings to None, if public settings unavailable diff --git a/specparam/algorithms/spectral_fit.py b/specparam/algorithms/spectral_fit.py index dd66daf0..e72fd699 100644 --- a/specparam/algorithms/spectral_fit.py +++ b/specparam/algorithms/spectral_fit.py @@ -16,7 +16,7 @@ ################################################################################################### ################################################################################################### -SPECTRAL_FIT_SETTINGS = SettingsDefinition({ +SPECTRAL_FIT_SETTINGS_DEF = SettingsDefinition({ 'peak_width_limits' : { 'type' : 'tuple of (float, float), optional, default: (0.5, 12.0)', 'description' : 'Limits on possible peak width, in Hz, as (lower_bound, upper_bound).', @@ -91,14 +91,14 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h super().__init__( name='spectral fit', description='Original parameterizing neural power spectra algorithm.', - settings=SPECTRAL_FIT_SETTINGS, format='spectrum', + settings=SPECTRAL_FIT_SETTINGS_DEF, format='spectrum', modes=modes, data=data, results=results, debug=debug) ## Public settings - self.peak_width_limits = peak_width_limits - self.max_n_peaks = max_n_peaks - self.min_peak_height = min_peak_height - self.peak_threshold = peak_threshold + self.settings.peak_width_limits = peak_width_limits + self.settings.max_n_peaks = max_n_peaks + self.settings.min_peak_height = min_peak_height + self.settings.peak_threshold = peak_threshold ## Private settings: model parameters related settings self._ap_percentile_thresh = ap_percentile_thresh @@ -126,8 +126,8 @@ def _fit_prechecks(self, verbose=True): """ if verbose: - if 1.5 * self.data.freq_res >= self.peak_width_limits[0]: - print(gen_width_warning_str(self.data.freq_res, self.peak_width_limits[0])) + if 1.5 * self.data.freq_res >= self.settings.peak_width_limits[0]: + print(gen_width_warning_str(self.data.freq_res, self.settings.peak_width_limits[0])) def _fit(self): @@ -177,11 +177,11 @@ def _reset_internal_settings(self): """ # Only update these settings if other relevant settings are available - if self.peak_width_limits: + if self.settings.peak_width_limits: # Bandwidth limits are given in 2-sided peak bandwidth # Convert to gaussian std parameter limits - self._gauss_std_limits = tuple(bwl / 2 for bwl in self.peak_width_limits) + self._gauss_std_limits = tuple(bwl / 2 for bwl in self.settings.peak_width_limits) # Otherwise, assume settings are unknown (have been cleared) and set to None else: @@ -360,14 +360,14 @@ def _fit_peaks(self, flatspec): # Find peak: loop through, finding a candidate peak, & fit with a guess peak # Stopping procedures: limit on # of peaks, or relative or absolute height thresholds - while len(guess) < self.max_n_peaks: + while len(guess) < self.settings.max_n_peaks: # Find candidate peak - the maximum point of the flattened spectrum max_ind = np.argmax(flat_iter) max_height = flat_iter[max_ind] # Stop searching for peaks once height drops below height threshold - if max_height <= self.peak_threshold * np.std(flat_iter): + if max_height <= self.settings.peak_threshold * np.std(flat_iter): break # Set the guess parameters for gaussian fitting, specifying the mean and height @@ -375,7 +375,7 @@ def _fit_peaks(self, flatspec): guess_height = max_height # Halt fitting process if candidate peak drops below minimum height - if not guess_height > self.min_peak_height: + if not guess_height > self.settings.min_peak_height: break # Data-driven first guess at standard deviation @@ -402,7 +402,7 @@ def _fit_peaks(self, flatspec): except ValueError: # This procedure can fail (very rarely), if both left & right inds end up as None # In this case, default the guess to the average of the peak width limits - guess_std = np.mean(self.peak_width_limits) + guess_std = np.mean(self.settings.peak_width_limits) # Check that guess value isn't outside preset limits - restrict if so # Note: without this, curve_fitting fails if given guess > or < bounds diff --git a/specparam/tests/algorithms/test_algorithm.py b/specparam/tests/algorithms/test_algorithm.py index 5441161a..ff3e646a 100644 --- a/specparam/tests/algorithms/test_algorithm.py +++ b/specparam/tests/algorithms/test_algorithm.py @@ -38,7 +38,7 @@ def test_algorithm_settings(): settings = model_settings(a=1, b=2) talgo.add_settings(settings) for setting in settings._fields: - assert getattr(talgo, setting) == getattr(settings, setting) + assert getattr(talgo.settings, setting) == getattr(settings, setting) settings_out = talgo.get_settings() assert isinstance(settings, model_settings) From 9803c2f4aabb92ac60fe4d8f95de0dad87ec1bf8 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 14 Apr 2025 02:17:55 -0400 Subject: [PATCH 22/40] udpate across code to use algorithm.settings --- specparam/io/models.py | 2 +- specparam/models/base.py | 4 ++-- specparam/models/utils.py | 2 +- specparam/plts/annotate.py | 6 ++++-- specparam/reports/strings.py | 12 ++++++------ specparam/tests/models/test_event.py | 3 +-- specparam/tests/models/test_group.py | 18 +++++++----------- specparam/tests/models/test_model.py | 7 +++---- specparam/tests/models/test_utils.py | 4 ++-- 9 files changed, 27 insertions(+), 31 deletions(-) diff --git a/specparam/io/models.py b/specparam/io/models.py index 8aa33955..6018f3f2 100644 --- a/specparam/io/models.py +++ b/specparam/io/models.py @@ -49,7 +49,7 @@ def save_model(model, file_name, file_path=None, append=False, """ # 'Flatten' the model object by extracting relevant attributes to a dictionary - obj_dict = {**model.data.__dict__, **model.algorithm.__dict__} + obj_dict = {**model.data.__dict__, **model.algorithm.settings.values} # Convert modes object to their saveable string name obj_dict['aperiodic_mode'] = model.modes.aperiodic.name diff --git a/specparam/models/base.py b/specparam/models/base.py index 5ce949bf..aadd1981 100644 --- a/specparam/models/base.py +++ b/specparam/models/base.py @@ -180,7 +180,7 @@ def _add_from_dict(self, data): # Add additional attributes directly to object for key in data.keys(): - if getattr(self.algorithm, key, False) is not False: - setattr(self.algorithm, key, data[key]) + if getattr(self.algorithm.settings, key, False) is not False: + setattr(self.algorithm.settings, key, data[key]) elif getattr(self.data, key, False) is not False: setattr(self.data, key, data[key]) diff --git a/specparam/models/utils.py b/specparam/models/utils.py index e4b6cbef..0e8e7082 100644 --- a/specparam/models/utils.py +++ b/specparam/models/utils.py @@ -39,7 +39,7 @@ def initialize_model_from_source(source, target): """ model = MODELS[target](**source.modes.get_modes()._asdict(), - **source.algorithm.get_settings()._asdict(), + **source.algorithm.settings.values, metrics=source.results.metrics.labels, bands=source.results.bands, verbose=source.verbose) diff --git a/specparam/plts/annotate.py b/specparam/plts/annotate.py index 05670214..20ae0d70 100644 --- a/specparam/plts/annotate.py +++ b/specparam/plts/annotate.py @@ -51,10 +51,12 @@ def plot_annotated_peak_search(model): plot_spectra(model.data.freqs, flatspec, linewidth=2.5, label='Flattened Spectrum', color=PLT_COLORS['data'], ax=ax) plot_spectra(model.data.freqs, - [model.algorithm.peak_threshold * np.std(flatspec)] * len(model.data.freqs), + [model.algorithm.settings.peak_threshold * np.std(flatspec)] \ + * len(model.data.freqs), label='Relative Threshold', color='orange', linewidth=2.5, linestyle='dashed', ax=ax) - plot_spectra(model.data.freqs, [model.algorithm.min_peak_height]*len(model.data.freqs), + plot_spectra(model.data.freqs, + [model.algorithm.settings.min_peak_height] * len(model.data.freqs), label='Absolute Threshold', color='red', linewidth=2.5, linestyle='dashed', ax=ax) diff --git a/specparam/reports/strings.py b/specparam/reports/strings.py index c774e0ab..b9a8d185 100644 --- a/specparam/reports/strings.py +++ b/specparam/reports/strings.py @@ -211,9 +211,9 @@ def gen_settings_str(model, description=False, concise=False): # Loop through algorithm settings, and add information for name in model.algorithm.settings.names: - str_lst.append(name + ' : ' + str(getattr(model.algorithm, name))) + str_lst.append(name + ' : ' + str(getattr(model.algorithm.settings, name))) if description: - str_lst.append(model.algorithm.settings.descriptions[name].split('\n ')[0]) + str_lst.append(model.algorithm._settings.descriptions[name].split('\n ')[0]) # Add footer to string str_lst.extend([ @@ -337,10 +337,10 @@ def gen_methods_text_str(model=None): methods_str = template.format(MODULE_VERSION, model.modes.aperiodic.name if model else 'XX', model.modes.periodic.name if model else 'XX', - model.algorithm.peak_width_limits if model else 'XX', - model.algorithm.max_n_peaks if model else 'XX', - model.algorithm.min_peak_height if model else 'XX', - model.algorithm.peak_threshold if model else 'XX', + model.algorithm.settings.peak_width_limits if model else 'XX', + model.algorithm.settings.max_n_peaks if model else 'XX', + model.algorithm.settings.min_peak_height if model else 'XX', + model.algorithm.settings.peak_threshold if model else 'XX', *freq_range) return methods_str diff --git a/specparam/tests/models/test_event.py b/specparam/tests/models/test_event.py index 71b8c5dc..24730b9c 100644 --- a/specparam/tests/models/test_event.py +++ b/specparam/tests/models/test_event.py @@ -124,8 +124,7 @@ def test_event_get_model(tfe): tfm_null = tfe.get_model() assert tfm_null # Check that settings are copied over properly, but data and results are empty - for setting in tfe.algorithm.settings.names: - assert getattr(tfe.algorithm, setting) == getattr(tfm_null.algorithm, setting) + assert tfe.algorithm.settings.values == tfm_null.algorithm.settings.values assert not tfm_null.data.has_data assert not tfm_null.results.has_model diff --git a/specparam/tests/models/test_group.py b/specparam/tests/models/test_group.py index aec3bea1..31e9e0b4 100644 --- a/specparam/tests/models/test_group.py +++ b/specparam/tests/models/test_group.py @@ -280,14 +280,13 @@ def test_load(tfg): assert len(ntfg.results.group_results) > 0 # Test that settings and data are None for setting in tfg.algorithm.settings.names: - assert getattr(ntfg.algorithm, setting) is None + assert getattr(ntfg.algorithm.settings, setting) is None assert ntfg.data.power_spectra is None # Test loading just settings ntfg = SpectralGroupModel(verbose=False) ntfg.load('test_group_set', TEST_DATA_PATH) - for setting in tfg.algorithm.settings.names: - assert getattr(tfg.algorithm, setting) == getattr(ntfg.algorithm, setting) + assert tfg.algorithm.settings.values == ntfg.algorithm.settings.values # Test that results and data are None for result in tfg.results.params.fields: assert np.all(np.isnan(getattr(ntfg.results.params, result))) @@ -299,7 +298,7 @@ def test_load(tfg): assert ntfg.data.has_data # Test that settings and results are None for setting in tfg.algorithm.settings.names: - assert getattr(ntfg.algorithm, setting) is None + assert getattr(ntfg.algorithm.settings, setting) is None for result in tfg.results.params.fields: assert np.all(np.isnan(getattr(ntfg.results.params, result))) @@ -327,8 +326,7 @@ def test_get_model(tfg): tfm_null = tfg.get_model() assert tfm_null # Check that settings are copied over properly, but data and results are empty - for setting in tfg.algorithm.settings.names: - assert getattr(tfg.algorithm, setting) == getattr(tfm_null.algorithm, setting) + assert tfg.algorithm.settings.values == tfm_null.algorithm.settings.values assert not tfm_null.data.has_data assert not tfm_null.results.has_model @@ -336,8 +334,7 @@ def test_get_model(tfg): tfm0 = tfg.get_model(0, False) assert tfm0 # Check that settings are copied over properly - for setting in tfg.algorithm.settings.names: - assert getattr(tfg.algorithm, setting) == getattr(tfm0.algorithm, setting) + assert tfg.algorithm.settings.values == tfm0.algorithm.settings.values # Check with regenerating tfm1 = tfg.get_model(1, True) @@ -375,9 +372,8 @@ def test_get_group(tfg): assert isinstance(nfg2, SpectralGroupModel) # Check that settings are copied over properly - for setting in tfg.algorithm.settings.names: - assert getattr(tfg.algorithm, setting) == getattr(nfg1.algorithm, setting) - assert getattr(tfg.algorithm, setting) == getattr(nfg2.algorithm, setting) + assert tfg.algorithm.settings.values == nfg1.algorithm.settings.values + assert tfg.algorithm.settings.values == nfg2.algorithm.settings.values # Check that data info is copied over properly for meta_dat in tfg.data._meta_fields: diff --git a/specparam/tests/models/test_model.py b/specparam/tests/models/test_model.py index fdabcd77..c944bcd5 100644 --- a/specparam/tests/models/test_model.py +++ b/specparam/tests/models/test_model.py @@ -201,14 +201,13 @@ def test_load(tfm): assert not np.all(np.isnan(getattr(ntfm.results.params, result))) # Test that settings and data are None for setting in tfm.algorithm.settings.names: - assert getattr(ntfm.algorithm, setting) is None + assert getattr(ntfm.algorithm.settings, setting) is None assert ntfm.data.power_spectrum is None # Test loading just settings ntfm = SpectralModel(verbose=False) ntfm.load('test_model_set', TEST_DATA_PATH) - for setting in tfm.algorithm.settings.names: - assert getattr(tfm.algorithm, setting) == getattr(ntfm.algorithm, setting) + assert tfm.algorithm.settings.values == ntfm.algorithm.settings.values # Test that results and data are None for result in tfm.results.params.fields: assert np.all(np.isnan(getattr(ntfm.results.params, result))) @@ -221,7 +220,7 @@ def test_load(tfm): assert np.array_equal(tfm.data.power_spectrum, ntfm.data.power_spectrum) # Test that settings and results are None for setting in tfm.algorithm.settings.names: - assert getattr(ntfm.algorithm, setting) is None + assert getattr(ntfm.algorithm.settings, setting) is None for result in tfm.results.params.fields: assert np.all(np.isnan(getattr(ntfm.results.params, result))) diff --git a/specparam/tests/models/test_utils.py b/specparam/tests/models/test_utils.py index 85a28b39..cc5b928a 100644 --- a/specparam/tests/models/test_utils.py +++ b/specparam/tests/models/test_utils.py @@ -41,7 +41,7 @@ def test_compare_model_objs(tfm, tfg): assert not compare_model_objs([f_obj, f_obj2], 'modes') assert compare_model_objs([f_obj, f_obj2], 'settings') - f_obj2.algorithm.peak_width_limits = [2, 4] + f_obj2.algorithm.settings.peak_width_limits = [2, 4] f_obj2.algorithm._reset_internal_settings() assert not compare_model_objs([f_obj, f_obj2], 'settings') @@ -137,7 +137,7 @@ def test_combine_errors(tfm, tfg): # Incompatible settings for f_obj in [tfm, tfg]: f_obj2 = f_obj.copy() - f_obj2.algorithm.peak_width_limits = [2, 4] + f_obj2.algorithm.settings.peak_width_limits = [2, 4] f_obj2.algorithm._reset_internal_settings() with raises(IncompatibleSettingsError): From 47fef97b3fa8f849a2334d1701e83e3c07305025 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 14 Apr 2025 02:22:06 -0400 Subject: [PATCH 23/40] update spectral fit settings def dict name --- specparam/models/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/specparam/models/model.py b/specparam/models/model.py index 2324d26c..4517c676 100644 --- a/specparam/models/model.py +++ b/specparam/models/model.py @@ -10,7 +10,7 @@ from specparam.models.base import BaseModel from specparam.objs.data import Data from specparam.objs.results import Results -from specparam.algorithms.spectral_fit import SpectralFitAlgorithm, SPECTRAL_FIT_SETTINGS +from specparam.algorithms.spectral_fit import SpectralFitAlgorithm, SPECTRAL_FIT_SETTINGS_DEF from specparam.reports.save import save_model_report from specparam.reports.strings import gen_model_results_str from specparam.modutils.errors import NoDataError, FitError @@ -24,7 +24,7 @@ ################################################################################################### ################################################################################################### -@replace_docstring_sections([SPECTRAL_FIT_SETTINGS.make_docstring()]) +@replace_docstring_sections([SPECTRAL_FIT_SETTINGS_DEF.make_docstring()]) class SpectralModel(BaseModel): """Model a power spectrum as a combination of aperiodic and periodic components. From ad6c0e032fc98d7a2dc2a961229e799312913a73 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 14 Apr 2025 21:47:27 -0400 Subject: [PATCH 24/40] add private_settings to Algorithm --- specparam/algorithms/algorithm.py | 43 ++++++++++++-------- specparam/reports/strings.py | 2 +- specparam/tests/algorithms/test_algorithm.py | 10 ++--- 3 files changed, 32 insertions(+), 23 deletions(-) diff --git a/specparam/algorithms/algorithm.py b/specparam/algorithms/algorithm.py index 96145e7b..b9dd4fd0 100644 --- a/specparam/algorithms/algorithm.py +++ b/specparam/algorithms/algorithm.py @@ -6,7 +6,7 @@ ################################################################################################### ################################################################################################### -FORMATS = ['spectrum', 'spectra', 'spectrogram', 'spectrograms'] +DATA_FORMATS = ['spectrum', 'spectra', 'spectrogram', 'spectrograms'] class Algorithm(): @@ -18,10 +18,12 @@ class Algorithm(): Name of the fitting algorithm. description : str Description of the fitting algorithm. - settings : SettingsDefinition or dict - Name and description of settings for the fitting algorithm. - format : {'spectrum', 'spectra', 'spectrogram', 'spectrograms'} - Set base format of data model can be applied to. + public_settings : SettingsDefinition or dict + Name and description of public settings for the fitting algorithm. + private_settings : SettingsDefinition or dict, optional + Name and description of private settings for the fitting algorithm. + data_format : {'spectrum', 'spectra', 'spectrogram', 'spectrograms'} + Set base data format the model can be applied to. modes : Modes Modes object with fit mode definitions. data : Data @@ -32,20 +34,27 @@ class Algorithm(): Whether to run in debug state, raising an error if encountered during fitting. """ - def __init__(self, name, description, settings, format, - modes=None, data=None, results=None, debug=False): + def __init__(self, name, description, public_settings, private_settings=None, + data_format='spectrum', modes=None, data=None, results=None, debug=False): """Initialize Algorithm object.""" self.name = name self.description = description - if not isinstance(settings, SettingsDefinition): - settings = SettingsDefinition(settings) - self._settings = settings - self.settings = SettingsValues(self._settings.names) + if not isinstance(public_settings, SettingsDefinition): + public_settings = SettingsDefinition(public_settings) + self.public_settings = public_settings + self.settings = SettingsValues(self.public_settings.names) - check_input_options(format, FORMATS, 'format') - self.format = format + if private_settings is None: + private_settings = {} + if not isinstance(private_settings, SettingsDefinition): + private_settings = SettingsDefinition(private_settings) + self.private_settings = private_settings + self._settings = SettingsValues(self.private_settings.names) + + check_input_options(data_format, DATA_FORMATS, 'data_format') + self.data_format = data_format self.modes = None self.data = None @@ -87,8 +96,8 @@ def get_settings(self): Object containing the settings from the current object. """ - return self._settings.make_model_settings()(\ - **{key : getattr(self.settings, key) for key in self._settings.names}) + return self.public_settings.make_model_settings()(\ + **{key : getattr(self.settings, key) for key in self.public_settings.names}) def get_debug(self): @@ -120,10 +129,10 @@ def _check_loaded_settings(self, data): # If settings not loaded from file, clear from object, so that default # settings, which are potentially wrong for loaded data, aren't kept - if not set(self._settings.names).issubset(set(data.keys())): + if not set(self.settings.names).issubset(set(data.keys())): # Reset all public settings to None - for setting in self._settings.names: + for setting in self.settings.names: setattr(self.settings, setting, None) # Reset internal settings so that they are consistent with what was loaded diff --git a/specparam/reports/strings.py b/specparam/reports/strings.py index b9a8d185..b3c07bad 100644 --- a/specparam/reports/strings.py +++ b/specparam/reports/strings.py @@ -213,7 +213,7 @@ def gen_settings_str(model, description=False, concise=False): for name in model.algorithm.settings.names: str_lst.append(name + ' : ' + str(getattr(model.algorithm.settings, name))) if description: - str_lst.append(model.algorithm._settings.descriptions[name].split('\n ')[0]) + str_lst.append(model.algorithm.public_settings.descriptions[name].split('\n ')[0]) # Add footer to string str_lst.extend([ diff --git a/specparam/tests/algorithms/test_algorithm.py b/specparam/tests/algorithms/test_algorithm.py index ff3e646a..8b13b7a3 100644 --- a/specparam/tests/algorithms/test_algorithm.py +++ b/specparam/tests/algorithms/test_algorithm.py @@ -16,12 +16,12 @@ def test_algorithm(): 'b' : {'type' : 'b type desc', 'description' : 'b desc'}, }) - algo = Algorithm(name=tname, description=tdescription, settings=tsettings, format='spectrum') + algo = Algorithm(name=tname, description=tdescription, public_settings=tsettings) assert algo assert algo.name == tname assert algo.description == tdescription - assert isinstance(algo._settings, SettingsDefinition) - assert algo._settings == tsettings + assert isinstance(algo.public_settings, SettingsDefinition) + assert algo.public_settings == tsettings def test_algorithm_settings(): @@ -32,9 +32,9 @@ def test_algorithm_settings(): 'b' : {'type' : 'b type desc', 'description' : 'b desc'}, }) - talgo = Algorithm(name=tname, description=tdescription, settings=tsettings, format='spectrum') + talgo = Algorithm(name=tname, description=tdescription, public_settings=tsettings) - model_settings = talgo._settings.make_model_settings() + model_settings = talgo.public_settings.make_model_settings() settings = model_settings(a=1, b=2) talgo.add_settings(settings) for setting in settings._fields: From cabf8af04e9dead4548e7242238250ceb75ee951 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 14 Apr 2025 21:53:42 -0400 Subject: [PATCH 25/40] use private settings on spectral fit algorithm object --- specparam/algorithms/spectral_fit.py | 105 ++++++++++++++++----------- 1 file changed, 64 insertions(+), 41 deletions(-) diff --git a/specparam/algorithms/spectral_fit.py b/specparam/algorithms/spectral_fit.py index e72fd699..219cef31 100644 --- a/specparam/algorithms/spectral_fit.py +++ b/specparam/algorithms/spectral_fit.py @@ -40,32 +40,53 @@ }) +SPECTRAL_FIT_PRIVATE_SETTINGS_DEF = SettingsDefinition({ + 'ap_percentile_thresh' : { + 'type' : 'float', + 'description' : \ + 'Percentile threshold, to select points from a flat spectrum for an initial aperiodic fit.\n ' + 'Points are selected at a low percentile value to restrict to non-peak points.' + }, + 'ap_guess' : { + 'type' : 'list of float', + 'description' : \ + 'Guess parameters for fitting the aperiodic component.\n ' + 'The length of the guess parameters should match the number & order of the aperiodic parameters.\n ' + 'If \'offset\' is a parameter & guess is None, the first value of the power spectrum is used as the guess.\n ' + 'If \'exponent\' is a parmater & guess is None, the abs(log-log slope) of first & last points is used.' + }, + 'ap_bounds' : { + 'type' : 'tuple of tuple of float', + 'description' : \ + 'Bounds for aperiodic fitting, as ((param1_low_bound, ...) (param1_high_bound, ...)).\n ' + 'By default, aperiodic fitting is unbound, but can be restricted here.' + }, + 'cf_bound' : { + 'type' : 'float', + 'description' : 'Parameter bounds for center frequency when fitting gaussians, in terms of +/- std dev.' + }, + 'bw_std_edge' : { + 'type' : 'float', + 'description' : \ + 'Threshold for how far a peak has to be from edge to keep.\n ' + 'This is defined in units of gaussian standard deviation.' + }, + 'gauss_overlap_thresh' : { + 'type' : 'float', + 'description' : \ + 'Degree of overlap between gaussian guesses for one to be dropped.\n ' + 'This is defined in units of gaussian standard deviation.' + }, +}) + + class SpectralFitAlgorithm(Algorithm): """Base object defining model & algorithm for spectral parameterization. Parameters ---------- % public settings described in Spectral Fit Algorithm Settings - _ap_percentile_thresh : float - Percentile threshold, to select points from a flat spectrum for an initial aperiodic fit - Points are selected at a low percentile value to restrict to non-peak points. - _ap_guess : list of [float, float, float] - Guess parameters for fitting the aperiodic component, as [offset, knee, exponent]. - If offset guess is None, the first value of the power spectrum is used as offset guess - If exponent guess is None, the abs(log-log slope) of first & last points is used - _ap_bounds : tuple of tuple of float - Bounds for aperiodic fitting, as: ((offset_low_bound, knee_low_bound, exp_low_bound), - (offset_high_bound, knee_high_bound, exp_high_bound)) - By default, aperiodic fitting is unbound, but can be restricted here. - Even if fitting without knee, leave bounds for knee (they are dropped later). - _cf_bound : float - Parameter bounds for center frequency when fitting gaussians, in terms of +/- std dev. - _bw_std_edge : float - Threshold for how far a peak has to be from edge to keep. - This is defined in units of gaussian standard deviation. - _gauss_overlap_thresh : float - Degree of overlap between gaussian guesses for one to be dropped. - This is defined in units of gaussian standard deviation. + _maxfev : int The maximum number of calls to the curve fitting function. _tol : float @@ -91,7 +112,8 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h super().__init__( name='spectral fit', description='Original parameterizing neural power spectra algorithm.', - settings=SPECTRAL_FIT_SETTINGS_DEF, format='spectrum', + public_settings=SPECTRAL_FIT_SETTINGS_DEF, + private_settings=SPECTRAL_FIT_PRIVATE_SETTINGS_DEF, modes=modes, data=data, results=results, debug=debug) ## Public settings @@ -101,12 +123,12 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h self.settings.peak_threshold = peak_threshold ## Private settings: model parameters related settings - self._ap_percentile_thresh = ap_percentile_thresh - self._ap_guess = ap_guess - self._set_ap_bounds(ap_bounds) - self._cf_bound = cf_bound - self._bw_std_edge = bw_std_edge - self._gauss_overlap_thresh = gauss_overlap_thresh + self._settings.ap_percentile_thresh = ap_percentile_thresh + self._settings.ap_guess = ap_guess + self._settings.ap_bounds = self._get_ap_bounds(ap_bounds) + self._settings.cf_bound = cf_bound + self._settings.bw_std_edge = bw_std_edge + self._settings.gauss_overlap_thresh = gauss_overlap_thresh ## Private setting: curve_fit related settings self._maxfev = maxfev @@ -198,7 +220,7 @@ def _get_ap_guess(self, freqs, power_spectrum): ToDo - Could be updated to fill in missing guesses. """ - if not self._ap_guess: + if not self._settings.ap_guess: ap_guess = [] for label in self.modes.aperiodic.params.labels: @@ -221,7 +243,7 @@ def _get_ap_guess(self, freqs, power_spectrum): return ap_guess - def _set_ap_bounds(self, ap_bounds): + def _get_ap_bounds(self, ap_bounds): """Set the default bounds for the aperiodic fit. Notes @@ -232,12 +254,13 @@ def _set_ap_bounds(self, ap_bounds): if ap_bounds: msg = 'Provided aperiodic bounds do not have right length for fit function.' - assert len(self._ap_bounds[0]) == len(self._ap_bounds[1]) == \ + assert len(ap_bounds[0]) == len(ap_bounds[1]) == \ self.modes.aperiodic.n_params, msg - self._ap_bounds = ap_bounds else: - self._ap_bounds = (tuple([-np.inf] * self.modes.aperiodic.n_params), - tuple([np.inf] * self.modes.aperiodic.n_params)) + ap_bounds = (tuple([-np.inf] * self.modes.aperiodic.n_params), + tuple([np.inf] * self.modes.aperiodic.n_params)) + + return ap_bounds def _simple_ap_fit(self, freqs, power_spectrum): @@ -267,7 +290,7 @@ def _simple_ap_fit(self, freqs, power_spectrum): with warnings.catch_warnings(): warnings.simplefilter("ignore") aperiodic_params, _ = curve_fit(self.modes.aperiodic.func, freqs, power_spectrum, - p0=ap_guess, bounds=self._ap_bounds, + p0=ap_guess, bounds=self._settings.ap_bounds, maxfev=self._maxfev, check_finite=False, ftol=self._tol, xtol=self._tol, gtol=self._tol) except RuntimeError as excp: @@ -310,7 +333,7 @@ def _robust_ap_fit(self, freqs, power_spectrum): flatspec[flatspec < 0] = 0 # Use percentile threshold, in terms of # of points, to extract and re-fit - perc_thresh = np.percentile(flatspec, self._ap_percentile_thresh) + perc_thresh = np.percentile(flatspec, self._settings.ap_percentile_thresh) perc_mask = flatspec <= perc_thresh freqs_ignore = freqs[perc_mask] spectrum_ignore = power_spectrum[perc_mask] @@ -322,7 +345,7 @@ def _robust_ap_fit(self, freqs, power_spectrum): warnings.simplefilter("ignore") aperiodic_params, _ = curve_fit(self.modes.aperiodic.func, freqs_ignore, spectrum_ignore, - p0=popt, bounds=self._ap_bounds, + p0=popt, bounds=self._settings.ap_bounds, maxfev=self._maxfev, check_finite=False, ftol=self._tol, xtol=self._tol, gtol=self._tol) except RuntimeError as excp: @@ -447,9 +470,9 @@ def _get_pe_bounds(self, guess): # ((cf_low_peak1, height_low_peak1, bw_low_peak1, *repeated for n_peaks*), # (cf_high_peak1, height_high_peak1, bw_high_peak, *repeated for n_peaks*)) # ^where each value sets the bound on the specified parameter - lo_bound = [[peak[0] - 2 * self._cf_bound * peak[2], 0, self._gauss_std_limits[0]] + lo_bound = [[peak[0] - 2 * self._settings.cf_bound * peak[2], 0, self._gauss_std_limits[0]] for peak in guess] - hi_bound = [[peak[0] + 2 * self._cf_bound * peak[2], np.inf, self._gauss_std_limits[1]] + hi_bound = [[peak[0] + 2 * self._settings.cf_bound * peak[2], np.inf, self._gauss_std_limits[1]] for peak in guess] # Check that CF bounds are within frequency range @@ -525,7 +548,7 @@ def _drop_peak_cf(self, guess): """ cf_params = guess[:, 0] - bw_params = guess[:, 2] * self._bw_std_edge + bw_params = guess[:, 2] * self._settings.bw_std_edge # Check if peaks within drop threshold from the edge of the frequency range keep_peak = \ @@ -562,8 +585,8 @@ def _drop_peak_overlap(self, guess): # Calculate standard deviation bounds for checking amount of overlap # The bounds are the gaussian frequency +/- gaussian standard deviation - bounds = [[peak[0] - peak[2] * self._gauss_overlap_thresh, - peak[0] + peak[2] * self._gauss_overlap_thresh] for peak in guess] + bounds = [[peak[0] - peak[2] * self._settings.gauss_overlap_thresh, + peak[0] + peak[2] * self._settings.gauss_overlap_thresh] for peak in guess] # Loop through peak bounds, comparing current bound to that of next peak # If the left peak's upper bound extends pass the right peaks lower bound, From 71cd252a84631b89ea5f3c914a5ed62bbc9ce6a2 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 14 Apr 2025 22:46:08 -0400 Subject: [PATCH 26/40] update process for peak_width_limits -> gaussian std --- specparam/algorithms/spectral_fit.py | 51 ++++++---------------------- specparam/tests/models/test_model.py | 3 -- specparam/tests/models/test_utils.py | 3 -- 3 files changed, 11 insertions(+), 46 deletions(-) diff --git a/specparam/algorithms/spectral_fit.py b/specparam/algorithms/spectral_fit.py index 219cef31..80b53a17 100644 --- a/specparam/algorithms/spectral_fit.py +++ b/specparam/algorithms/spectral_fit.py @@ -53,7 +53,7 @@ 'Guess parameters for fitting the aperiodic component.\n ' 'The length of the guess parameters should match the number & order of the aperiodic parameters.\n ' 'If \'offset\' is a parameter & guess is None, the first value of the power spectrum is used as the guess.\n ' - 'If \'exponent\' is a parmater & guess is None, the abs(log-log slope) of first & last points is used.' + 'If \'exponent\' is a parameter & guess is None, the abs(log-log slope) of first & last points is used.' }, 'ap_bounds' : { 'type' : 'tuple of tuple of float', @@ -93,12 +93,6 @@ class SpectralFitAlgorithm(Algorithm): The tolerance setting for curve fitting (see scipy.curve_fit - ftol / xtol / gtol). The default value reduce tolerance to speed fitting (as compared to curve_fit's default). Set value to 1e-8 to match curve_fit default. - - Attributes - ---------- - _gauss_std_limits : list of [float, float] - Settings attribute: peak width limits, to use for gaussian standard deviation parameter. - This attribute is computed based on `peak_width_limits` and should not be updated directly. """ # pylint: disable=attribute-defined-outside-init @@ -134,9 +128,6 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h self._maxfev = maxfev self._tol = tol - ## Set internal settings, based on inputs, and initialize data & results attributes - self._reset_internal_settings() - def _fit_prechecks(self, verbose=True): """Prechecks to run before the fit function. @@ -189,27 +180,6 @@ def _fit(self): self.results.params.peak = self._create_peak_params(self.results.params.gaussian) - def _reset_internal_settings(self): - """Set, or reset, internal settings, based on what is provided in init. - - Notes - ----- - These settings are for internal use, based on what is provided to, or set in `__init__`. - They should not be altered by the user. - """ - - # Only update these settings if other relevant settings are available - if self.settings.peak_width_limits: - - # Bandwidth limits are given in 2-sided peak bandwidth - # Convert to gaussian std parameter limits - self._gauss_std_limits = tuple(bwl / 2 for bwl in self.settings.peak_width_limits) - - # Otherwise, assume settings are unknown (have been cleared) and set to None - else: - self._gauss_std_limits = None - - def _get_ap_guess(self, freqs, power_spectrum): """Get the guess parameters for the aperiodic fit. @@ -428,11 +398,12 @@ def _fit_peaks(self, flatspec): guess_std = np.mean(self.settings.peak_width_limits) # Check that guess value isn't outside preset limits - restrict if so + # This also converts the peak_width_limits from 2-sided BW to 1-sided gaussian std # Note: without this, curve_fitting fails if given guess > or < bounds - if guess_std < self._gauss_std_limits[0]: - guess_std = self._gauss_std_limits[0] - if guess_std > self._gauss_std_limits[1]: - guess_std = self._gauss_std_limits[1] + if guess_std < self.settings.peak_width_limits[0] / 2: + guess_std = self.settings.peak_width_limits[0] / 2 + if guess_std > self.settings.peak_width_limits[1] / 2: + guess_std = self.settings.peak_width_limits[0] / 2 # Collect guess parameters and subtract this guess gaussian from the data current_guess_params = (guess_freq, guess_height, guess_std) @@ -465,15 +436,15 @@ def _get_pe_bounds(self, guess): """Get the bound for the peak fit.""" # Set the bounds for CF, enforce positive height value, and set bandwidth limits - # Note that 'guess' is in terms of gaussian std, so +/- BW is 2 * the guess_gauss_std + # Note that 'guess' is in terms of gaussian std, so peak_width_limit is % by 2 # This set of list comprehensions is a way to end up with bounds in the form: # ((cf_low_peak1, height_low_peak1, bw_low_peak1, *repeated for n_peaks*), # (cf_high_peak1, height_high_peak1, bw_high_peak, *repeated for n_peaks*)) # ^where each value sets the bound on the specified parameter - lo_bound = [[peak[0] - 2 * self._settings.cf_bound * peak[2], 0, self._gauss_std_limits[0]] - for peak in guess] - hi_bound = [[peak[0] + 2 * self._settings.cf_bound * peak[2], np.inf, self._gauss_std_limits[1]] - for peak in guess] + lo_bound = [[peak[0] - 2 * self._settings.cf_bound * peak[2], 0, \ + self.settings.peak_width_limits[0] / 2] for peak in guess] + hi_bound = [[peak[0] + 2 * self._settings.cf_bound * peak[2], np.inf, \ + self.settings.peak_width_limits[1] / 2] for peak in guess] # Check that CF bounds are within frequency range # If they are not, update them to be restricted to frequency range diff --git a/specparam/tests/models/test_model.py b/specparam/tests/models/test_model.py index c944bcd5..0985365d 100644 --- a/specparam/tests/models/test_model.py +++ b/specparam/tests/models/test_model.py @@ -314,10 +314,7 @@ def test_resets(): # Note: uses it's own tfm, to not clear the global one tfm = get_tfm() - tfm._reset_data_results(True, True, True) - tfm.algorithm._reset_internal_settings() - for field in tfm.data._fields: assert getattr(tfm.data, field) is None for key, value in tfm.results.model.__dict__.items(): diff --git a/specparam/tests/models/test_utils.py b/specparam/tests/models/test_utils.py index cc5b928a..8e2d9903 100644 --- a/specparam/tests/models/test_utils.py +++ b/specparam/tests/models/test_utils.py @@ -42,7 +42,6 @@ def test_compare_model_objs(tfm, tfg): assert compare_model_objs([f_obj, f_obj2], 'settings') f_obj2.algorithm.settings.peak_width_limits = [2, 4] - f_obj2.algorithm._reset_internal_settings() assert not compare_model_objs([f_obj, f_obj2], 'settings') assert compare_model_objs([f_obj, f_obj2], 'meta_data') @@ -138,8 +137,6 @@ def test_combine_errors(tfm, tfg): for f_obj in [tfm, tfg]: f_obj2 = f_obj.copy() f_obj2.algorithm.settings.peak_width_limits = [2, 4] - f_obj2.algorithm._reset_internal_settings() - with raises(IncompatibleSettingsError): combine_model_objs([f_obj, f_obj2]) From f7512626211e3fd69311ac97c4892a19cb897404 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 14 Apr 2025 22:58:13 -0400 Subject: [PATCH 27/40] drop reset internal settings - no longer used --- specparam/algorithms/algorithm.py | 8 -------- specparam/algorithms/spectral_fit.py | 3 +-- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/specparam/algorithms/algorithm.py b/specparam/algorithms/algorithm.py index b9dd4fd0..c38c0864 100644 --- a/specparam/algorithms/algorithm.py +++ b/specparam/algorithms/algorithm.py @@ -135,14 +135,6 @@ def _check_loaded_settings(self, data): for setting in self.settings.names: setattr(self.settings, setting, None) - # Reset internal settings so that they are consistent with what was loaded - # Note that this will set internal settings to None, if public settings unavailable - self._reset_internal_settings() - - - def _reset_internal_settings(self): - """"Can be overloaded if any resetting needed for internal settings.""" - def _reset_subobjects(self, modes=None, data=None, results=None): """Reset links to sub-objects (mode / data / results). diff --git a/specparam/algorithms/spectral_fit.py b/specparam/algorithms/spectral_fit.py index 80b53a17..8aefdc91 100644 --- a/specparam/algorithms/spectral_fit.py +++ b/specparam/algorithms/spectral_fit.py @@ -224,8 +224,7 @@ def _get_ap_bounds(self, ap_bounds): if ap_bounds: msg = 'Provided aperiodic bounds do not have right length for fit function.' - assert len(ap_bounds[0]) == len(ap_bounds[1]) == \ - self.modes.aperiodic.n_params, msg + assert len(ap_bounds[0]) == len(ap_bounds[1]) == self.modes.aperiodic.n_params, msg else: ap_bounds = (tuple([-np.inf] * self.modes.aperiodic.n_params), tuple([np.inf] * self.modes.aperiodic.n_params)) From 7aaeb428b7d92b8618215a20eba1740c42be1297 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 14 Apr 2025 23:11:36 -0400 Subject: [PATCH 28/40] add clear method to settings values --- specparam/algorithms/settings.py | 7 +++++++ specparam/tests/algorithms/test_settings.py | 4 ++++ 2 files changed, 11 insertions(+) diff --git a/specparam/algorithms/settings.py b/specparam/algorithms/settings.py index a4f315a8..3e1cbbc2 100644 --- a/specparam/algorithms/settings.py +++ b/specparam/algorithms/settings.py @@ -65,6 +65,13 @@ def names(self): return list(self.values.keys()) + def clear(self): + """Clear all settings - resetting to None.""" + + for setting in self.names: + self.values[setting] = None + + class SettingsDefinition(): """Defines a set of algorithm settings. diff --git a/specparam/tests/algorithms/test_settings.py b/specparam/tests/algorithms/test_settings.py index 10b059c4..4e896fbb 100644 --- a/specparam/tests/algorithms/test_settings.py +++ b/specparam/tests/algorithms/test_settings.py @@ -18,6 +18,10 @@ def test_settings_values(): assert settings_vals.a == 1 assert settings_vals.b == 2 + settings_vals.clear() + assert settings_vals.a is None + assert settings_vals.b is None + def test_settings_definition(): tdefinitions = { From 944e81254ce2a02583d6ab26258018bce3df80b2 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 14 Apr 2025 23:29:47 -0400 Subject: [PATCH 29/40] drop _check_loaded_settings --- specparam/algorithms/algorithm.py | 22 +--------------------- specparam/models/group.py | 6 +++--- specparam/models/model.py | 5 ++++- 3 files changed, 8 insertions(+), 25 deletions(-) diff --git a/specparam/algorithms/algorithm.py b/specparam/algorithms/algorithm.py index c38c0864..5401faa8 100644 --- a/specparam/algorithms/algorithm.py +++ b/specparam/algorithms/algorithm.py @@ -78,14 +78,12 @@ def add_settings(self, settings): Parameters ---------- settings : ModelSettings - A data object containing the settings for a power spectrum model. + A data object containing model settings. """ for setting in settings._fields: setattr(self.settings, setting, getattr(settings, setting)) - self._check_loaded_settings(settings._asdict()) - def get_settings(self): """Return user defined settings of the current object. @@ -118,24 +116,6 @@ def set_debug(self, debug): self._debug = debug - def _check_loaded_settings(self, data): - """Check if settings added, and update the object as needed. - - Parameters - ---------- - data : dict - A dictionary of data that has been added to the object. - """ - - # If settings not loaded from file, clear from object, so that default - # settings, which are potentially wrong for loaded data, aren't kept - if not set(self.settings.names).issubset(set(data.keys())): - - # Reset all public settings to None - for setting in self.settings.names: - setattr(self.settings, setting, None) - - def _reset_subobjects(self, modes=None, data=None, results=None): """Reset links to sub-objects (mode / data / results). diff --git a/specparam/models/group.py b/specparam/models/group.py index 56ff1fbb..ec4bd4d5 100644 --- a/specparam/models/group.py +++ b/specparam/models/group.py @@ -221,9 +221,9 @@ def load(self, file_name, file_path=None): self._add_from_dict(data) - # If settings are loaded, check and update based on the first line - if ind == 0: - self.algorithm._check_loaded_settings(data) + # For hearder line, check if settings are loaded and clear defaults if not + if ind == 0 and not set(self.algorithm.settings.names).issubset(set(data.keys())): + self.algorithm.settings.clear() # If results part of current data added, check and update object results if set([el + '_params' for el in self.results.params.fields]).issubset(set(data.keys())): diff --git a/specparam/models/model.py b/specparam/models/model.py index 4517c676..26913cda 100644 --- a/specparam/models/model.py +++ b/specparam/models/model.py @@ -268,7 +268,10 @@ def load(self, file_name, file_path=None, regenerate=True): # Add loaded data to object and check loaded data self._add_from_dict(data) - self.algorithm._check_loaded_settings(data) + + # If settings are not loaded, clear defaults to not have potentially incorrect values + if not set(self.algorithm.settings.names).issubset(set(data.keys())): + self.algorithm.settings.clear() # Regenerate model components, based on what is available if regenerate: From 456487204d26e8a82c3a1f3c14a2c39ca63c23f0 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 15 Apr 2025 00:09:15 -0400 Subject: [PATCH 30/40] update algo to use cf settings --- specparam/algorithms/algorithm.py | 15 ++++++++++++++ specparam/algorithms/spectral_fit.py | 30 +++++++++++++--------------- specparam/tests/models/test_group.py | 2 +- specparam/tests/models/test_model.py | 4 ++-- 4 files changed, 32 insertions(+), 19 deletions(-) diff --git a/specparam/algorithms/algorithm.py b/specparam/algorithms/algorithm.py index 5401faa8..65457006 100644 --- a/specparam/algorithms/algorithm.py +++ b/specparam/algorithms/algorithm.py @@ -8,6 +8,18 @@ DATA_FORMATS = ['spectrum', 'spectra', 'spectrogram', 'spectrograms'] +CURVE_FIT_SETTINGS = SettingsDefinition({ + 'maxfev' : { + 'type' : 'int', + 'description' : 'The maximum number of calls to the curve fitting function.', + }, + 'tol' : { + 'type' : 'float', + 'description' : \ + 'The tolerance setting for curve fitting (see scipy.curve_fit: ftol / xtol / gtol).' + }, +}) + class Algorithm(): """Template object for defining a fit algorithm. @@ -53,6 +65,9 @@ def __init__(self, name, description, public_settings, private_settings=None, self.private_settings = private_settings self._settings = SettingsValues(self.private_settings.names) + self._cf_settings_desc = CURVE_FIT_SETTINGS + self._cf_settings = SettingsValues(self._cf_settings_desc.names) + check_input_options(data_format, DATA_FORMATS, 'data_format') self.data_format = data_format diff --git a/specparam/algorithms/spectral_fit.py b/specparam/algorithms/spectral_fit.py index 8aefdc91..fa73b3c6 100644 --- a/specparam/algorithms/spectral_fit.py +++ b/specparam/algorithms/spectral_fit.py @@ -86,13 +86,6 @@ class SpectralFitAlgorithm(Algorithm): Parameters ---------- % public settings described in Spectral Fit Algorithm Settings - - _maxfev : int - The maximum number of calls to the curve fitting function. - _tol : float - The tolerance setting for curve fitting (see scipy.curve_fit - ftol / xtol / gtol). - The default value reduce tolerance to speed fitting (as compared to curve_fit's default). - Set value to 1e-8 to match curve_fit default. """ # pylint: disable=attribute-defined-outside-init @@ -124,9 +117,11 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h self._settings.bw_std_edge = bw_std_edge self._settings.gauss_overlap_thresh = gauss_overlap_thresh - ## Private setting: curve_fit related settings - self._maxfev = maxfev - self._tol = tol + ## curve_fit settings + # Note - default reduces tolerance to speed fitting (as compared to curve_fit's default). + # Set value to 1e-8 to match curve_fit default. + self._cf_settings.maxfev = maxfev + self._cf_settings.tol = tol def _fit_prechecks(self, verbose=True): @@ -260,8 +255,9 @@ def _simple_ap_fit(self, freqs, power_spectrum): warnings.simplefilter("ignore") aperiodic_params, _ = curve_fit(self.modes.aperiodic.func, freqs, power_spectrum, p0=ap_guess, bounds=self._settings.ap_bounds, - maxfev=self._maxfev, check_finite=False, - ftol=self._tol, xtol=self._tol, gtol=self._tol) + maxfev=self._cf_settings.maxfev, check_finite=False, + ftol=self._cf_settings.tol, xtol=self._cf_settings.tol, + gtol=self._cf_settings.tol) except RuntimeError as excp: error_msg = ("Model fitting failed due to not finding parameters in " "the simple aperiodic component fit.") @@ -315,8 +311,9 @@ def _robust_ap_fit(self, freqs, power_spectrum): aperiodic_params, _ = curve_fit(self.modes.aperiodic.func, freqs_ignore, spectrum_ignore, p0=popt, bounds=self._settings.ap_bounds, - maxfev=self._maxfev, check_finite=False, - ftol=self._tol, xtol=self._tol, gtol=self._tol) + maxfev=self._cf_settings.maxfev, check_finite=False, + ftol=self._cf_settings.tol, xtol=self._cf_settings.tol, + gtol=self._cf_settings.tol) except RuntimeError as excp: error_msg = ("Model fitting failed due to not finding " "parameters in the robust aperiodic fit.") @@ -483,8 +480,9 @@ def _fit_peak_guess(self, flatspec, guess): p0=np.ndarray.flatten(guess), bounds=self._get_pe_bounds(guess), jac=self.modes.periodic.jacobian, - maxfev=self._maxfev, check_finite=False, - ftol=self._tol, xtol=self._tol, gtol=self._tol) + maxfev=self._cf_settings.maxfev, check_finite=False, + ftol=self._cf_settings.tol, xtol=self._cf_settings.tol, + gtol=self._cf_settings.tol) except RuntimeError as excp: error_msg = ("Model fitting failed due to not finding " diff --git a/specparam/tests/models/test_group.py b/specparam/tests/models/test_group.py index 31e9e0b4..5665cbbe 100644 --- a/specparam/tests/models/test_group.py +++ b/specparam/tests/models/test_group.py @@ -153,7 +153,7 @@ def test_fg_fail(): # Use a fg with the max iterations set so low that it will fail to converge ntfg = SpectralGroupModel() - ntfg.algorithm._maxfev = 5 + ntfg.algorithm._cf_settings.maxfev = 5 # Fit models, where some will fail, to see if it completes cleanly ntfg.fit(fs, ps) diff --git a/specparam/tests/models/test_model.py b/specparam/tests/models/test_model.py index 0985365d..05dea871 100644 --- a/specparam/tests/models/test_model.py +++ b/specparam/tests/models/test_model.py @@ -336,7 +336,7 @@ def test_fit_failure(): ## Induce a runtime error, and check it runs through tfm = SpectralModel(verbose=False) - tfm.algorithm._maxfev = 2 + tfm.algorithm._cf_settings.maxfev = 2 tfm.fit(*sim_power_spectrum(*default_spectrum_params())) @@ -362,7 +362,7 @@ def test_debug(): """Test model object in debug state, including with fit failures.""" tfm = SpectralModel(verbose=False) - tfm.algorithm._maxfev = 2 + tfm.algorithm._cf_settings.maxfev = 2 tfm.algorithm.set_debug(True) assert tfm.algorithm._debug is True From 104128c7c574c87d5d283cce125c391d68676a27 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 15 Apr 2025 12:06:55 -0400 Subject: [PATCH 31/40] lints --- specparam/algorithms/settings.py | 2 +- specparam/data/conversions.py | 2 +- specparam/models/group.py | 5 +++-- specparam/modes/modes.py | 5 ++--- specparam/objs/components.py | 1 + specparam/objs/data.py | 7 ++++--- specparam/objs/metrics.py | 24 ++++++++++++------------ specparam/objs/results.py | 3 ++- specparam/plts/event.py | 2 +- specparam/plts/time.py | 4 ++-- specparam/plts/utils.py | 2 +- specparam/sim/params.py | 2 -- specparam/tests/objs/test_metrics.py | 8 ++++---- 13 files changed, 34 insertions(+), 33 deletions(-) diff --git a/specparam/algorithms/settings.py b/specparam/algorithms/settings.py index 3e1cbbc2..14743188 100644 --- a/specparam/algorithms/settings.py +++ b/specparam/algorithms/settings.py @@ -19,7 +19,7 @@ class SettingsValues(): Settings values. """ - __slots__ = 'values' + __slots__ = ('values',) def __init__(self, names): """Initialize settings values.""" diff --git a/specparam/data/conversions.py b/specparam/data/conversions.py index 98258cbf..39f6a385 100644 --- a/specparam/data/conversions.py +++ b/specparam/data/conversions.py @@ -2,7 +2,7 @@ import numpy as np -from specparam.bands.bands import Bands, check_bands +from specparam.bands.bands import check_bands from specparam.modutils.dependencies import safe_import, check_dependency from specparam.data.periodic import get_band_peak_arr from specparam.data.utils import flatten_results_dict diff --git a/specparam/models/group.py b/specparam/models/group.py index ec4bd4d5..9966c528 100644 --- a/specparam/models/group.py +++ b/specparam/models/group.py @@ -219,14 +219,15 @@ def load(self, file_name, file_path=None): if 'power_spectrum' in data.keys(): power_spectra.append(data.pop('power_spectrum')) + data_keys = set(data.keys()) self._add_from_dict(data) # For hearder line, check if settings are loaded and clear defaults if not - if ind == 0 and not set(self.algorithm.settings.names).issubset(set(data.keys())): + if ind == 0 and not set(self.algorithm.settings.names).issubset(data_keys): self.algorithm.settings.clear() # If results part of current data added, check and update object results - if set([el + '_params' for el in self.results.params.fields]).issubset(set(data.keys())): + if set([el + '_params' for el in self.results.params.fields]).issubset(data_keys): self.results.group_results.append(self.results._get_results()) # Reconstruct frequency vector, if information is available to do so diff --git a/specparam/modes/modes.py b/specparam/modes/modes.py index f2752ed8..43b03a0b 100644 --- a/specparam/modes/modes.py +++ b/specparam/modes/modes.py @@ -56,9 +56,8 @@ def check_mode_definition(mode, options): if isinstance(mode, str): assert mode in list(options.keys()), 'Specific Mode not found.' mode = options[mode] - elif isinstance(mode, Mode): - mode = mode - else: + + if not isinstance(mode, Mode): raise ValueError('Mode input not understood.') return mode diff --git a/specparam/objs/components.py b/specparam/objs/components.py index dca5cc2e..25b6ee58 100644 --- a/specparam/objs/components.py +++ b/specparam/objs/components.py @@ -1,6 +1,7 @@ """Define model components object.""" from specparam.utils.array import unlog +from specparam.modutils.errors import NoModelError ################################################################################################### ################################################################################################### diff --git a/specparam/objs/data.py b/specparam/objs/data.py index ae96dbb7..cd3546b4 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -1,5 +1,6 @@ """Define data objects.""" +from warnings import warn from functools import wraps import numpy as np @@ -262,10 +263,10 @@ def _prepare_data(self, freqs, powers, freq_range, spectra_dim=1): # Check if freqs start at 0 and move up one value if so # Aperiodic fit gets an inf if freq of 0 is included, which leads to an error if freqs[0] == 0.0: + msg = "specparam fit warning - skipping frequency == 0, " \ + "as this causes a problem with fitting." + warn(msg, category=RuntimeWarning) freqs, powers = trim_spectrum(freqs, powers, [freqs[1], freqs.max()]) - if self.verbose: - print("\nFITTING WARNING: Skipping frequency == 0, " - "as this causes a problem with fitting.") # Calculate frequency resolution, and actual frequency range of the data freq_range = [freqs.min(), freqs.max()] diff --git a/specparam/objs/metrics.py b/specparam/objs/metrics.py index f2cc6295..d3068db3 100644 --- a/specparam/objs/metrics.py +++ b/specparam/objs/metrics.py @@ -12,8 +12,8 @@ class Metric(): Parameters ---------- - type : str - The type of measure, e.g. 'error' or 'gof'. + category : str + The category of measure, e.g. 'error' or 'gof'. measure : str The specific measure, e.g. 'r_squared'. func : callable @@ -25,10 +25,10 @@ class Metric(): and returns the desired parameter / computed value. """ - def __init__(self, type, measure, func, kwargs=None): + def __init__(self, category, measure, func, kwargs=None): """Initialize metric.""" - self.type = type + self.category = category self.measure = measure self.func = func self.result = np.nan @@ -45,17 +45,17 @@ def __repr__(self): def label(self): """Define label property.""" - return self.type + '_' + self.measure + return self.category + '_' + self.measure @property def flabel(self): """Define formatted label property.""" - if self.type == 'error': - flabel = '{} ({})'.format(self.type.capitalize(), self.measure.upper()) - if self.type == 'gof': - flabel = '{} ({})'.format(self.type.upper(), self.measure) + if self.category == 'error': + flabel = '{} ({})'.format(self.category.capitalize(), self.measure.upper()) + if self.category == 'gof': + flabel = '{} ({})'.format(self.category.upper(), self.measure) return flabel @@ -162,10 +162,10 @@ def compute_metrics(self, data, results): @property - def types(self): - """Define alias for metric type of all currently defined metrics.""" + def categories(self): + """Define alias for metric categories of all currently defined metrics.""" - return [metric.type for metric in self.metrics] + return [metric.category for metric in self.metrics] @property diff --git a/specparam/objs/results.py b/specparam/objs/results.py index 9d75a215..cdcecf74 100644 --- a/specparam/objs/results.py +++ b/specparam/objs/results.py @@ -621,7 +621,8 @@ def get_params(self, name, field=None): column is appended to the returned array, indicating the index that the peak came from. """ - return [get_group_params(gres, self.modes, name, field) for gres in self.event_group_results] + return [get_group_params(gres, self.modes, name, field) \ + for gres in self.event_group_results] def convert_results(self): diff --git a/specparam/plts/event.py b/specparam/plts/event.py index 97e2df39..3b99df68 100644 --- a/specparam/plts/event.py +++ b/specparam/plts/event.py @@ -86,5 +86,5 @@ def plot_event_model(event, **plot_kwargs): title='Fit Quality' if ind == 0 else None, drop_xticks=ind < len(event.results.metrics), add_xlabel=ind == len(event.results.metrics), - color=PARAM_COLORS[event.results.metrics.types[ind]], + color=PARAM_COLORS[event.results.metrics.categories[ind]], xlim=xlim, ax=next(axes)) diff --git a/specparam/plts/time.py b/specparam/plts/time.py index eece420f..6fa435d0 100644 --- a/specparam/plts/time.py +++ b/specparam/plts/time.py @@ -74,6 +74,6 @@ def plot_time_model(time, **plot_kwargs): time.results.time_results[time.results.metrics.labels[gof_ind]]], labels=[time.results.metrics.flabels[err_ind], time.results.metrics.flabels[gof_ind]], - colors=[PARAM_COLORS[time.results.metrics.types[err_ind]], - PARAM_COLORS[time.results.metrics.types[gof_ind]]], + colors=[PARAM_COLORS[time.results.metrics.categories[err_ind]], + PARAM_COLORS[time.results.metrics.categories[gof_ind]]], xlim=xlim, title='Fit Quality', ax=next(axes)) diff --git a/specparam/plts/utils.py b/specparam/plts/utils.py index 83d2f018..29078d09 100644 --- a/specparam/plts/utils.py +++ b/specparam/plts/utils.py @@ -93,7 +93,7 @@ def add_shades(ax, shades, colors='r', shade_alpha=0.2, shades = [shades] colors = repeat(colors) if not isinstance(colors, list) else colors - shade_alphas = repeat(shade_alpha) if not isinstance(shade_alpha, list) else alpha + shade_alphas = repeat(shade_alpha) if not isinstance(shade_alpha, list) else shade_alpha for shade, color, alpha in zip(shades, colors, shade_alphas): diff --git a/specparam/sim/params.py b/specparam/sim/params.py index 04223559..d5720d90 100644 --- a/specparam/sim/params.py +++ b/specparam/sim/params.py @@ -7,7 +7,6 @@ from specparam.data import SimParams from specparam.modes.modes import check_mode_definition from specparam.modes.definitions import AP_MODES -from specparam.utils.select import groupby from specparam.utils.checks import check_flat from specparam.modutils.errors import InconsistentDataError @@ -33,7 +32,6 @@ def collect_sim_params(aperiodic_params, periodic_params, nlv): """ return SimParams(deepcopy(aperiodic_params), - #sorted(groupby(check_flat(periodic_params), 3)), deepcopy(periodic_params), nlv) diff --git a/specparam/tests/objs/test_metrics.py b/specparam/tests/objs/test_metrics.py index 07028b9f..f06e0a95 100644 --- a/specparam/tests/objs/test_metrics.py +++ b/specparam/tests/objs/test_metrics.py @@ -55,8 +55,8 @@ def test_metrics_obj(tfm): def test_metrics_dict(tfm): - er_met_def = {'type' : 'error', 'measure' : 'mae', 'func' : compute_mean_abs_error} - gof_met_def = {'type' : 'gof', 'measure' : 'rsquared', 'func' : compute_r_squared} + er_met_def = {'category' : 'error', 'measure' : 'mae', 'func' : compute_mean_abs_error} + gof_met_def = {'category' : 'gof', 'measure' : 'rsquared', 'func' : compute_r_squared} metrics = Metrics([er_met_def, gof_met_def]) assert isinstance(metrics, Metrics) @@ -73,8 +73,8 @@ def test_metrics_dict(tfm): def test_metrics_kwargs(tfm): - er_met_def = {'type' : 'error', 'measure' : 'mae', 'func' : compute_mean_abs_error} - ar2_met_def = {'type' : 'gof', 'measure' : 'arsquared', + er_met_def = {'category' : 'error', 'measure' : 'mae', 'func' : compute_mean_abs_error} + ar2_met_def = {'category' : 'gof', 'measure' : 'arsquared', 'func' : compute_adj_r_squared, 'kwargs' : {'n_params' : lambda data, results: \ results.params.peak.size + results.params.aperiodic.size}} From 9b51f6dc150ea4ef89e0b14c045a1b74b33dcc3f Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 15 Apr 2025 12:15:16 -0400 Subject: [PATCH 32/40] udpate line spacing and settings docs --- specparam/algorithms/spectral_fit.py | 79 ++++++++++++++++++---------- 1 file changed, 50 insertions(+), 29 deletions(-) diff --git a/specparam/algorithms/spectral_fit.py b/specparam/algorithms/spectral_fit.py index fa73b3c6..a546d610 100644 --- a/specparam/algorithms/spectral_fit.py +++ b/specparam/algorithms/spectral_fit.py @@ -28,13 +28,15 @@ 'min_peak_height' : { 'type' : 'float, optional, default: 0', 'description' : \ - 'Absolute threshold for detecting peaks.\n ' \ + 'Absolute threshold for detecting peaks.' + '\n ' 'This threshold is defined in absolute units of the power spectrum (log power).', }, 'peak_threshold' : { 'type' : 'float, optional, default: 2.0', 'description' : \ - 'Relative threshold for detecting peaks.\n ' \ + 'Relative threshold for detecting peaks.' + '\n ' 'Threshold is defined in relative units of the power spectrum (standard deviation).', }, }) @@ -44,38 +46,47 @@ 'ap_percentile_thresh' : { 'type' : 'float', 'description' : \ - 'Percentile threshold, to select points from a flat spectrum for an initial aperiodic fit.\n ' - 'Points are selected at a low percentile value to restrict to non-peak points.' + 'Percentile threshold to select data from flat spectrum for an initial aperiodic fit.' + '\n ' + 'Points are selected at a low percentile value to restrict to non-peak points.', }, 'ap_guess' : { 'type' : 'list of float', 'description' : \ - 'Guess parameters for fitting the aperiodic component.\n ' - 'The length of the guess parameters should match the number & order of the aperiodic parameters.\n ' - 'If \'offset\' is a parameter & guess is None, the first value of the power spectrum is used as the guess.\n ' - 'If \'exponent\' is a parameter & guess is None, the abs(log-log slope) of first & last points is used.' + 'Guess parameters for fitting the aperiodic component.' + '\n ' + 'The guess parameters should match the length and order of the aperiodic parameters.' + '\n ' + 'If \'offset\' is a parameter, default guess is the first value of the power spectrum.' + '\n ' + 'If \'exponent\' is a parameter, ' + 'default guess is the abs(log-log slope) of first & last points.' }, 'ap_bounds' : { 'type' : 'tuple of tuple of float', 'description' : \ - 'Bounds for aperiodic fitting, as ((param1_low_bound, ...) (param1_high_bound, ...)).\n ' - 'By default, aperiodic fitting is unbound, but can be restricted here.' + 'Bounds for aperiodic fitting, as ((param1_low_bound, ...) (param1_high_bound, ...)).' + '\n ' + 'By default, aperiodic fitting is unbound, but can be restricted here.', }, 'cf_bound' : { 'type' : 'float', - 'description' : 'Parameter bounds for center frequency when fitting gaussians, in terms of +/- std dev.' + 'description' : \ + 'Parameter bounds for center frequency when fitting gaussians, as +/- std dev.', }, 'bw_std_edge' : { 'type' : 'float', 'description' : \ - 'Threshold for how far a peak has to be from edge to keep.\n ' - 'This is defined in units of gaussian standard deviation.' + 'Threshold for how far a peak has to be from edge to keep.' + '\n ' + 'This is defined in units of gaussian standard deviation.', }, 'gauss_overlap_thresh' : { 'type' : 'float', 'description' : \ - 'Degree of overlap between gaussian guesses for one to be dropped.\n ' - 'This is defined in units of gaussian standard deviation.' + 'Degree of overlap between gaussian guesses for one to be dropped.' + '\n ' + 'This is defined in units of gaussian standard deviation.', }, }) @@ -157,7 +168,8 @@ def _fit(self): self.data.freqs, *np.ndarray.flatten(self.results.params.gaussian)) # Create peak-removed (but not flattened) power spectrum - self.results.model._spectrum_peak_rm = self.data.power_spectrum - self.results.model._peak_fit + self.results.model._spectrum_peak_rm = \ + self.data.power_spectrum - self.results.model._peak_fit # Run final aperiodic fit on peak-removed power spectrum self.results.params.aperiodic = self._simple_ap_fit(\ @@ -167,7 +179,8 @@ def _fit(self): # Create remaining model components: flatspec & full power_spectrum model fit self.results.model._spectrum_flat = self.data.power_spectrum - self.results.model._ap_fit - self.results.model.modeled_spectrum = self.results.model._peak_fit + self.results.model._ap_fit + self.results.model.modeled_spectrum = \ + self.results.model._peak_fit + self.results.model._ap_fit ## PARAMETER UPDATES @@ -255,8 +268,10 @@ def _simple_ap_fit(self, freqs, power_spectrum): warnings.simplefilter("ignore") aperiodic_params, _ = curve_fit(self.modes.aperiodic.func, freqs, power_spectrum, p0=ap_guess, bounds=self._settings.ap_bounds, - maxfev=self._cf_settings.maxfev, check_finite=False, - ftol=self._cf_settings.tol, xtol=self._cf_settings.tol, + maxfev=self._cf_settings.maxfev, + check_finite=False, + ftol=self._cf_settings.tol, + xtol=self._cf_settings.tol, gtol=self._cf_settings.tol) except RuntimeError as excp: error_msg = ("Model fitting failed due to not finding parameters in " @@ -311,8 +326,10 @@ def _robust_ap_fit(self, freqs, power_spectrum): aperiodic_params, _ = curve_fit(self.modes.aperiodic.func, freqs_ignore, spectrum_ignore, p0=popt, bounds=self._settings.ap_bounds, - maxfev=self._cf_settings.maxfev, check_finite=False, - ftol=self._cf_settings.tol, xtol=self._cf_settings.tol, + maxfev=self._cf_settings.maxfev, + check_finite=False, + ftol=self._cf_settings.tol, + xtol=self._cf_settings.tol, gtol=self._cf_settings.tol) except RuntimeError as excp: error_msg = ("Model fitting failed due to not finding " @@ -480,8 +497,10 @@ def _fit_peak_guess(self, flatspec, guess): p0=np.ndarray.flatten(guess), bounds=self._get_pe_bounds(guess), jac=self.modes.periodic.jacobian, - maxfev=self._cf_settings.maxfev, check_finite=False, - ftol=self._cf_settings.tol, xtol=self._cf_settings.tol, + maxfev=self._cf_settings.maxfev, + check_finite=False, + ftol=self._cf_settings.tol, + xtol=self._cf_settings.tol, gtol=self._cf_settings.tol) except RuntimeError as excp: @@ -614,14 +633,16 @@ def _create_peak_params(self, gaus_params): # Collect peak parameter data if self.modes.periodic.name == 'gaussian': ## TEMP - peak_params[ii] = [peak[0], - self.results.model.modeled_spectrum[ind] - self.results.model._ap_fit[ind], - peak[2] * 2] + peak_params[ii] = [\ + peak[0], + self.results.model.modeled_spectrum[ind] - self.results.model._ap_fit[ind], + peak[2] * 2] ## TEMP: if self.modes.periodic.name == 'skewnorm': - peak_params[ii] = [peak[0], - self.results.model.modeled_spectrum[ind] - self.results.model._ap_fit[ind], - peak[2] * 2, peak[3]] + peak_params[ii] = [\ + peak[0], + self.results.model.modeled_spectrum[ind] - self.results.model._ap_fit[ind], + peak[2] * 2, peak[3]] return peak_params From edce4ec293af34788cd534fd3f35e8e64ba8fed0 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 16 Apr 2025 20:40:49 -0400 Subject: [PATCH 33/40] add AlgorithmCF object --- specparam/algorithms/algorithm.py | 53 ++++++++++++++------ specparam/algorithms/spectral_fit.py | 4 +- specparam/tests/algorithms/test_algorithm.py | 18 +++++++ 3 files changed, 57 insertions(+), 18 deletions(-) diff --git a/specparam/algorithms/algorithm.py b/specparam/algorithms/algorithm.py index 65457006..5281f505 100644 --- a/specparam/algorithms/algorithm.py +++ b/specparam/algorithms/algorithm.py @@ -2,25 +2,13 @@ from specparam.utils.checks import check_input_options from specparam.algorithms.settings import SettingsDefinition, SettingsValues +from specparam.modutils.docs import docs_get_section, replace_docstring_sections ################################################################################################### ################################################################################################### DATA_FORMATS = ['spectrum', 'spectra', 'spectrogram', 'spectrograms'] -CURVE_FIT_SETTINGS = SettingsDefinition({ - 'maxfev' : { - 'type' : 'int', - 'description' : 'The maximum number of calls to the curve fitting function.', - }, - 'tol' : { - 'type' : 'float', - 'description' : \ - 'The tolerance setting for curve fitting (see scipy.curve_fit: ftol / xtol / gtol).' - }, -}) - - class Algorithm(): """Template object for defining a fit algorithm. @@ -65,9 +53,6 @@ def __init__(self, name, description, public_settings, private_settings=None, self.private_settings = private_settings self._settings = SettingsValues(self.private_settings.names) - self._cf_settings_desc = CURVE_FIT_SETTINGS - self._cf_settings = SettingsValues(self._cf_settings_desc.names) - check_input_options(data_format, DATA_FORMATS, 'data_format') self.data_format = data_format @@ -150,3 +135,39 @@ def _reset_subobjects(self, modes=None, data=None, results=None): self.data = data if results is not None: self.results = results + + +## AlgorithmCF + +CURVE_FIT_SETTINGS = SettingsDefinition({ + 'maxfev' : { + 'type' : 'int', + 'description' : 'The maximum number of calls to the curve fitting function.', + }, + 'tol' : { + 'type' : 'float', + 'description' : \ + 'The tolerance setting for curve fitting (see scipy.curve_fit: ftol / xtol / gtol).' + }, +}) + +@replace_docstring_sections([docs_get_section(Algorithm.__doc__, 'Parameters')]) +class AlgorithmCF(Algorithm): + """Template object for defining a fit algorithm that uses `curve_fit`. + + Parameters + ---------- + % copied in from Algorithm + """ + + def __init__(self, name, description, public_settings, private_settings=None, + data_format='spectrum', modes=None, data=None, results=None, debug=False): + """Initialize Algorithm object.""" + + Algorithm.__init__(self, name=name, description=description, + public_settings=public_settings, private_settings=private_settings, + data_format=data_format, modes=modes, data=data, results=results, + debug=debug) + + self._cf_settings_desc = CURVE_FIT_SETTINGS + self._cf_settings = SettingsValues(self._cf_settings_desc.names) diff --git a/specparam/algorithms/spectral_fit.py b/specparam/algorithms/spectral_fit.py index a546d610..990191bd 100644 --- a/specparam/algorithms/spectral_fit.py +++ b/specparam/algorithms/spectral_fit.py @@ -10,7 +10,7 @@ from specparam.utils.select import groupby from specparam.reports.strings import gen_width_warning_str from specparam.measures.params import compute_gauss_std -from specparam.algorithms.algorithm import Algorithm +from specparam.algorithms.algorithm import AlgorithmCF from specparam.algorithms.settings import SettingsDefinition ################################################################################################### @@ -91,7 +91,7 @@ }) -class SpectralFitAlgorithm(Algorithm): +class SpectralFitAlgorithm(AlgorithmCF): """Base object defining model & algorithm for spectral parameterization. Parameters diff --git a/specparam/tests/algorithms/test_algorithm.py b/specparam/tests/algorithms/test_algorithm.py index 8b13b7a3..48f7d661 100644 --- a/specparam/tests/algorithms/test_algorithm.py +++ b/specparam/tests/algorithms/test_algorithm.py @@ -22,6 +22,8 @@ def test_algorithm(): assert algo.description == tdescription assert isinstance(algo.public_settings, SettingsDefinition) assert algo.public_settings == tsettings + for setting in algo.public_settings.names: + assert getattr(algo.settings, setting) is None def test_algorithm_settings(): @@ -43,3 +45,19 @@ def test_algorithm_settings(): settings_out = talgo.get_settings() assert isinstance(settings, model_settings) assert settings_out == settings + +def test_algorithm_cf(): + + tname = 'test_algo' + tdescription = 'Test algorithm description' + tsettings = SettingsDefinition({ + 'a' : {'type' : 'a type desc', 'description' : 'a desc'}, + 'b' : {'type' : 'b type desc', 'description' : 'b desc'}, + }) + + algo = AlgorithmCF(name=tname, description=tdescription, public_settings=tsettings) + + assert isinstance(algo._cf_settings_desc, SettingsDefinition) + assert algo._cf_settings + for setting in algo._cf_settings.names: + assert getattr(algo._cf_settings, setting) is None From e743691e52a2834147ef1d264194c82602c4a849 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 16 Apr 2025 20:55:22 -0400 Subject: [PATCH 34/40] add algo_cf initialize methods --- specparam/algorithms/algorithm.py | 40 ++++++++++++++++++++ specparam/tests/algorithms/test_algorithm.py | 17 +++++++++ 2 files changed, 57 insertions(+) diff --git a/specparam/algorithms/algorithm.py b/specparam/algorithms/algorithm.py index 5281f505..7e04fa2c 100644 --- a/specparam/algorithms/algorithm.py +++ b/specparam/algorithms/algorithm.py @@ -1,5 +1,7 @@ """Define object to manage algorithm implementations.""" +import numpy as np + from specparam.utils.checks import check_input_options from specparam.algorithms.settings import SettingsDefinition, SettingsValues from specparam.modutils.docs import docs_get_section, replace_docstring_sections @@ -171,3 +173,41 @@ def __init__(self, name, description, public_settings, private_settings=None, self._cf_settings_desc = CURVE_FIT_SETTINGS self._cf_settings = SettingsValues(self._cf_settings_desc.names) + + + def _initialize_bounds(self, mode): + """Initialize a bounds definition. + + Parameters + ---------- + mode : {'aperiodic', 'periodic'} + Which mode to initialize for. + + Returns + ------- + bounds : tuple of tuple + Guess values. + """ + + n_params = getattr(self.modes, mode).n_params + bounds = (tuple([-np.inf] * n_params), tuple([np.inf] * n_params)) + + return bounds + + def _initialize_guess(self, mode): + """Initialize a guess definition. + + Parameters + ---------- + mode : {'aperiodic', 'periodic'} + Which mode to initialize for. + + Returns + ------- + guess : 1d array + Guess values. + """ + + guess = np.zeros([getattr(self.modes, mode).n_params]) + + return guess diff --git a/specparam/tests/algorithms/test_algorithm.py b/specparam/tests/algorithms/test_algorithm.py index 48f7d661..f46f4cb3 100644 --- a/specparam/tests/algorithms/test_algorithm.py +++ b/specparam/tests/algorithms/test_algorithm.py @@ -1,5 +1,6 @@ """Tests for specparam.algorthms.algorithm.""" +from specparam.modes.modes import Modes from specparam.algorithms.settings import SettingsDefinition from specparam.algorithms.algorithm import * @@ -61,3 +62,19 @@ def test_algorithm_cf(): assert algo._cf_settings for setting in algo._cf_settings.names: assert getattr(algo._cf_settings, setting) is None + +def test_algorithm_cf_initialize(): + + algo = AlgorithmCF(name='test_algo', description='desc', + public_settings={'a' : {'type' : 'a type desc', 'description' : 'a desc'}}, + modes=Modes('fixed', 'gaussian')) + + ap_bounds = algo._initialize_bounds('aperiodic') + assert len(ap_bounds[0]) == algo.modes.aperiodic.n_params + pe_bounds = algo._initialize_bounds('periodic') + assert len(pe_bounds[0]) == algo.modes.periodic.n_params + + ap_guess = algo._initialize_guess('aperiodic') + assert len(ap_guess) == algo.modes.aperiodic.n_params + pe_guess = algo._initialize_guess('periodic') + assert len(pe_guess) == algo.modes.periodic.n_params From 33d0eb01dfdb6533e49f52f675cc7c9942525f9b Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 16 Apr 2025 21:18:02 -0400 Subject: [PATCH 35/40] use algocf initialization for ap --- specparam/algorithms/algorithm.py | 2 +- specparam/algorithms/spectral_fit.py | 32 +++++++++++++++------------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/specparam/algorithms/algorithm.py b/specparam/algorithms/algorithm.py index 7e04fa2c..558a5c2e 100644 --- a/specparam/algorithms/algorithm.py +++ b/specparam/algorithms/algorithm.py @@ -186,7 +186,7 @@ def _initialize_bounds(self, mode): Returns ------- bounds : tuple of tuple - Guess values. + Bounds values. """ n_params = getattr(self.modes, mode).n_params diff --git a/specparam/algorithms/spectral_fit.py b/specparam/algorithms/spectral_fit.py index 990191bd..e3ec15a4 100644 --- a/specparam/algorithms/spectral_fit.py +++ b/specparam/algorithms/spectral_fit.py @@ -200,23 +200,16 @@ def _get_ap_guess(self, freqs, power_spectrum): if not self._settings.ap_guess: - ap_guess = [] - for label in self.modes.aperiodic.params.labels: + ap_guess = self._initialize_guess('aperiodic') + + for label, ind in self.modes.aperiodic.params.indices.items(): if label == 'offset': # Offset guess is the power value for lowest available frequency - ap_guess.append(power_spectrum[0]) + ap_guess[ind] = power_spectrum[0] elif 'exponent' in label: # Exponent guess is a quick calculation of the log-log slope - ap_guess.append(np.abs((power_spectrum[-1] - power_spectrum[0]) / - (np.log10(freqs[-1]) - np.log10(freqs[0])))) - elif 'knee' in label: - # Knee guess set to zero (no real guess) - ap_guess.append(0) - else: - # Any other (un-anticipated) parameter set to guess of 0 - ap_guess.append(0) - - ap_guess = np.array(ap_guess) + ap_guess[ind] = np.abs((power_spectrum[-1] - power_spectrum[0]) / + (np.log10(freqs[-1]) - np.log10(freqs[0]))) return ap_guess @@ -224,6 +217,16 @@ def _get_ap_guess(self, freqs, power_spectrum): def _get_ap_bounds(self, ap_bounds): """Set the default bounds for the aperiodic fit. + Parameters + ---------- + bounds : tuple of tuple or None + Bounds definition. If None, creates default bounds. + + Returns + ------- + bounds : tuple of tuple + Bounds definition. + Notes ----- The bounds for aperiodic parameters are set in general, and currently do not update @@ -234,8 +237,7 @@ def _get_ap_bounds(self, ap_bounds): msg = 'Provided aperiodic bounds do not have right length for fit function.' assert len(ap_bounds[0]) == len(ap_bounds[1]) == self.modes.aperiodic.n_params, msg else: - ap_bounds = (tuple([-np.inf] * self.modes.aperiodic.n_params), - tuple([np.inf] * self.modes.aperiodic.n_params)) + ap_bounds = self._initialize_bounds('aperiodic') return ap_bounds From 6df8eb1adae8e33f1d35de43b23c8c1cab506c0c Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 17 Apr 2025 01:05:42 -0400 Subject: [PATCH 36/40] update _get_pe_bounds --- specparam/algorithms/algorithm.py | 10 +++- specparam/algorithms/spectral_fit.py | 77 ++++++++++++++++++---------- 2 files changed, 58 insertions(+), 29 deletions(-) diff --git a/specparam/algorithms/algorithm.py b/specparam/algorithms/algorithm.py index 558a5c2e..9f46e23a 100644 --- a/specparam/algorithms/algorithm.py +++ b/specparam/algorithms/algorithm.py @@ -185,12 +185,18 @@ def _initialize_bounds(self, mode): Returns ------- - bounds : tuple of tuple + bounds : tuple of array Bounds values. + + Notes + ----- + Output follows the needed bounds definition for curve_fit, which is: + ([low_bound_param1, low_bound_param2], + [high_bound_param1, high_bound_param2]) """ n_params = getattr(self.modes, mode).n_params - bounds = (tuple([-np.inf] * n_params), tuple([np.inf] * n_params)) + bounds = (np.array([-np.inf] * n_params), np.array([np.inf] * n_params)) return bounds diff --git a/specparam/algorithms/spectral_fit.py b/specparam/algorithms/spectral_fit.py index e3ec15a4..ff85c29d 100644 --- a/specparam/algorithms/spectral_fit.py +++ b/specparam/algorithms/spectral_fit.py @@ -1,6 +1,7 @@ """Define original spectral fitting algorithm object.""" import warnings +from itertools import repeat import numpy as np from numpy.linalg import LinAlgError @@ -446,34 +447,56 @@ def _fit_peaks(self, flatspec): return gaussian_params - ## TO GENERALIZE FOR MODES def _get_pe_bounds(self, guess): - """Get the bound for the peak fit.""" - - # Set the bounds for CF, enforce positive height value, and set bandwidth limits - # Note that 'guess' is in terms of gaussian std, so peak_width_limit is % by 2 - # This set of list comprehensions is a way to end up with bounds in the form: - # ((cf_low_peak1, height_low_peak1, bw_low_peak1, *repeated for n_peaks*), - # (cf_high_peak1, height_high_peak1, bw_high_peak, *repeated for n_peaks*)) - # ^where each value sets the bound on the specified parameter - lo_bound = [[peak[0] - 2 * self._settings.cf_bound * peak[2], 0, \ - self.settings.peak_width_limits[0] / 2] for peak in guess] - hi_bound = [[peak[0] + 2 * self._settings.cf_bound * peak[2], np.inf, \ - self.settings.peak_width_limits[1] / 2] for peak in guess] - - # Check that CF bounds are within frequency range - # If they are not, update them to be restricted to frequency range - lo_bound = [bound if bound[0] > self.data.freq_range[0] else \ - [self.data.freq_range[0], *bound[1:]] for bound in lo_bound] - hi_bound = [bound if bound[0] < self.data.freq_range[1] else \ - [self.data.freq_range[1], *bound[1:]] for bound in hi_bound] - - # Unpacks the embedded lists into flat tuples - # This is what the fit function requires as input - gaus_param_bounds = (tuple(item for sublist in lo_bound for item in sublist), - tuple(item for sublist in hi_bound for item in sublist)) - - return gaus_param_bounds + """Get the bound for the peak fit. + + Parameters + ---------- + guess : list + Guess parameters from initial peak search. + + Returns + ------- + pe_bounds : tuple of array + Bounds for periodic fit. + """ + + n_pe_params = self.modes.periodic.n_params + bounds = repeat(self._initialize_bounds('periodic')) + bounds_lo = np.empty(len(guess) * n_pe_params) + bounds_hi = np.empty(len(guess) * n_pe_params) + + for p_ind, peak in enumerate(guess): + for label, ind in self.modes.periodic.params.indices.items(): + + pbounds_lo, pbounds_hi = next(bounds) + + if label == 'cf': + # Set boundaries on CF, weighted by the bandwidth + peak_bw = peak[self.modes.periodic.params.indices['bw']] + lcf = peak[ind] - 2 * self._settings.cf_bound * peak_bw + hcf = peak[ind] + 2 * self._settings.cf_bound * peak_bw + # Check that CF bounds are within frequency range - if not restrict to range + pbounds_lo[ind] = lcf if lcf > self.data.freq_range[0] \ + else self.data.freq_range[0] + pbounds_hi[ind] = hcf if hcf < self.data.freq_range[1] \ + else self.data.freq_range[1] + + if label == 'pw': + # Enforce positive values for height + pbounds_lo[ind] = 0 + + if label == 'bw': + # Set bandwidth limits, converting limits from Hz to guess params in std + pbounds_lo[ind] = self.settings.peak_width_limits[0] / 2 + pbounds_hi[ind] = self.settings.peak_width_limits[1] / 2 + + bounds_lo[p_ind*n_pe_params:(p_ind+1)*n_pe_params] = pbounds_lo + bounds_hi[p_ind*n_pe_params:(p_ind+1)*n_pe_params] = pbounds_hi + + pe_bounds = (bounds_lo, bounds_hi) + + return pe_bounds def _fit_peak_guess(self, flatspec, guess): From d96ed69423fd4bc53e92e74d74f57e1a75efc539 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 17 Apr 2025 01:15:00 -0400 Subject: [PATCH 37/40] udpate _fit_peaks for modes --- specparam/algorithms/spectral_fit.py | 36 +++++++++++++--------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/specparam/algorithms/spectral_fit.py b/specparam/algorithms/spectral_fit.py index ff85c29d..88af1bcd 100644 --- a/specparam/algorithms/spectral_fit.py +++ b/specparam/algorithms/spectral_fit.py @@ -356,9 +356,8 @@ def _fit_peaks(self, flatspec): Returns ------- - gaussian_params : 2d array - Parameters that define the gaussian fit(s). - Each row is a gaussian, as [mean, height, standard deviation]. + peak_params : 2d array + Parameters that define the peak fit(s). """ # Take a copy of the flattened spectrum to iterate across @@ -379,7 +378,7 @@ def _fit_peaks(self, flatspec): if max_height <= self.settings.peak_threshold * np.std(flat_iter): break - # Set the guess parameters for gaussian fitting, specifying the mean and height + # Set the guess parameters for peak fitting, specifying the mean and height guess_freq = self.data.freqs[max_ind] guess_height = max_height @@ -404,7 +403,7 @@ def _fit_peaks(self, flatspec): for ind in [le_ind, ri_ind] if ind is not None]) # Use the shortest side to estimate full-width, half max (converted to Hz) - # and use this to estimate that guess for gaussian standard deviation + # and use this to estimate that guess for peak standard deviation fwhm = short_side * 2 * self.data.freq_res guess_std = compute_gauss_std(fwhm) @@ -414,23 +413,22 @@ def _fit_peaks(self, flatspec): guess_std = np.mean(self.settings.peak_width_limits) # Check that guess value isn't outside preset limits - restrict if so - # This also converts the peak_width_limits from 2-sided BW to 1-sided gaussian std + # This also converts the peak_width_limits from 2-sided BW to 1-sided std # Note: without this, curve_fitting fails if given guess > or < bounds if guess_std < self.settings.peak_width_limits[0] / 2: guess_std = self.settings.peak_width_limits[0] / 2 if guess_std > self.settings.peak_width_limits[1] / 2: guess_std = self.settings.peak_width_limits[0] / 2 - # Collect guess parameters and subtract this guess gaussian from the data - current_guess_params = (guess_freq, guess_height, guess_std) - - ## TEMP - if self.modes.periodic.name == 'skewnorm': - guess_skew = 0 - current_guess_params = (guess_freq, guess_height, guess_std, guess_skew) + # Collect guess parameters + cur_guess = [0] * self.modes.periodic.n_params + cur_guess[self.modes.periodic.params.indices['cf']] = guess_freq + cur_guess[self.modes.periodic.params.indices['pw']] = guess_height + cur_guess[self.modes.periodic.params.indices['bw']] = guess_std - guess = np.vstack((guess, current_guess_params)) - peak_gauss = self.modes.periodic.func(self.data.freqs, *current_guess_params) + # Fit and subtract guess peak from the spectrum + guess = np.vstack((guess, cur_guess)) + peak_gauss = self.modes.periodic.func(self.data.freqs, *cur_guess) flat_iter = flat_iter - peak_gauss # Check peaks based on edges, and on overlap, dropping any that violate requirements @@ -439,12 +437,12 @@ def _fit_peaks(self, flatspec): # If there are peak guesses, fit the peaks, and sort results if len(guess) > 0: - gaussian_params = self._fit_peak_guess(flatspec, guess) - gaussian_params = gaussian_params[gaussian_params[:, 0].argsort()] + peak_params = self._fit_peak_guess(flatspec, guess) + peak_params = peak_params[peak_params[:, 0].argsort()] else: - gaussian_params = np.empty([0, self.modes.periodic.n_params]) + peak_params = np.empty([0, self.modes.periodic.n_params]) - return gaussian_params + return peak_params def _get_pe_bounds(self, guess): From 6048b8ca1d2f2d973229918a257fced511f8c418 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 17 Apr 2025 01:18:41 -0400 Subject: [PATCH 38/40] update peak tuning for modes --- specparam/algorithms/spectral_fit.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/specparam/algorithms/spectral_fit.py b/specparam/algorithms/spectral_fit.py index 88af1bcd..048d8260 100644 --- a/specparam/algorithms/spectral_fit.py +++ b/specparam/algorithms/spectral_fit.py @@ -542,7 +542,6 @@ def _fit_peak_guess(self, flatspec, guess): return pe_params - ## TO GENERALIZE FOR MODES def _drop_peak_cf(self, guess): """Check whether to drop peaks based on center's proximity to the edge of the spectrum. @@ -557,8 +556,8 @@ def _drop_peak_cf(self, guess): Guess parameters for periodic peak fits. Shape: [n_peaks, n_params_per_peak]. """ - cf_params = guess[:, 0] - bw_params = guess[:, 2] * self._settings.bw_std_edge + cf_params = guess[:, self.modes.periodic.params.indices['cf']] + bw_params = guess[:, self.modes.periodic.params.indices['bw']] * self._settings.bw_std_edge # Check if peaks within drop threshold from the edge of the frequency range keep_peak = \ @@ -589,14 +588,17 @@ def _drop_peak_overlap(self, guess): For any peaks with an overlap > threshold, the lowest height guess peak is dropped. """ + inds = self.modes.periodic.params.indices + # Sort the peak guesses by increasing frequency # This is so adjacent peaks can be compared from right to left - guess = sorted(guess, key=lambda x: float(x[0])) + guess = sorted(guess, key=lambda x: float(x[inds['cf']])) # Calculate standard deviation bounds for checking amount of overlap # The bounds are the gaussian frequency +/- gaussian standard deviation - bounds = [[peak[0] - peak[2] * self._settings.gauss_overlap_thresh, - peak[0] + peak[2] * self._settings.gauss_overlap_thresh] for peak in guess] + bounds = [[peak[inds['cf']] - peak[inds['bw']] * self._settings.gauss_overlap_thresh, + peak[inds['cf']] + peak[inds['bw']] * self._settings.gauss_overlap_thresh]\ + for peak in guess] # Loop through peak bounds, comparing current bound to that of next peak # If the left peak's upper bound extends pass the right peaks lower bound, @@ -608,7 +610,7 @@ def _drop_peak_overlap(self, guess): # Check if bound of current peak extends into next peak if b_0[1] > b_1[0]: - # If so, get the index of the gaussian with the lowest height (to drop) + # If so, get the index of the peak with the lowest height (to drop) drop_inds.append([ind, ind + 1][np.argmin([guess[ind][1], guess[ind + 1][1]])]) # Drop any peaks guesses that overlap too much, based on threshold From 736a0be1cfd4217d379234270ac433582f407855 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 17 Apr 2025 01:25:11 -0400 Subject: [PATCH 39/40] udpate param conversion to use modes --- specparam/algorithms/spectral_fit.py | 39 +++++++++++++--------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/specparam/algorithms/spectral_fit.py b/specparam/algorithms/spectral_fit.py index 048d8260..df3e3ef7 100644 --- a/specparam/algorithms/spectral_fit.py +++ b/specparam/algorithms/spectral_fit.py @@ -621,18 +621,18 @@ def _drop_peak_overlap(self, guess): ## TO GENERALIZE FOR MODES - def _create_peak_params(self, gaus_params): + def _create_peak_params(self, fit_peak_params): """Copies over the gaussian params to peak outputs, updating as appropriate. Parameters ---------- - gaus_params : 2d array - Parameters that define the gaussian fit(s), as gaussian parameters. + fit_peak_params : 2d array + Parameters that define the peak parameters directly fit to the spectrum. Returns ------- peak_params : 2d array - Fitted parameter values for the peaks, with each row as [CF, PW, BW]. + Updated parameter values for the peaks. Notes ----- @@ -649,25 +649,22 @@ def _create_peak_params(self, gaus_params): with `freqs`, `modeled_spectrum` and `_ap_fit` all required to be available. """ - peak_params = np.empty((len(gaus_params), self.modes.periodic.n_params)) + inds = self.modes.periodic.params.indices + + peak_params = np.empty((len(fit_peak_params), self.modes.periodic.n_params)) + + for ii, peak in enumerate(fit_peak_params): - for ii, peak in enumerate(gaus_params): + cpeak = peak.copy() # Gets the index of the power_spectrum at the frequency closest to the CF of the peak - ind = np.argmin(np.abs(self.data.freqs - peak[0])) - - # Collect peak parameter data - if self.modes.periodic.name == 'gaussian': ## TEMP - peak_params[ii] = [\ - peak[0], - self.results.model.modeled_spectrum[ind] - self.results.model._ap_fit[ind], - peak[2] * 2] - - ## TEMP: - if self.modes.periodic.name == 'skewnorm': - peak_params[ii] = [\ - peak[0], - self.results.model.modeled_spectrum[ind] - self.results.model._ap_fit[ind], - peak[2] * 2, peak[3]] + cf_ind = np.argmin(np.abs(self.data.freqs - peak[inds['cf']])) + cpeak[inds['pw']] = \ + self.results.model.modeled_spectrum[cf_ind] - self.results.model._ap_fit[cf_ind] + + # Bandwidth is updated to be 'two-sided' (as opposed to one-sided std dev) + cpeak[inds['bw']] = peak[inds['bw']] * 2 + + peak_params[ii] = cpeak return peak_params From c48888870e6d4a6caadbd595d98fa483f6224e4f Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 17 Apr 2025 01:31:36 -0400 Subject: [PATCH 40/40] update / clean up descriptions --- specparam/algorithms/spectral_fit.py | 38 ++++++++++++++-------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/specparam/algorithms/spectral_fit.py b/specparam/algorithms/spectral_fit.py index df3e3ef7..763e78de 100644 --- a/specparam/algorithms/spectral_fit.py +++ b/specparam/algorithms/spectral_fit.py @@ -73,21 +73,21 @@ 'cf_bound' : { 'type' : 'float', 'description' : \ - 'Parameter bounds for center frequency when fitting gaussians, as +/- std dev.', + 'Parameter bounds for center frequency when fitting peaks, as +/- std dev.', }, 'bw_std_edge' : { 'type' : 'float', 'description' : \ 'Threshold for how far a peak has to be from edge to keep.' '\n ' - 'This is defined in units of gaussian standard deviation.', + 'This is defined in units of peak standard deviation.', }, 'gauss_overlap_thresh' : { 'type' : 'float', 'description' : \ - 'Degree of overlap between gaussian guesses for one to be dropped.' + 'Degree of overlap between peak guesses for one to be dropped.' '\n ' - 'This is defined in units of gaussian standard deviation.', + 'This is defined in units of peak standard deviation.', }, }) @@ -147,7 +147,8 @@ def _fit_prechecks(self, verbose=True): if verbose: if 1.5 * self.data.freq_res >= self.settings.peak_width_limits[0]: - print(gen_width_warning_str(self.data.freq_res, self.settings.peak_width_limits[0])) + print(gen_width_warning_str(self.data.freq_res, + self.settings.peak_width_limits[0])) def _fit(self): @@ -159,7 +160,7 @@ def _fit(self): temp_aperiodic_params = self._robust_ap_fit(self.data.freqs, self.data.power_spectrum) temp_ap_fit = self.modes.aperiodic.func(self.data.freqs, *temp_aperiodic_params) - # Find peaks from the flattened power spectrum, and fit them with gaussians + # Find peaks from the flattened power spectrum, and fit them temp_spectrum_flat = self.data.power_spectrum - temp_ap_fit self.results.params.gaussian = self._fit_peaks(temp_spectrum_flat) @@ -185,7 +186,7 @@ def _fit(self): ## PARAMETER UPDATES - # Convert gaussian definitions to peak parameters + # Convert fit peak parameters to updated values self.results.params.peak = self._create_peak_params(self.results.params.gaussian) @@ -428,8 +429,8 @@ def _fit_peaks(self, flatspec): # Fit and subtract guess peak from the spectrum guess = np.vstack((guess, cur_guess)) - peak_gauss = self.modes.periodic.func(self.data.freqs, *cur_guess) - flat_iter = flat_iter - peak_gauss + peak_fit = self.modes.periodic.func(self.data.freqs, *cur_guess) + flat_iter = flat_iter - peak_fit # Check peaks based on edges, and on overlap, dropping any that violate requirements guess = self._drop_peak_cf(guess) @@ -571,7 +572,7 @@ def _drop_peak_cf(self, guess): def _drop_peak_overlap(self, guess): - """Checks whether to drop gaussians based on amount of overlap. + """Checks whether to drop peaks based on amount of overlap. Parameters ---------- @@ -595,7 +596,7 @@ def _drop_peak_overlap(self, guess): guess = sorted(guess, key=lambda x: float(x[inds['cf']])) # Calculate standard deviation bounds for checking amount of overlap - # The bounds are the gaussian frequency +/- gaussian standard deviation + # The bounds are the center frequency +/- width (standard deviation) bounds = [[peak[inds['cf']] - peak[inds['bw']] * self._settings.gauss_overlap_thresh, peak[inds['cf']] + peak[inds['bw']] * self._settings.gauss_overlap_thresh]\ for peak in guess] @@ -620,9 +621,8 @@ def _drop_peak_overlap(self, guess): return guess - ## TO GENERALIZE FOR MODES def _create_peak_params(self, fit_peak_params): - """Copies over the gaussian params to peak outputs, updating as appropriate. + """Copies over the fit peak parameters output parameters, updating as appropriate. Parameters ---------- @@ -636,14 +636,14 @@ def _create_peak_params(self, fit_peak_params): Notes ----- - The gaussian center is unchanged as the peak center frequency. + The center frequency estimate is unchanged as the peak center frequency. - The gaussian height is updated to reflect the height of the peak above - the aperiodic fit. This is returned instead of the gaussian height, as - the gaussian height is harder to interpret, due to peak overlaps. + The peak height is updated to reflect the height of the peak above + the aperiodic fit. This is returned instead of the fit peak height, as + the fit height is harder to interpret, due to peak overlaps. - The gaussian standard deviation is updated to be 'both-sided', to reflect the - 'bandwidth' of the peak, as opposed to the gaussian parameter, which is 1-sided. + The peak bandwidth is updated to be 'both-sided', to reflect the overal width + of the peak, as opposed to the fit parameter, which is 1-sided standard deviation. Performing this conversion requires that the model has been run, with `freqs`, `modeled_spectrum` and `_ap_fit` all required to be available.