Skip to content

Commit 487e939

Browse files
authored
Break up datasets.py (fixie-ai#141)
* Break up datasets.py This splits out types.py and registry.py to move the list of pre-defined datasets to its own file and avoid circular refs. An __all__ import is used to minimize changes to surrounding code. * sr * cr * merge * restore typing
1 parent 041c4fe commit 487e939

20 files changed

+808
-754
lines changed

ultravox/data/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from ultravox.data.data_sample import *
2+
from ultravox.data.datasets import *
3+
from ultravox.data.registry import *
4+
from ultravox.data.types import *
5+
6+
__all__ = [
7+
"SizedIterableDataset",
8+
"EmptyDataset",
9+
"InterleaveDataset",
10+
"Range",
11+
"Dataproc",
12+
"VoiceDataset",
13+
"VoiceDatasetArgs",
14+
"VoiceSample",
15+
"create_dataset",
16+
"register_datasets",
17+
]

ultravox/data/data_sample.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import base64
2+
import dataclasses
3+
import io
4+
from typing import Any, Dict, List, Optional
5+
6+
import librosa
7+
import numpy as np
8+
import soundfile as sf
9+
from numpy import typing as npt
10+
11+
SAMPLE_RATE = 16000
12+
13+
14+
def audio_from_file(path: str) -> np.ndarray:
15+
"""Load audio from a file, converting to float32 PCM @ 16 kHz."""
16+
audio, _ = librosa.load(path, sr=SAMPLE_RATE)
17+
assert audio.dtype == np.float32
18+
return audio
19+
20+
21+
def audio_from_buf(buf: bytes) -> np.ndarray:
22+
"""Load audio from a buffer, converting to float32 PCM @ 16 kHz."""
23+
audio, _ = librosa.load(io.BytesIO(buf), sr=SAMPLE_RATE)
24+
assert audio.dtype == np.float32
25+
return audio
26+
27+
28+
def audio_to_wav(audio: np.ndarray, sample_rate: int = SAMPLE_RATE) -> bytes:
29+
"""Convert audio to WAV format, 16-bit PCM @ 16 kHz."""
30+
assert audio.dtype == np.float32
31+
with io.BytesIO() as buf:
32+
sf.write(buf, audio, sample_rate, format="WAV", subtype="PCM_16")
33+
return buf.getvalue()
34+
35+
36+
def audio_to_wav_base64(audio: np.ndarray, sample_rate: int = SAMPLE_RATE) -> str:
37+
"""Convert audio to a base64-encoded WAV file."""
38+
return base64.b64encode(audio_to_wav(audio, sample_rate)).decode("utf-8")
39+
40+
41+
def audio_to_data_uri(audio: np.ndarray, sample_rate: int = SAMPLE_RATE) -> str:
42+
"""Convert audio to a data URI."""
43+
return f"data:audio/wav;base64,{audio_to_wav_base64(audio, sample_rate)}"
44+
45+
46+
def messages_from_prompt(prompt: str) -> List[Dict[str, str]]:
47+
return [{"role": "user", "content": prompt}]
48+
49+
50+
@dataclasses.dataclass
51+
class VoiceSample:
52+
@staticmethod
53+
def from_json(data: Dict[str, Any]) -> "VoiceSample":
54+
"""Convert from JSON format; audio is expected as base64ed WAV."""
55+
bytes = base64.b64decode(data["audio"])
56+
return VoiceSample(data["messages"], audio_from_buf(bytes))
57+
58+
@staticmethod
59+
def from_prompt(prompt: str) -> "VoiceSample":
60+
"""Create a VoiceSample from a prompt only."""
61+
return VoiceSample(messages_from_prompt(prompt), None)
62+
63+
@staticmethod
64+
def from_prompt_and_file(prompt: str, path: str) -> "VoiceSample":
65+
"""Create a VoiceSample from a prompt and an audio file."""
66+
return VoiceSample(messages_from_prompt(prompt), audio_from_file(path))
67+
68+
@staticmethod
69+
def from_prompt_and_buf(prompt: str, buf: bytes) -> "VoiceSample":
70+
"""Create a VoiceSample from a prompt and an encoded audio buffer."""
71+
return VoiceSample(messages_from_prompt(prompt), audio_from_buf(buf))
72+
73+
@staticmethod
74+
def from_prompt_and_raw(
75+
prompt: str, buf: np.ndarray, sample_rate: int
76+
) -> "VoiceSample":
77+
"""Create a VoiceSample from a prompt and raw audio data with sample rate."""
78+
# Keep in native sample rate; we'll resample later if needed.
79+
return VoiceSample(messages_from_prompt(prompt), buf, sample_rate)
80+
81+
def to_json(self) -> Dict[str, Any]:
82+
"""Convert to JSON format; audio is written as base64ed WAV."""
83+
obj: Dict[str, Any] = {"messages": self.messages}
84+
if self.audio is not None:
85+
obj["audio"] = audio_to_wav_base64(self.audio, self.sample_rate)
86+
return obj
87+
88+
def __post_init__(self):
89+
"""Ensure audio is float32 PCM."""
90+
if self.audio is not None:
91+
if self.audio.dtype == np.float64:
92+
self.audio = self.audio.astype(np.float32)
93+
elif self.audio.dtype == np.int16:
94+
self.audio = self.audio.astype(np.float32) / np.float32(32768.0)
95+
elif self.audio.dtype == np.int32:
96+
self.audio = self.audio.astype(np.float32) / np.float32(2147483648.0)
97+
assert (
98+
self.audio.dtype == np.float32
99+
), f"Unexpected audio dtype: {self.audio.dtype}"
100+
assert self.audio.ndim == 1, f"Unexpected audio shape: {self.audio.shape}"
101+
102+
def add_past_messages(self, past_messages: List[Dict[str, str]]):
103+
self.messages = past_messages + self.messages
104+
105+
messages: List[Dict[str, str]]
106+
"""List of messages, each with a "role" and "content" field."""
107+
audio: Optional[npt.NDArray[np.float32]] = None
108+
"""Audio data as float32 PCM @ `sample_rate`."""
109+
sample_rate: int = SAMPLE_RATE
110+
"""Audio sample rate in Hz."""
111+
audio_transcript: Optional[str] = None
112+
"""For evaluations, the known transcript of the audio."""

ultravox/data/data_sample_test.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from typing import Union
2+
3+
import numpy as np
4+
import pytest
5+
6+
from ultravox.data import data_sample
7+
8+
9+
def _create_sine_wave(
10+
freq: int = 440,
11+
duration: float = 1.0,
12+
sample_rate: int = 16000,
13+
amplitude: float = 0.1,
14+
target_dtype: str = "float32",
15+
) -> Union[
16+
np.typing.NDArray[np.float32],
17+
np.typing.NDArray[np.float64],
18+
np.typing.NDArray[np.int16],
19+
np.typing.NDArray[np.int32],
20+
]:
21+
t = np.arange(sample_rate * duration, dtype=np.float32) / sample_rate
22+
wave = amplitude * np.sin(2 * np.pi * freq * t)
23+
match target_dtype:
24+
case "int16":
25+
wave = np.int16(wave * 32767)
26+
case "int32":
27+
wave = np.int32(wave * 2147483647)
28+
case "float32":
29+
# Already float32, nothing needed.
30+
pass
31+
case "float64":
32+
wave = wave.astype(np.float64)
33+
case _:
34+
raise ValueError(f"Unsupported dtype: {target_dtype}")
35+
return wave
36+
37+
38+
def _create_and_validate_sample(target_dtype: str = "float32"):
39+
# Create a sine wave at 440 Hz with a duration of 1.0 second, sampled at 16
40+
# kHz, with an amplitude of 0.1, and the specified dtype.
41+
array = _create_sine_wave(target_dtype=target_dtype)
42+
sample = data_sample.VoiceSample.from_prompt_and_raw(
43+
"Transcribe\n<|audio|>", array, 16000
44+
)
45+
assert sample.sample_rate == 16000
46+
assert sample.audio is not None, "sample.audio should not be None"
47+
assert len(sample.audio) == 16000
48+
assert sample.audio.dtype == np.float32
49+
assert sample.messages == [
50+
{"role": "user", "content": "Transcribe\n<|audio|>"},
51+
]
52+
# Serialize and deserialize the sample.
53+
json = sample.to_json()
54+
sample2 = data_sample.VoiceSample.from_json(json)
55+
assert sample2.sample_rate == sample.sample_rate
56+
assert sample2.audio is not None, "sample2.audio should not be None"
57+
assert len(sample2.audio) == len(sample.audio)
58+
assert sample2.audio.dtype == sample.audio.dtype
59+
assert sample2.messages == sample.messages
60+
assert np.allclose(sample2.audio, sample.audio, rtol=0.0001, atol=0.0001)
61+
62+
63+
def test_create_sample__int16():
64+
_create_and_validate_sample("int16")
65+
66+
67+
def test_create_sample__int32():
68+
_create_and_validate_sample("int32")
69+
70+
71+
def test_create_sample__float32():
72+
_create_and_validate_sample("float32")
73+
74+
75+
def test_create_sample__float64():
76+
_create_and_validate_sample("float64")
77+
78+
79+
def test_create_sample__raises_on_unsupported_dtype():
80+
with pytest.raises(AssertionError):
81+
array = np.ndarray(shape=(16000,), dtype=np.uint8)
82+
_ = data_sample.VoiceSample.from_prompt_and_raw(
83+
"Transcribe\n<|audio|>", array, 16000
84+
)

0 commit comments

Comments
 (0)