Source code for museval.metrics

# -*- coding: utf-8 -*-
"""BSS Eval toolbox, version 4 (Based on mir_eval.separation)

Source separation algorithms attempt to extract recordings of individual
sources from a recording of a mixture of sources.  Evaluation methods for
source separation compare the extracted sources from reference sources and
attempt to measure the perceptual quality of the separation.

See also the bss_eval MATLAB toolbox:
http://bass-db.gforge.inria.fr/bss_eval/

Conventions
-----------

An audio signal is expected to be in the format of a 2-dimensional array where
the first dimension goes over the samples of the audio signal and the second
dimension goes over the channels (as in stereo left and right).
When providing a group of estimated or reference sources, they should be
provided in a 3-dimensional array, where the first dimension corresponds to the
source number, the second corresponds to the samples and the third to the
channels.

Metrics
-------

* :func:`mir_eval.separation.bss_eval`: Computes the bss_eval metrics: source
  to distortion (SDR), source to artifacts (SAR), source to interference (SIR)
  ratios, plus the image to spatial ratio (ISR). These are computed on a frame
  by frame basis, (with infinite window size meaning the whole signal).

  Optionally, the distortion filters are time-varying, corresponding to
  behavior of BSS Eval version 3. Furthermore, metrics may optionally
  correspond to the bsseval_sources version, as defined in the BSS Eval
  version 2.

References
----------
  .. Antoine Liutkus, Fabian-Robert Stöter and Nobutaka
     Ito, "The 2018 Signal Separation Evaluation Campaign," In Proceedings of
     LVA/ICA 2018.
  .. Emmanuel Vincent, Rémi Gribonval, and Cédric
      Févotte, "Performance measurement in blind audio source separation," IEEE
      Trans. on Audio, Speech and Language Processing, 14(4):1462-1469, 2006.
  .. Cédric Févotte, Rémi Gribonval and Emmanuel
     Vincent, "BSS_EVAL toolbox user guide - Revision 2.0", Technical Report
     1706, IRISA, April 2005."""

import numpy as np
import scipy.fftpack
from scipy.linalg import toeplitz
from scipy.signal import fftconvolve
import itertools
import collections
import warnings

# The maximum allowable number of sources (prevents insane computational load)
MAX_SOURCES = 100


[docs]def validate(reference_sources, estimated_sources): """Checks that the input data to a metric are valid, and throws helpful errors if not. Parameters ---------- reference_sources : np.ndarray, shape=(nsrc, nsampl,nchan) matrix containing true sources estimated_sources : np.ndarray, shape=(nsrc, nsampl,nchan) matrix containing estimated sources""" if reference_sources.shape != estimated_sources.shape: raise ValueError( "The shape of estimated sources and the true " "sources should match. reference_sources.shape " "= {}, estimated_sources.shape " "= {}".format(reference_sources.shape, estimated_sources.shape) ) if reference_sources.ndim > 3 or estimated_sources.ndim > 3: raise ValueError( "The number of dimensions is too high (must be less " "than 3). reference_sources.ndim = {}, " "estimated_sources.ndim " "= {}".format(reference_sources.ndim, estimated_sources.ndim) ) if reference_sources.size == 0: warnings.warn( "reference_sources is empty, should be of size " "(nsrc, nsample, nchan). sdr, isr sir, sar, and perm " "will all be empty np.ndarrays" ) elif _any_source_silent(reference_sources): raise ValueError( "All the reference sources should be non-silent (not " "all-zeros), but at least one of the reference " "sources is all 0s, which introduces ambiguity to the" " evaluation. (Otherwise we can add infinitely many " "all-zero sources.)" ) if estimated_sources.size == 0: warnings.warn( "estimated_sources is empty, should be of size " "(nsrc, nsample, nchan). sdr, isr, sir, sar, and perm " "will all be empty np.ndarrays" ) elif _any_source_silent(estimated_sources): raise ValueError( "All the estimated sources should be non-silent (not " "all-zeros), but at least one of the estimated " "sources is all 0s. Since we require each reference " "source to be non-silent, having a silent estimated " "source will result in an underdetermined system." ) if ( estimated_sources.shape[0] > MAX_SOURCES or reference_sources.shape[0] > MAX_SOURCES ): raise ValueError( "The supplied matrices should be of shape (nsrc," " nsampl, nchan) but reference_sources.shape[0] = {} " "and estimated_sources.shape[0] = {} which is greater" "than bsseval.MAX_SOURCES = {}. To " "override this check, set " "bsseval.MAX_SOURCES to a " "larger value.".format( reference_sources.shape[0], estimated_sources.shape[0], MAX_SOURCES ) )
def _any_source_silent(sources): """Returns true if the parameter sources has any silent first dimensions""" return np.any( np.all(np.sum(sources, axis=tuple(range(2, sources.ndim))) == 0, axis=1) )
[docs]def bss_eval( reference_sources, estimated_sources, window=2 * 44100, hop=1.5 * 44100, compute_permutation=False, filters_len=512, framewise_filters=False, bsseval_sources_version=False, ): """BSS_EVAL version 4. Measurement of the separation quality for estimated source signals in terms of source to distortion, interference and artifacts ratios, (SDR, SIR, SAR) as well as the image to spatial ratio (ISR), as defined in [#vincent2005bssevalv3]_. The metrics are computed on a framewise basis, with overlap allowed between the windows. The key difference between this version 4 and BSS Eval version 3 is the possibility of using the same distortion filters for all windows when matching the sources to their estimates, instead of estimating the filters anew at every frame, as done in BSS Eval v3. This implementation is fully compatible with BSS Eval v2 and v3 written in MATLAB. Examples -------- >>> # reference_sources[n] should be a 2D ndarray, with first dimension the >>> # samples and second dimension the channels of the n'th reference >>> # source estimated_sources[n] should be the same for the n'th estimated >>> # source >>> (sdr, isr, sir, sar, perm) = mir_eval.separation.bss_eval( >>> reference_sources, >>> estimated_sources) Parameters ---------- reference_sources : np.ndarray, shape=(nsrc, nsampl, nchan) matrix containing true sources estimated_sources : np.ndarray, shape=(nsrc, nsampl, nchan) matrix containing estimated sources window : int, optional size of each window for time-varying evaluation. Picking np.inf or any integer greater than nsampl will compute metrics on the whole signal. hop : int, optional hop size between windows compute_permutation : bool, optional compute all permutations of estimate/source combinations to compute the best scores (False by default). Note that picking True will lead to a significant computation overhead. filters_len : int, optional maximum time lag for the computation of the distortion filters. Default is filters_len = 512. framewise_filters : bool, optional Compute a new distortion filter for each frame (False by default). Note that picking True as in BSS Eval v2 and v3 leads to a significant computation overhead. bsseval_sources_version : bool, optional if ``True``, results correspond to the `bss_eval_sources` version from the BSS Eval v2 and v3. Note however that this is not recommended because this evaluation method modifies the references according to the estimated sources, leading to potential problems for the estimation of SDR. For instance, zeroing some frequencies in the estimates will lead those to also be zeroed in the references, and hence not evaluated, artificially boosting results. For this reason, SiSEC always uses the `bss_eval_images` version, corresponding to ``False``. Returns ------- sdr : np.ndarray, shape=(nsrc, nwin) matrix of Signal to Distortion Ratios (SDR). One for each source and window isr : np.ndarray, shape=(nsrc, nwin) matrix of source Image to Spatial distortion Ratios (ISR) sir : np.ndarray, shape=(nsrc, nwin) matrix of Source to Interference Ratios (SIR) sar : np.ndarray, shape=(nsrc, nwin) matrix of Sources to Artifacts Ratios (SAR) perm : np.ndarray, shape=(nsrc, nwin) vector containing the best ordering of estimated sources in the mean SIR sense (estimated source number ``perm[j,t]`` corresponds to true source number ``j`` at window ``t``). Note: ``perm`` will be ``(0,2,...,nsrc-1)`` if ``compute_permutation`` is ``False``. References ---------- .. [#liutkus2018bssevalv4] Antoine Liutkus, Fabian-Robert Stöter and Nobutaka Ito, "The 2018 Signal Separation Evaluation Campaign," In Proceedings of LVA/ICA 2018. .. [#vincent2005bssevalv3] Emmanuel Vincent, Rémi Gribonval, and Cédric Févotte, "Performance measurement in blind audio source separation," IEEE Trans. on Audio, Speech and Language Processing, 2006.""" # assuming input is in shape (nsampl) or (nsrc, nsampl) estimated_sources = np.atleast_3d(estimated_sources) reference_sources = np.atleast_3d(reference_sources) # validate input validate(reference_sources, estimated_sources) # If empty matrices were supplied, return empty lists (special case) if reference_sources.size == 0 or estimated_sources.size == 0: return (np.array([]), np.array([]), np.array([]), np.array([]), np.array([])) # determine shape parameters (nsrc, nsampl, nchan) = estimated_sources.shape # defines all the permutations desired by user if compute_permutation: candidate_permutations = np.array( list(itertools.permutations(list(range(nsrc)))) ) else: candidate_permutations = np.array(np.arange(nsrc))[None, :] # initialize variables framer = Framing(window, hop, nsampl) nwin = framer.nwin (SDR, ISR, SIR, SAR) = list(range(4)) s_r = np.empty((4, nsrc, nsrc, nwin)) # define helper functions for computing filters on windows of the signals def compute_GsfC(win=slice(0, nsampl)): # First compute the references correlations G, sf = _compute_reference_correlations(reference_sources[:, win], filters_len) # compute the interference distortion filters C = np.zeros((nsrc, nsrc, nchan, filters_len, nchan)) for jtrue in range(nsrc): C[jtrue] = _compute_projection_filters(G, sf, estimated_sources[jtrue, win]) return (G, sf, C) def compute_Cj(win=slice(0, nsampl)): Cj = np.zeros((nsrc, nsrc, 1, nchan, filters_len, nchan)) for jtrue in range(nsrc): for jest in candidate_permutations[:, jtrue]: # compute the projection filters for this combination Cj[jtrue, jest] = _compute_projection_filters( G[jtrue, jtrue], sf[jtrue], estimated_sources[jest, win] ) return Cj if not framewise_filters: # compute filters on whole signals if no framewise filters (G, sf, C) = compute_GsfC() Cj = compute_Cj() # loop over all windows for t, win in enumerate(framer): # if we have time-varying distortion filters if framewise_filters: (G, sf, C) = compute_GsfC(win) Cj = compute_Cj(win) # loop over all permutations done = np.zeros((nsrc, nsrc)) ref_slice = reference_sources[:, win] est_slice = estimated_sources[:, win] if not _any_source_silent(ref_slice) and not _any_source_silent(est_slice): for jtrue in range(nsrc): for k, jest in enumerate(candidate_permutations[:, jtrue]): # if we have a silent frame set results as np.nan if not done[jtrue, jest]: s_true, e_spat, e_interf, e_artif = _bss_decomp_mtifilt( reference_sources[:, win], estimated_sources[jest, win], jtrue, C[jest], Cj[jtrue, jest, 0], ) s_r[:, jtrue, jest, t] = _bss_crit( s_true, e_spat, e_interf, e_artif, bsseval_sources_version ) done[jtrue, jest] = True else: a = np.empty((4, nsrc, nsrc)) a[:] = np.nan s_r[:, :, :, t] = a # select the best ordering if framewise_filters: # if we have framewise filters, output one permutation for each window mean_sir = np.empty((len(candidate_permutations), nwin)) axis_mean = 0 else: # otherwise, output one permutation for the whole signal as the best # average one mean_sir = np.empty((len(candidate_permutations), 1)) axis_mean = None dum = np.arange(nsrc) for i, perm in enumerate(candidate_permutations): mean_sir[i] = np.mean(s_r[SIR, dum, perm, :], axis=axis_mean) popt = candidate_permutations[np.argmax(mean_sir, axis=0)].T # now prepare the output if not framewise_filters: result = s_r[:, dum, popt[:, 0], :] else: result = np.empty((4, nsrc, nwin)) for m, t in itertools.product(list(range(4)), list(range(nwin))): result[m, :, t] = s_r[m, dum, popt[:, t], t] return (result[SDR], result[ISR], result[SIR], result[SAR], popt)
[docs]def bss_eval_sources(reference_sources, estimated_sources, compute_permutation=True): """ BSS Eval v3 bss_eval_sources Wrapper to ``bss_eval`` with the right parameters. The call to this function is not recommended. See the description for the ``bsseval_sources`` parameter of ``bss_eval``. """ (sdr, isr, sir, sar, perm) = bss_eval( reference_sources, estimated_sources, window=np.inf, hop=np.inf, compute_permutation=compute_permutation, filters_len=512, framewise_filters=True, bsseval_sources_version=True, ) return (sdr, sir, sar, perm)
[docs]def bss_eval_sources_framewise( reference_sources, estimated_sources, window=30 * 44100, hop=15 * 44100, compute_permutation=False, ): """ BSS Eval v3 bss_eval_sources_framewise Wrapper to ``bss_eval`` with the right parameters. The call to this function is not recommended. See the description for the ``bsseval_sources`` parameter of ``bss_eval``. """ (sdr, isr, sir, sar, perm) = bss_eval( reference_sources, estimated_sources, window=window, hop=hop, compute_permutation=compute_permutation, filters_len=512, framewise_filters=True, bsseval_sources_version=True, ) return (sdr, sir, sar, perm)
[docs]def bss_eval_images(reference_sources, estimated_sources, compute_permutation=True): """ BSS Eval v3 bss_eval_images Wrapper to ``bss_eval`` with the right parameters. """ return bss_eval( reference_sources, estimated_sources, window=np.inf, hop=np.inf, compute_permutation=compute_permutation, filters_len=512, framewise_filters=True, bsseval_sources_version=False, )
[docs]def bss_eval_images_framewise( reference_sources, estimated_sources, window=30 * 44100, hop=15 * 44100, compute_permutation=False, ): """ BSS Eval v3 bss_eval_images_framewise Framewise computation of bss_eval_images. Wrapper to ``bss_eval`` with the right parameters. """ return bss_eval( reference_sources, estimated_sources, window=window, hop=hop, compute_permutation=compute_permutation, filters_len=512, framewise_filters=True, bsseval_sources_version=False, )
# Helper functions
[docs]class Framing: """helper iterator class to do overlapped windowing""" def __init__(self, window, hop, length): self.current = 0 self.window = window self.hop = hop self.length = length def __iter__(self): return self def __next__(self): if self.current >= self.nwin: raise StopIteration else: start = self.current * self.hop if np.isnan(start) or np.isinf(start): start = 0 stop = min(self.current * self.hop + self.window, self.length) if np.isnan(stop) or np.isinf(stop): stop = self.length start = int(np.floor(start)) stop = int(np.floor(stop)) result = slice(start, stop) self.current += 1 return result @property def nwin(self): if self.window < self.length: return int(np.floor((self.length - self.window + self.hop) / self.hop)) else: return 1 next = __next__
def _bss_decomp_mtifilt(reference_sources, estimated_source, j, C, Cj): """Decomposition of an estimated source image into four components representing respectively the true source image, spatial (or filtering) distortion, interference and artifacts, derived from the true source images using multichannel time-invariant filters.""" filters_len = Cj.shape[-2] # zero pad s_true = _zeropad(reference_sources[j], filters_len - 1, axis=0) # compute appropriate projections e_spat = _project(reference_sources[j], Cj) - s_true e_interf = _project(reference_sources, C) - s_true - e_spat e_artif = -s_true - e_spat - e_interf e_artif[: estimated_source.shape[0], :] += estimated_source return (s_true, e_spat, e_interf, e_artif) def _zeropad(sig, N, axis=0): """pads with N zeros at the end of the signal, along given axis""" # ensures concatenation dimension is the first sig = np.moveaxis(sig, axis, 0) # zero pad out = np.zeros((sig.shape[0] + N,) + sig.shape[1:]) out[: sig.shape[0], ...] = sig # put back axis in place out = np.moveaxis(out, 0, axis) return out def _reshape_G(G): """From a correlation matrix of size nsrc X nsrc X nchan X nchan X filters_len X filters_len, creates a new one of size nsrc*nchan*filters_len X nsrc*nchan*filters_len""" G = np.moveaxis(G, (1, 3), (3, 4)) (nsrc, nchan, filters_len) = G.shape[0:3] G = np.reshape(G, (nsrc * nchan * filters_len, nsrc * nchan * filters_len)) return G def _compute_reference_correlations(reference_sources, filters_len): """Compute the inner products between delayed versions of reference_sources reference is nsrc X nsamp X nchan. Returns * G, matrix : nsrc X nsrc X nchan X nchan X filters_len X filters_len * sf, reference spectra: nsrc X nchan X filters_len""" # reshape references as nsrc X nchan X nsampl (nsrc, nsampl, nchan) = reference_sources.shape reference_sources = np.moveaxis(reference_sources, (1), (2)) # zero padding and FFT of references reference_sources = _zeropad(reference_sources, filters_len - 1, axis=2) n_fft = int(2 ** np.ceil(np.log2(nsampl + filters_len - 1.0))) sf = scipy.fftpack.fft(reference_sources, n=n_fft, axis=2) # compute intercorrelation between sources G = np.zeros((nsrc, nsrc, nchan, nchan, filters_len, filters_len)) for (i, c1), (j, c2) in itertools.combinations_with_replacement( itertools.product(list(range(nsrc)), list(range(nchan))), 2 ): ssf = sf[j, c2] * np.conj(sf[i, c1]) ssf = np.real(scipy.fftpack.ifft(ssf)) ss = toeplitz(np.hstack((ssf[0], ssf[-1:-filters_len:-1])), r=ssf[:filters_len]) G[j, i, c2, c1] = ss G[i, j, c1, c2] = ss.T return G, sf def _compute_projection_filters(G, sf, estimated_source): """Least-squares projection of estimated source on the subspace spanned by delayed versions of reference sources, with delays between 0 and filters_len-1 """ # epsilon eps = np.finfo(float).eps # shapes (nsampl, nchan) = estimated_source.shape # handles the case where we are calling this with only one source # G should be nsrc X nsrc X nchan X nchan X filters_len X filters_len # and sf should be nsrc X nchan X filters_len if len(G.shape) == 4: G = G[None, None, ...] sf = sf[None, ...] nsrc = G.shape[0] filters_len = G.shape[-1] # zero pad estimates and put chan in first dimension estimated_source = _zeropad(estimated_source.T, filters_len - 1, axis=1) # compute its FFT n_fft = int(2 ** np.ceil(np.log2(nsampl + filters_len - 1.0))) sef = scipy.fftpack.fft(estimated_source, n=n_fft) # compute the cross-correlations between sources and estimates D = np.zeros((nsrc, nchan, filters_len, nchan)) for j, cj, c in itertools.product( list(range(nsrc)), list(range(nchan)), list(range(nchan)) ): ssef = sf[j, cj] * np.conj(sef[c]) ssef = np.real(scipy.fftpack.ifft(ssef)) D[j, cj, :, c] = np.hstack((ssef[0], ssef[-1:-filters_len:-1])) # reshape matrices to build the filters D = D.reshape(nsrc * nchan * filters_len, nchan) G = _reshape_G(G) # Distortion filters try: C = np.linalg.solve(G + eps * np.eye(G.shape[0]), D).reshape( nsrc, nchan, filters_len, nchan ) except np.linalg.linalg.LinAlgError: C = np.linalg.lstsq(G, D)[0].reshape(nsrc, nchan, filters_len, nchan) # if we asked for one single reference source, # return just a nchan X filters_len matrix if nsrc == 1: C = C[0] return C def _project(reference_sources, C): """Project images using pre-computed filters C reference_sources are nsrc X nsampl X nchan C is nsrc X nchan X filters_len X nchan """ # shapes: ensure that input is 3d (comprising the source index) if len(reference_sources.shape) == 2: reference_sources = reference_sources[None, ...] C = C[None, ...] (nsrc, nsampl, nchan) = reference_sources.shape filters_len = C.shape[-2] # zero pad reference_sources = _zeropad(reference_sources, filters_len - 1, axis=1) sproj = np.zeros((nchan, nsampl + filters_len - 1)) for j, cj, c in itertools.product( list(range(nsrc)), list(range(nchan)), list(range(nchan)) ): sproj[c] += fftconvolve(C[j, cj, :, c], reference_sources[j, :, cj])[ : nsampl + filters_len - 1 ] return sproj.T def _bss_crit(s_true, e_spat, e_interf, e_artif, bsseval_sources_version): """Measurement of the separation quality for a given source in terms of filtered true source, interference and artifacts. """ # energy ratios if bsseval_sources_version: s_filt = s_true + e_spat energy_s_filt = np.sum(s_filt**2) sdr = _safe_db(energy_s_filt, np.sum((e_interf + e_artif) ** 2)) isr = np.empty(sdr.shape) * np.nan sir = _safe_db(energy_s_filt, np.sum(e_interf**2)) sar = _safe_db(np.sum((s_filt + e_interf) ** 2), np.sum(e_artif**2)) else: energy_s_true = np.sum((s_true) ** 2) sdr = _safe_db(energy_s_true, np.sum((e_spat + e_interf + e_artif) ** 2)) isr = _safe_db(energy_s_true, np.sum(e_spat**2)) sir = _safe_db(np.sum((s_true + e_spat) ** 2), np.sum(e_interf**2)) sar = _safe_db(np.sum((s_true + e_spat + e_interf) ** 2), np.sum(e_artif**2)) return (sdr, isr, sir, sar) def _safe_db(num, den): """Properly handle the potential +Inf db SIR instead of raising a RuntimeWarning. """ if den == 0: return np.inf return 10 * np.log10(num / den)