from .audio_classes import MultiTrack, Source, Target
from os import path as op
from pathlib import Path
import stempeg
from urllib.request import urlopen, Request
import collections
import numpy as np
from tqdm import tqdm
import zipfile
import yaml
import errno
import musdb
import shutil
import os
import tempfile
[docs]
class DB(object):
"""
The musdb DB Object
Parameters
----------
root : str, optional
musdb Root path. If set to `None` it will be read
from the `MUSDB_PATH` environment variable
subsets : str or list, optional
select a _musdb_ subset `train` or `test` (defaults to both)
is_wav : boolean, optional
expect subfolder with wav files for each source instead stems,
defaults to `False`
download : boolean, optional
download sample version of MUSDB18 which includes 7s excerpts,
defaults to `False`
subsets : list[str], optional
select a _musdb_ subset `train` or `test`.
Default `None` loads `['train', 'test']`
split : str, optional
when `subsets=train`, `split` selects the train/validation split.
`split='train' loads the training split, `split='valid'` loads the validation
split. `split=None` applies no splitting.
Attributes
----------
setup_file : str
path to yaml file. default: `setup.yaml`
root : str
musdb Root path. Default is `MUSDB_PATH`. In combination with
`download`, this path will set the download destination and set to
'~/musdb/' by default.
sources_dir : str
path to Sources directory
sources_names : list[str]
list of names of available sources
targets_names : list[str]
list of names of available targets
setup : Dict
loaded yaml configuration
sample_rate : Optional(Float)
sets sample rate for optional resampling. Defaults to none
which results in `44100.0`
Methods
-------
load_mus_tracks()
Iterates through the musdb folder structure and
returns ``Track`` objects
"""
def __init__(
self,
root=None,
setup_file=None,
is_wav=False,
download=False,
subsets=["train", "test"],
split=None,
sample_rate=None,
):
if root is None:
if download:
self.root = os.path.expanduser("~/MUSDB18/MUSDB18-7")
else:
if "MUSDB_PATH" in os.environ:
self.root = os.environ["MUSDB_PATH"]
else:
raise RuntimeError("Variable `MUSDB_PATH` has not been set.")
else:
self.root = os.path.expanduser(root)
if setup_file is not None:
setup_path = op.join(self.root, setup_file)
else:
setup_path = os.path.join(musdb.__path__[0], "configs", "mus.yaml")
with open(setup_path, "r") as f:
self.setup = yaml.safe_load(f)
if download:
self.url = self.setup["sample-url"]
self.download()
if not self._check_exists():
raise RuntimeError(
"Dataset not found."
+ "You can use download=True to download a sample version of the dataset"
)
self.sample_rate = sample_rate
self.sources_names = list(self.setup["sources"].keys())
self.targets_names = list(self.setup["targets"].keys())
self.is_wav = is_wav
self.tracks = self.load_mus_tracks(subsets=subsets, split=split)
def __getitem__(self, index):
return self.tracks[index]
def __len__(self):
return len(self.tracks)
[docs]
def get_validation_track_indices(self, validation_track_names=None):
"""Returns validation track indices by a given list of track names
Defaults to the builtin selection 8 validation tracks, defined in
`mus.yaml`.
Parameters
== == == == ==
validation_track_names : list[str], optional
validation track names by a given `str` or list of tracknames
Returns
-------
list[int]
return a list of validation track indices
"""
if validation_track_names is None:
validation_track_names = self.setup["validation_tracks"]
return self.get_track_indices_by_names(validation_track_names)
[docs]
def get_track_indices_by_names(self, names):
"""Returns musdb track indices by track name
Can be used to filter the musdb tracks for
a validation subset by trackname
Parameters
== == == == ==
names : list[str], optional
select tracks by a given `str` or list of tracknames
Returns
-------
list[int]
return a list of ``Track`` Objects
"""
if isinstance(names, str):
names = [names]
return [[t.name for t in self.tracks].index(name) for name in names]
[docs]
def load_mus_tracks(self, subsets=None, split=None):
"""Parses the musdb folder structure, returns list of `Track` objects
Parameters
==========
subsets : list[str], optional
select a _musdb_ subset `train` or `test`.
Default `None` loads [`train, test`].
split : str
for subsets='train', `split='train` applies a train/validation split.
if `split='valid`' the validation split of the training subset will be used
Returns
-------
list[Track]
return a list of ``Track`` Objects
"""
if subsets is not None:
if isinstance(subsets, str):
subsets = [subsets]
else:
subsets = ["train", "test"]
if subsets != ["train"] and split is not None:
raise RuntimeError("Subset has to set to `train` when split is used")
tracks = []
for subset in subsets:
subset_folder = op.join(self.root, subset)
for _, folders, files in os.walk(subset_folder):
if self.is_wav:
# parse pcm tracks and sort by name
for track_name in sorted(folders):
if subset == "train":
if (
split == "train"
and track_name in self.setup["validation_tracks"]
):
continue
elif (
split == "valid"
and track_name not in self.setup["validation_tracks"]
):
continue
track_folder = op.join(subset_folder, track_name)
# create new mus track
track = MultiTrack(
name=track_name,
path=op.join(track_folder, self.setup["mixture"]),
subset=subset,
is_wav=self.is_wav,
stem_id=self.setup["stem_ids"]["mixture"],
sample_rate=self.sample_rate,
)
# add sources to track
sources = {}
for src, source_file in list(self.setup["sources"].items()):
# create source object
abs_path = op.join(track_folder, source_file)
if os.path.exists(abs_path):
sources[src] = Source(
track,
name=src,
path=abs_path,
stem_id=self.setup["stem_ids"][src],
sample_rate=self.sample_rate,
)
track.sources = sources
track.targets = self.create_targets(track)
# add track to list of tracks
tracks.append(track)
else:
# parse stem files
for track_name in sorted(files):
if not track_name.endswith(".stem.mp4"):
continue
if subset == "train":
if (
split == "train"
and track_name.split(".stem.mp4")[0]
in self.setup["validation_tracks"]
):
continue
elif (
split == "valid"
and track_name.split(".stem.mp4")[0]
not in self.setup["validation_tracks"]
):
continue
# create new mus track
track = MultiTrack(
name=track_name.split(".stem.mp4")[0],
path=op.join(subset_folder, track_name),
subset=subset,
stem_id=self.setup["stem_ids"]["mixture"],
is_wav=self.is_wav,
sample_rate=self.sample_rate,
)
# add sources to track
sources = {}
for src, source_file in list(self.setup["sources"].items()):
# create source object
abs_path = op.join(subset_folder, track_name)
if os.path.exists(abs_path):
sources[src] = Source(
track,
name=src,
path=abs_path,
stem_id=self.setup["stem_ids"][src],
sample_rate=self.sample_rate,
)
track.sources = sources
# add targets to track
track.targets = self.create_targets(track)
tracks.append(track)
return tracks
def create_targets(self, track):
# add targets to track
targets = collections.OrderedDict()
for name, target_srcs in list(self.setup["targets"].items()):
# add a list of target sources
target_sources = []
for source, gain in list(target_srcs.items()):
if source in list(track.sources.keys()):
# add gain to source tracks
track.sources[source].gain = float(gain)
# add tracks to components
target_sources.append(track.sources[source])
# add sources to target
if target_sources:
targets[name] = Target(track, sources=target_sources, name=name)
return targets
[docs]
def save_estimates(self, user_estimates, track, estimates_dir, write_stems=False):
"""Writes `user_estimates` to disk while recreating the musdb file structure in that folder.
Parameters
==========
user_estimates : Dict[np.array]
the target estimates.
track : Track,
musdb track object
estimates_dir : str,
output folder name where to save the estimates.
"""
track_estimate_dir = op.join(estimates_dir, track.subset, track.name)
if not os.path.exists(track_estimate_dir):
os.makedirs(track_estimate_dir)
# write out tracks to disk
if write_stems:
pass
# to be implemented
else:
for target, estimate in list(user_estimates.items()):
target_path = op.join(track_estimate_dir, target + ".wav")
stempeg.write_audio(
path=target_path, data=estimate, sample_rate=track.rate
)
def _check_exists(self):
return os.path.exists(os.path.join(self.root, "train"))
[docs]
def download(self, progress: bool = True, suffix: str = ".zip"):
"""Download the MUSDB Sample data"""
if self._check_exists():
return
# download files
try:
os.makedirs(os.path.join(self.root))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
print("Downloading MUSDB 7s Sample Dataset to %s..." % self.root)
file_size = None
req = Request(self.url, headers={"User-Agent": "musdb_downloader"})
u = urlopen(req)
meta = u.info()
if hasattr(meta, "getheaders"):
content_length = meta.getheaders("Content-Length")
else:
content_length = meta.get_all("Content-Length")
if content_length is not None and len(content_length) > 0:
file_size = int(content_length[0])
# We deliberately save it in a temp file and move it after
# download is complete. This prevents a local working checkpoint
# being overridden by a broken download.
f = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
try:
with tqdm(
total=file_size,
disable=not progress,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as pbar:
while True:
buffer = u.read(8192)
if len(buffer) == 0:
break
f.write(buffer)
pbar.update(len(buffer))
f.close()
zip_ref = zipfile.ZipFile(f.name, "r")
zip_ref.extractall(os.path.join(self.root))
zip_ref.close()
finally:
f.close()
if os.path.exists(f.name):
os.remove(f.name)