"""Utility functions."""
import logging
import numpy as np
from scipy.signal import periodogram
from tensorpac.methods.meth_pac import _kl_hr
from tensorpac.pac import _PacObj, _PacVisual
from tensorpac.io import set_log_level
from matplotlib.gridspec import GridSpec
import matplotlib.pyplot as plt
logger = logging.getLogger('tensorpac')
[docs]def pac_vec(f_pha='mres', f_amp='mres'):
"""Generate cross-frequency coupling vectors.
Parameters
----------
Frequency vector for the phase and amplitude. Here you can use
several forms to define those vectors :
* Basic list/tuple (ex: [2, 4] or [8, 12]...)
* List of frequency bands (ex: [[2, 4], [5, 7]]...)
* Dynamic definition : (start, stop, width, step)
* Range definition (ex : np.arange(3) => [[0, 1], [1, 2]])
* Using a string. `f_pha` and `f_amp` can be 'lres', 'mres', 'hres'
respectively for low, middle and high resolution vectors. In that
case, it uses the definition proposed by Bahramisharif et al. 2013
:cite:`bahramisharif2013propagating` i.e
f_pha = [f - f / 4, f + f / 4] and f_amp = [f - f / 8, f + f / 8]
Returns
-------
f_pha, f_amp : array_like
Arrays containing the pairs of phase and amplitude frequencies. Each
vector have a shape of (N, 2).
"""
nb_fcy = dict(lres=10, mres=30, hres=50, demon=70, hulk=100)
if isinstance(f_pha, str):
# get where phase frequencies start / finish / number
f_pha_start, f_pha_end = 2, 20
f_pha_nb = nb_fcy[f_pha]
# f_pha = [f - f / 4, f + f / 4]
f_pha_mid = np.linspace(f_pha_start, f_pha_end, f_pha_nb)
f_pha = np.c_[f_pha_mid - f_pha_mid / 4., f_pha_mid + f_pha_mid / 4.]
if isinstance(f_amp, str):
# get where amplitude frequencies start / finish / number
f_amp_start, f_amp_end = 60, 160
f_amp_nb = nb_fcy[f_amp]
# f_amp = [f - f / 8, f + f / 8]
f_amp_mid = np.linspace(f_amp_start, f_amp_end, f_amp_nb)
f_amp = np.c_[f_amp_mid - f_amp_mid / 8., f_amp_mid + f_amp_mid / 8.]
return _check_freq(f_pha), _check_freq(f_amp)
def _check_freq(f):
"""Check the frequency definition."""
f = np.atleast_2d(np.asarray(f))
#
if len(f.reshape(-1)) == 1:
raise ValueError("The length of f should at least be 2.")
elif 2 in f.shape: # f of shape (N, 2) or (2, N)
if f.shape[1] is not 2:
f = f.T
elif np.squeeze(f).shape == (4,): # (f_start, f_end, f_width, f_step)
f = _pair_vectors(*tuple(np.squeeze(f)))
else: # Sequential
f = f.reshape(-1)
f.sort()
f = np.c_[f[0:-1], f[1::]]
return f
def _pair_vectors(f_start, f_end, f_width, f_step):
# Generate two array for phase and amplitude :
fdown = np.arange(f_start, f_end - f_width, f_step)
fup = np.arange(f_start + f_width, f_end, f_step)
return np.c_[fdown, fup]
[docs]def pac_trivec(f_start=60., f_end=160., f_width=10.):
"""Generate triangular vector.
By contrast with the pac_vec function, this function generate frequency
vector with an increasing frequency bandwidth.
Parameters
----------
f_start : float | 60.
Starting frequency.
f_end : float | 160.
Ending frequency.
f_width : float | 10.
Frequency bandwidth increase between each band.
Returns
-------
f : array_like
The triangular vector.
tridx : array_like
The triangular index for the reconstruction.
"""
starting = np.arange(f_start, f_end + f_width, f_width)
f, tridx = np.array([]), np.array([])
for num, k in enumerate(starting[0:-1]):
# Lentgh of the vector to build :
le = len(starting) - (num + 1)
# Create the frequency vector for this starting frequency :
fst = np.c_[np.full(le, k), starting[num + 1::]]
nfst = fst.shape[0]
# Create the triangular index for this vector of frequencies :
idx = np.c_[np.flipud(np.arange(nfst)), np.full(nfst, num)]
tridx = np.concatenate((tridx, idx), axis=0) if tridx.size else idx
f = np.concatenate((f, fst), axis=0) if f.size else fst
return f, tridx
[docs]class PSD(object):
"""Power Spectrum Density for electrophysiological brain data.
Parameters
----------
x : array_like
Array of data of shape (n_epochs, n_times)
sf : float
The sampling frequency.
"""
[docs] def __init__(self, x, sf):
"""Init."""
assert isinstance(x, np.ndarray) and (x.ndim == 2), (
"x should be a 2d array of shape (n_epochs, n_times)")
self._n_trials, self._n_times = x.shape
logger.info(f"Compute PSD over {self._n_trials} trials and "
f"{self._n_times} time points")
self._freqs, self._psd = periodogram(x, fs=sf, window=None,
nfft=self._n_times,
detrend='constant',
return_onesided=True,
scaling='density', axis=1)
[docs] def plot(self, f_min=None, f_max=None, confidence=95, interp=None,
log=False, grid=True, fz_title=18, fz_labels=15):
"""Plot the PSD.
Parameters
----------
f_min, f_max : (int, float) | None
Frequency bounds to use for plotting
confidence : (int, float) | None
Light gray confidence interval. If None, no interval will be
displayed
interp : int | None
Line interpolation integer. For example, if interp is 10 the number
of points is going to be multiply by 10
log : bool | False
Use a log scale representation
grid : bool | True
Add a grid to the plot
fz_title : int | 18
Font size for the title
fz_labels : int | 15
Font size the x/y labels
Returns
-------
ax : Matplotlib axis
The matplotlib axis that contains the figure
"""
import matplotlib.pyplot as plt
f_types = (int, float)
# interpolation
xvec, yvec = self._freqs, self._psd
if isinstance(interp, int) and (interp > 1):
# from scipy.interpolate import make_interp_spline, BSpline
from scipy.interpolate import interp1d
xnew = np.linspace(xvec[0], xvec[-1], len(xvec) * interp)
f = interp1d(xvec, yvec, kind='quadratic', axis=1)
yvec = f(xnew)
xvec = xnew
# (f_min, f_max)
f_min = xvec[0] if not isinstance(f_min, f_types) else f_min
f_max = xvec[-1] if not isinstance(f_max, f_types) else f_max
# plot main psd
plt.plot(xvec, yvec.mean(0), color='black',
label='mean PSD over trials')
# plot confidence interval
if isinstance(confidence, (int, float)) and (0 < confidence < 100):
logger.info(f" Add {confidence}th confidence interval")
interval = (100. - confidence) / 2
kw = dict(axis=0, interpolation='nearest')
psd_min = np.percentile(yvec, interval, **kw)
psd_max = np.percentile(yvec, 100. - interval, **kw)
plt.fill_between(xvec, psd_max, psd_min, color='lightgray',
alpha=0.5,
label=f"{confidence}th confidence interval")
plt.legend(fontsize=fz_labels)
plt.xlabel("Frequencies (Hz)", fontsize=fz_labels)
plt.ylabel("Power (V**2/Hz)", fontsize=fz_labels)
plt.title(f"PSD mean over {self._n_trials} trials", fontsize=fz_title)
plt.xlim(f_min, f_max)
if log:
from matplotlib.ticker import ScalarFormatter
plt.xscale('log', basex=10)
plt.gca().xaxis.set_major_formatter(ScalarFormatter())
if grid:
plt.grid(color='grey', which='major', linestyle='-',
linewidth=1., alpha=0.5)
plt.grid(color='lightgrey', which='minor', linestyle='--',
linewidth=0.5, alpha=0.5)
return plt.gca()
[docs] def plot_st_psd(self, f_min=None, f_max=None, log=False, grid=True,
fz_title=18, fz_labels=15, fz_cblabel=15, **kw):
"""Single-trial PSD plot.
Parameters
----------
f_min, f_max : (int, float) | None
Frequency bounds to use for plotting
log : bool | False
Use a log scale representation
grid : bool | True
Add a grid to the plot
fz_title : int | 18
Font size for the title
fz_labels : int | 15
Font size the x/y labels
fz_cblabel : int | 15
Font size the colorbar label labels
Returns
-------
ax : Matplotlib axis
The matplotlib axis that contains the figure
"""
# manage input variables
kw['fz_labels'] = kw.get('fz_labels', fz_labels)
kw['fz_title'] = kw.get('fz_title', fz_title)
kw['fz_cblabel'] = kw.get('fz_cblabel', fz_title)
kw['xlabel'] = kw.get('xlabel', "Frequencies (Hz)")
kw['ylabel'] = kw.get('ylabel', "Trials")
kw['title'] = kw.get('title', "Single-trial PSD")
kw['cblabel'] = kw.get('cblabel', "Power (V**2/Hz)")
# (f_min, f_max)
xvec, psd = self._freqs, self._psd
f_types = (int, float)
f_min = xvec[0] if not isinstance(f_min, f_types) else f_min
f_max = xvec[-1] if not isinstance(f_max, f_types) else f_max
# locate (f_min, f_max) indices
f_min_idx = np.abs(xvec - f_min).argmin()
f_max_idx = np.abs(xvec - f_max).argmin()
sl_freq = slice(f_min_idx, f_max_idx)
xvec = xvec[sl_freq]
psd = psd[:, sl_freq]
# make the 2D plot
_viz = _PacVisual()
trials = np.arange(self._n_trials)
_viz.pacplot(psd, xvec, trials, **kw)
if log:
from matplotlib.ticker import ScalarFormatter
plt.xscale('log', basex=10)
plt.gca().xaxis.set_major_formatter(ScalarFormatter())
if grid:
plt.grid(color='grey', which='major', linestyle='-',
linewidth=1., alpha=0.5)
plt.grid(color='lightgrey', which='minor', linestyle='--',
linewidth=0.5, alpha=0.5)
return plt.gca()
[docs] def show(self):
"""Display the PSD figure."""
import matplotlib.pyplot as plt
plt.show()
@property
def freqs(self):
"""Get the frequency vector."""
return self._freqs
@property
def psd(self):
"""Get the psd value."""
return self._psd
[docs]class BinAmplitude(_PacObj):
"""Bin the amplitude according to the phase.
Parameters
----------
x : array_like
Array of data of shape (n_epochs, n_times)
sf : float
The sampling frequency
f_pha : tuple, list | [2, 4]
List of two floats describing the frequency bounds for extracting the
phase
f_amp : tuple, list | [60, 80]
List of two floats describing the frequency bounds for extracting the
amplitude
n_bins : int | 18
Number of bins to use to binarize the phase and the amplitude
dcomplex : {'wavelet', 'hilbert'}
Method for the complex definition. Use either 'hilbert' or
'wavelet'.
cycle : tuple | (3, 6)
Control the number of cycles for filtering (only if dcomplex is
'hilbert'). Should be a tuple of integers where the first one
refers to the number of cycles for the phase and the second for the
amplitude :cite:`bahramisharif2013propagating`.
width : int | 7
Width of the Morlet's wavelet.
edges : int | None
Number of samples to discard to avoid edge effects due to filtering
"""
[docs] def __init__(self, x, sf, f_pha=[2, 4], f_amp=[60, 80], n_bins=18,
dcomplex='hilbert', cycle=(3, 6), width=7, edges=None,
n_jobs=-1):
"""Init."""
_PacObj.__init__(self, f_pha=f_pha, f_amp=f_amp, dcomplex=dcomplex,
cycle=cycle, width=width)
# check
x = np.atleast_2d(x)
assert x.ndim <= 2, ("`x` input should be an array of shape "
"(n_epochs, n_times)")
assert isinstance(sf, (int, float)), ("`sf` input should be a integer "
"or a float")
assert all([isinstance(k, (int, float)) for k in f_pha]), (
"`f_pha` input should be a list of two integers / floats")
assert all([isinstance(k, (int, float)) for k in f_amp]), (
"`f_amp` input should be a list of two integers / floats")
assert isinstance(n_bins, int), "`n_bins` should be an integer"
logger.info(f"Binning {f_amp}Hz amplitude according to {f_pha}Hz "
"phase")
# extract phase and amplitude
kw = dict(keepfilt=False, edges=edges, n_jobs=n_jobs)
pha = self.filter(sf, x, 'phase', **kw)
amp = self.filter(sf, x, 'amplitude', **kw)
# binarize amplitude according to phase
self._amplitude = _kl_hr(pha, amp, n_bins, mean_bins=False).squeeze()
self.n_bins = n_bins
[docs] def plot(self, unit='rad', normalize=False, **kw):
"""Plot the amplitude.
Parameters
----------
unit : {'rad', 'deg'}
The unit to use for the phase. Use either 'deg' for degree or 'rad'
for radians
normalize : bool | None
Normalize the histogram by the maximum
kw : dict | {}
Additional inputs are passed to the matplotlib.pyplot.bar function
Returns
-------
ax : Matplotlib axis
The matplotlib axis that contains the figure
"""
import matplotlib.pyplot as plt
assert unit in ['rad', 'deg']
if unit == 'rad':
self._phase = np.linspace(-np.pi, np.pi, self.n_bins)
width = 2 * np.pi / self.n_bins
elif unit == 'deg':
self._phase = np.linspace(-180, 180, self.n_bins)
width = 360 / self.n_bins
amp_mean = self._amplitude.mean(1)
if normalize:
amp_mean /= amp_mean.max()
plt.bar(self._phase, amp_mean, width=width, **kw)
plt.xlabel(f"Frequency phase ({self.n_bins} bins)", fontsize=18)
plt.ylabel("Amplitude", fontsize=18)
plt.title("Binned amplitude")
plt.autoscale(enable=True, axis='x', tight=True)
[docs] def show(self):
"""Show the figure."""
import matplotlib.pyplot as plt
plt.show()
@property
def amplitude(self):
"""Get the amplitude value."""
return self._amplitude
@property
def phase(self):
"""Get the phase value."""
return self._phase
[docs]class ITC(_PacObj, _PacVisual):
"""Compute the Inter-Trials Coherence (ITC).
The Inter-Trials Coherence (ITC) is a measure of phase consistency over
trials for a single recording site (electrode / sensor etc.).
Parameters
----------
x : array_like
Array of data of shape (n_epochs, n_times)
sf : float
The sampling frequency
f_pha : tuple, list | [2, 4]
List of two floats describing the frequency bounds for extracting the
phase
dcomplex : {'wavelet', 'hilbert'}
Method for the complex definition. Use either 'hilbert' or
'wavelet'.
cycle : tuple | 3
Control the number of cycles for filtering the phase (only if dcomplex
is 'hilbert').
width : int | 7
Width of the Morlet's wavelet.
edges : int | None
Number of samples to discard to avoid edge effects due to filtering
"""
[docs] def __init__(self, x, sf, f_pha=[2, 4], dcomplex='hilbert', cycle=3,
width=7, edges=None, n_jobs=-1, verbose=None):
"""Init."""
set_log_level(verbose)
_PacObj.__init__(self, f_pha=f_pha, f_amp=[60, 80], dcomplex=dcomplex,
cycle=(cycle, 6), width=width)
_PacVisual.__init__(self)
# check
x = np.atleast_2d(x)
assert x.ndim <= 2, ("`x` input should be an array of shape "
"(n_epochs, n_times)")
self._n_trials = x.shape[0]
logger.info("Inter-Trials Coherence (ITC)")
logger.info(f" extracting {len(self.xvec)} phases")
# extract phase and amplitude
kw = dict(keepfilt=False, edges=edges, n_jobs=n_jobs)
pha = self.filter(sf, x, 'phase', **kw)
# compute itc
self._itc = np.abs(np.exp(1j * pha).mean(1)).squeeze()
self._sf = sf
[docs] def plot(self, times=None, **kw):
"""Plot the Inter-Trials Coherence.
Parameters
----------
times : array_like | None
Custom time vector to use
kw : dict | {}
Additional inputs are either pass to the matplotlib.pyplot.plot
function if a single phase band is used, otherwise to the
matplotlib.pyplot.pcolormesh function
Returns
-------
ax : Matplotlib axis
The matplotlib axis that contains the figure
"""
import matplotlib.pyplot as plt
n_pts = self._itc.shape[-1]
if not isinstance(times, np.ndarray):
times = np.arange(n_pts) / self._sf
times = times[self._edges]
assert len(times) == n_pts, ("The length of the time vector should be "
"{n_pts}")
xlab = 'Time'
title = f"Inter-Trials Coherence ({self._n_trials} trials)"
if self._itc.ndim == 1:
plt.plot(times, self._itc, **kw)
elif self._itc.ndim == 2:
vmin = kw.get('vmin', np.percentile(self._itc, 1))
vmax = kw.get('vmax', np.percentile(self._itc, 99))
self.pacplot(self._itc, times, self.xvec, vmin=vmin, vmax=vmax,
ylabel="Frequency for phase (Hz)", xlabel=xlab,
title=title, **kw)
return plt.gca()
[docs] def show(self):
"""Show the figure."""
import matplotlib.pyplot as plt
plt.show()
@property
def itc(self):
"""Get the itc value."""
return self._itc
[docs]class PeakLockedTF(_PacObj, _PacVisual):
"""Peak-Locked Time-frequency representation.
This class can be used in order to re-align time-frequency representations
around a time-point (cue) according to the closest phase peak. This type
of visualization can bring out a cyclic behavior of the amplitude at a
given phase, potentially indicating the presence of a phase-amplitude
coupling. Here's the detailed pipeline :
* Filter around a single phase frequency bands and across multiple
amplitude frequencies
* Use a `cue` which define the time-point to use for the realignment
* Detect in the filtered phase the closest peak to the cue. This step
is repeated to each trial in order to get a list of length (n_epochs)
that contains the number of sample (shift) so that if the phase is
moved, the peak fall onto the cue. A positive shift indicates that
the phase is moved forward while a negative shift is for a backward
move
* Apply, to each trial, this shift to the amplitude
* Plot the mean re-aligned amplitudes
Parameters
----------
x : array_like
Array of data of shape (n_epochs, n_times)
sf : float
The sampling frequency
cue : int, float
Time-point to use in order to detect the closest phase peak. This
parameter works in conjunction with the `times` input below. Use
either :
* An integer and `times` is None to indicate that you want to
realign according to a time-point in sample
* A integer or a float with `times` the time vector if you want
that Tensorpac automatically infer the sample number around which
to align
times : array_like | None
Time vector
f_pha : tuple, list | [2, 4]
List of two floats describing the frequency bounds for extracting the
phase
f_amp : tuple, list | [60, 80]
Frequency vector for the amplitude. Here you can use several forms to
define those vectors :
* Dynamic definition : (start, stop, width, step)
* Using a string : `f_amp` can be 'lres', 'mres', 'hres'
respectively for low, middle and high resolution vectors
cycle : tuple | (3, 6)
Control the number of cycles for filtering. Should be a tuple of
integers where the first one refers to the number of cycles for the
phase and the second for the amplitude
:cite:`bahramisharif2013propagating`.
"""
[docs] def __init__(self, x, sf, cue, times=None, f_pha=[5, 7], f_amp='hres',
cycle=(3, 6), n_jobs=-1, verbose=None):
"""Init."""
set_log_level(verbose)
# initialize to retrieve filtering methods
_PacObj.__init__(self, f_pha=f_pha, f_amp=f_amp, dcomplex='hilbert',
cycle=cycle)
_PacVisual.__init__(self)
logger.info("PeakLockedTF object defined")
# inputs checking
x = np.atleast_2d(x)
assert isinstance(x, np.ndarray) and (x.ndim == 2)
assert isinstance(sf, (int, float))
assert isinstance(cue, (int, float))
assert isinstance(f_pha, (list, tuple)) and (len(f_pha) == 2)
n_epochs, n_times = x.shape
# manage cur conversion
if times is None:
cue = int(cue)
times = np.arange(n_times)
logger.info(f" align on sample cue={cue}")
else:
assert isinstance(times, np.ndarray) and (len(times) == n_times)
cue_time = cue
cue = np.abs(times - cue).argmin() - 1
logger.info(f" align on time-point={cue_time} (sample={cue})")
self.cue, self._times = cue, times
# extract phase and amplitudes
logger.info(f" extract phase and amplitudes "
f"(n_amps={len(self.yvec)})")
kw = dict(keepfilt=False, n_jobs=n_jobs)
pha = self.filter(sf, x, 'phase', n_jobs=n_jobs, keepfilt=True)
amp = self.filter(sf, x, 'amplitude', n_jobs=n_jobs)
self._pha, self._amp = pha, amp ** 2
# peak detection
logger.info(f" running peak detection around sample={cue}")
self.shifts = self._peak_detection(self._pha.squeeze(), cue)
# realign phases and amplitudes
logger.info(f" realign the {n_epochs} phases and amplitudes")
self.amp_a = self._shift_signals(self._amp, self.shifts, fill_with=0.)
self.pha_a = self._shift_signals(self._pha, self.shifts, fill_with=0.)
@staticmethod
def _peak_detection(pha, cue):
"""Single trial closest to a cue peak detection.
Parameters
----------
pha : array_like
Array of single trial phases of shape (n_trials, n_times)
cue : int
Cue to use as a reference (in sample unit)
Returns
-------
peaks : array_like
Array of length (n_trials,) describing each delay to apply
to each trial in order to realign the phases. In detail :
* Positive delays means that zeros should be prepend
* Negative delays means that zeros should be append
"""
n_trials, n_times = pha.shape
peaks = []
for tr in range(n_trials):
# select the single trial phase
st_pha = pha[tr, :]
# detect all peaks across time points
st_peaks = []
for t in range(n_times - 1):
if (st_pha[t - 1] < st_pha[t]) and (st_pha[t] > st_pha[t + 1]):
st_peaks += [t]
# detect the minimum peak
min_peak = st_peaks[np.abs(np.array(st_peaks) - cue).argmin()]
peaks += [cue - min_peak]
return np.array(peaks)
@staticmethod
def _shift_signals(sig, n_shifts, fill_with=0):
"""Shift an array of signals according to an array of delays.
Parameters
----------
sig : array_like
Array of signals of shape (n_freq, n_trials, n_times)
n_shifts : array_like
Array of delays to apply to each trial of shape (n_trials,)
fill_with : int
Value to prepend / append to each shifted time-series
Returns
-------
sig_shifted : array_like
Array of shifted signals with the same shape as the input
"""
# prepare the needed variables
n_freqs, n_trials, n_pts = sig.shape
sig_shifted = np.zeros_like(sig)
# shift each trial
for tr in range(n_trials):
# select the data of a specific trial
st_shift = n_shifts[tr]
st_sig = sig[:, tr, :]
fill = np.full((n_freqs, abs(st_shift)), fill_with,
dtype=st_sig.dtype)
# shift this specific trial
if st_shift > 0: # move forward = prepend zeros
sig_shifted[:, tr, :] = np.c_[fill, st_sig][:, 0:-st_shift]
elif st_shift < 0: # move backward = append zeros
sig_shifted[:, tr, :] = np.c_[st_sig, fill][:, abs(st_shift):]
return sig_shifted
[docs] def plot(self, zscore=False, baseline=None, edges=0, **kwargs):
"""Integrated Peak-Locked TF plotting function.
Parameters
----------
zscore : bool | False
Normalize the power by using a z-score normalization. This can be
useful in order to compensate the 1 / f effect in the power
spectrum. If True, the mean and deviation are computed at the
single trial level and across all time points
baseline : tuple | None
Baseline period to use in order to apply the z-score correction.
Should be in samples.
edges : int | 0
Number of pixels to discard to compensate filtering edge effect
(`power[edges:-edges]`).
kwargs : dict | {}
Additional arguments are sent to the
:class:`tensorpac.utils.PeakLockedTF.pacplot` method
"""
# manage additional arguments
kwargs['colorbar'] = False
kwargs['ylabel'] = 'Frequency for amplitude (hz)'
kwargs['xlabel'] = ''
kwargs['fz_labels'] = kwargs.get('fz_labels', 14)
kwargs['fz_cblabel'] = kwargs.get('fz_cblabel', 14)
kwargs['fz_title'] = kwargs.get('fz_title', 16)
sl_times = slice(edges, len(self._times) - edges)
times = self._times[sl_times]
pha_n = self.pha_a[..., sl_times].squeeze()
# z-score normalization
if zscore:
if baseline is None:
bsl_idx = sl_times
else:
assert len(baseline) == 2
bsl_idx = slice(baseline[0], baseline[1])
_mean = self.amp_a[..., bsl_idx].mean(2, keepdims=True)
_std = self.amp_a[..., bsl_idx].std(2, keepdims=True)
_std[_std == 0.] = 1. # correction from NaN
amp_n = (self.amp_a[..., sl_times] - _mean) / _std
else:
amp_n = self.amp_a[..., sl_times]
# grid definition
gs = GridSpec(8, 8)
# image plot
plt.subplot(gs[slice(0, 6), 0:-1])
self.pacplot(amp_n.mean(1), times, self.yvec, **kwargs)
plt.axvline(times[self.cue], color='w', lw=2)
plt.tick_params(bottom=False, labelbottom=False)
ax_1 = plt.gca()
# external colorbar
plt.subplot(gs[slice(1, 5), -1])
cb = plt.colorbar(self._plt_im, pad=0.01, cax=plt.gca())
cb.set_label('Power (V**2/Hz)', fontsize=kwargs['fz_cblabel'])
cb.outline.set_visible(False)
# phase plot
plt.subplot(gs[slice(6, 8), 0:-1])
plt.plot(times, pha_n.T, color='lightgray', alpha=.2, lw=1.)
plt.plot(times, pha_n.mean(0), label='single trial phases', alpha=.2,
lw=1.) # legend tweaking
plt.plot(times, pha_n.mean(0), label='mean phases',
color='#1f77b4')
plt.axvline(times[self.cue], color='k', lw=2)
plt.autoscale(axis='both', tight=True, enable=True)
plt.xlabel("Times", fontsize=kwargs['fz_labels'])
plt.ylabel("V / Hz", fontsize=kwargs['fz_labels'])
# bottom legend
plt.legend(loc='center', bbox_to_anchor=(.5, -.5),
fontsize='x-large', ncol=2)
ax_2 = plt.gca()
return [ax_1, ax_2]