```
├── .github/
├── workflows/
├── check.yml (200 tokens)
├── .gitignore
├── .python-version
├── LICENSE (omitted)
├── README.md (2.1k tokens)
├── assets/
├── asr.wav
├── question.wav
├── pyproject.toml (500 tokens)
├── src/
├── liquid_audio/
├── __init__.py
├── demo/
├── __init__.py
├── chat.py (700 tokens)
├── model.py (100 tokens)
├── model/
├── __init__.py
├── conformer/
├── __init__.py
├── encoder.py (11.2k tokens)
├── mha.py (3.9k tokens)
├── modules.py (3.7k tokens)
├── processor.py (4.9k tokens)
├── subsampling.py (4.9k tokens)
├── utils.py (1000 tokens)
├── lfm2_audio.py (3k tokens)
├── mlp.py (200 tokens)
├── transformer.py (4.2k tokens)
├── moshi/
├── __init__.py (100 tokens)
├── client.py (1400 tokens)
├── client_gradio.py (1100 tokens)
├── client_utils.py (1300 tokens)
├── conditioners/
├── __init__.py (100 tokens)
├── base.py (3.5k tokens)
├── tensors.py (100 tokens)
├── text.py (1000 tokens)
├── models/
├── __init__.py (100 tokens)
├── compression.py (3.5k tokens)
├── lm.py (7.4k tokens)
├── lm_utils.py (1000 tokens)
├── loaders.py (3.4k tokens)
├── tts.py (6.3k tokens)
├── modules/
├── __init__.py (100 tokens)
├── conv.py (2.8k tokens)
├── conv_test.py (1200 tokens)
├── gating.py (700 tokens)
├── lora.py (900 tokens)
├── resample.py (700 tokens)
├── rope.py (500 tokens)
├── seanet.py (3.1k tokens)
├── seanet_test.py (1200 tokens)
├── streaming.py (1700 tokens)
├── transformer.py (7.4k tokens)
├── py.typed
├── quantization/
├── __init__.py (100 tokens)
├── base.py (1100 tokens)
├── core_vq.py (4.3k tokens)
├── vq.py (2.7k tokens)
├── run_inference.py (2.4k tokens)
├── run_tts.py (2.1k tokens)
├── server.py (2.5k tokens)
├── utils/
├── __init__.py (100 tokens)
├── autocast.py (300 tokens)
├── compile.py (2.2k tokens)
├── quantize.py (500 tokens)
├── sampling.py (900 tokens)
├── utils.py (500 tokens)
├── processor.py (1300 tokens)
├── py.typed
├── utils.py (200 tokens)
├── uv.lock (omitted)
```
## /.github/workflows/check.yml
```yml path="/.github/workflows/check.yml"
name: Check
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
types:
- opened
- reopened
- synchronize
- ready_for_review
- review_requested
workflow_dispatch:
permissions:
contents: read
pull-requests: read
concurrency:
group: ${{ github.workflow }}-${{ github.ref || github.run_id }}
cancel-in-progress: true
jobs:
check:
if: >
github.event_name != 'pull_request' ||
!github.event.pull_request.draft ||
github.event.action == 'review_requested'
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
enable-cache: true
cache-dependency-glob: "uv.lock"
- name: Install dependencies
run: uv sync --dev --frozen
- name: Run ruff check
run: uv run ruff check
- name: Run ruff format check
run: uv run ruff format --check
```
## /.gitignore
```gitignore path="/.gitignore"
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# Virtual environments
.venv
# MacOS
.DS_Store
```
## /.python-version
```python-version path="/.python-version"
3.12
```
## /README.md
# Liquid Audio - Speech-to-Speech models
We present LFM2-Audio-1.5B, [Liquid AI](https://www.liquid.ai/)'s first end-to-end audio foundation model. Built with low-latency in mind, the lightweight [LFM2](https://huggingface.co/LiquidAI/LFM2-1.2B) backbone enables real time speech-to-speech conversations without sacrificing quality.
LFM2-Audio supports two generation modes, interleaved and sequential, to maximize performance and quality across different tasks. Interleaved generation outputs text and audio tokens in a fixed interleaved pattern. This approach minimizes time to first audio output and number of tokens generated, making it ideal for naturally flowing real-time speech-to-speech interactions on resource constrained devices. Sequential generation mode, where the model decides when to switch modalities via special tokens, is suitable for non-conversational tasks, such as speech-to-text (ASR) or text-to-speech (TTS).
## Installation
The package can be installed via `pip`
```bash
pip install liquid-audio
pip install "liquid-audio [demo]" # optional, to install demo dependencies
pip install flash-attn --no-build-isolation # optional, to use flash attention 2. Will fallback to torch SDPA if not installed
```
## Usage
Generation is handled by two generation modes, interleaved and sequential, accessible from the methods `LFM2AudioModel.generate_interleaved` and `LFM2AudioModel.generate_sequential` respectively. Both are generators that yield `torch.Tensor`s. Text tokens are represented by tensors with 1 entry, and audio tokens are tensors with 8 entries, corresponding to 8 [Mimi](https://huggingface.co/docs/transformers/en/model_doc/mimi) codebooks.
The `LFM2AudioModel` class operates on tokens only. The `LFM2AudioProcessor` class is used convert between tokens and data. For text, this means the conversion from string to tokens and back. For audio inputs, this handles the conversion of waveforms to log-mel features, and for audio outputs, this handles the detokenization of audio tokens to waveform.
To facilitate the creation of inputs for the generation methods and to apply the correct chat templates, use the `ChatState` helper class. See examples below for usage instructions.
### Gradio demo
To use the demo interface, make sure to install the extra dependencies in the `[demo]` group, e.g.
```bash
pip install "liquid-audio [demo]"
```
To launch the demo, use the command `liquid-audio-demo` on the terminal. The demo interface will be available via the url http://localhost:7860.
### Multi-turn, multi-modal chat
For multi-turn chat with text and audio output, we use interleaved generation. The system prompt should be set to `Respond with interleaved text and audio.`. Here we use audio as the first user turn, and text as the second one
<details>
<summary>Conversation transcript</summary>
**User**
https://github.com/user-attachments/assets/e2ffb8c3-c84c-4460-9cb8-f95a13b6eec6
**Assistant**
Sure! How about "Handcrafted Woodworking, Precision Made for You"? Another option could be "Quality Woodworking, Quality Results." If you want something more personal, you might try "Your Woodworking Needs, Our Expertise."
https://github.com/user-attachments/assets/019664b5-3480-4801-b05a-bd62ddcb8d3e
**User**
My business specialized in chairs, can you give me something related to that?
**Assistant**
Sure thing! How about “Comfortable Chairs, Crafted with Care” or “Elegant Seats, Handcrafted for You”? Let me know if you’d like a few more options.
https://github.com/user-attachments/assets/d0d054b2-6d1d-49fb-94df-4aa0b6641990
</details>
```python
import torch
import torchaudio
from liquid_audio import LFM2AudioModel, LFM2AudioProcessor, ChatState, LFMModality
# Load models
HF_REPO = "LiquidAI/LFM2-Audio-1.5B"
processor = LFM2AudioProcessor.from_pretrained(HF_REPO).eval()
model = LFM2AudioModel.from_pretrained(HF_REPO).eval()
# Set up inputs for the model
chat = ChatState(processor)
chat.new_turn("system")
chat.add_text("Respond with interleaved text and audio.")
chat.end_turn()
chat.new_turn("user")
wav, sampling_rate = torchaudio.load("assets/question.wav")
chat.add_audio(wav, sampling_rate)
chat.end_turn()
chat.new_turn("assistant")
# Generate text and audio tokens.
text_out: list[torch.Tensor] = []
audio_out: list[torch.Tensor] = []
modality_out: list[LFMModality] = []
for t in model.generate_interleaved(**chat, max_new_tokens=512, audio_temperature=1.0, audio_top_k=4):
if t.numel() == 1:
print(processor.text.decode(t), end="", flush=True)
text_out.append(t)
modality_out.append(LFMModality.TEXT)
else:
audio_out.append(t)
modality_out.append(LFMModality.AUDIO_OUT)
# output: Sure! How about "Handcrafted Woodworking, Precision Made for You"? Another option could be "Quality Woodworking, Quality Results." If you want something more personal, you might try "Your Woodworking Needs, Our Expertise."
# Detokenize audio, removing the last "end-of-audio" codes
# Mimi returns audio at 24kHz
mimi_codes = torch.stack(audio_out[:-1], 1).unsqueeze(0)
with torch.no_grad():
waveform = processor.mimi.decode(mimi_codes)[0]
torchaudio.save("answer1.wav", waveform.cpu(), 24_000)
# Append newly generated tokens to chat history
chat.append(
text = torch.stack(text_out, 1),
audio_out = torch.stack(audio_out, 1),
modality_flag = torch.tensor(modality_out),
)
chat.end_turn()
# Start new turn
chat.new_turn("user")
chat.add_text("My business specialized in chairs, can you give me something related to that?")
chat.end_turn()
chat.new_turn("assistant")
# Generate second turn text and audio tokens.
audio_out: list[torch.Tensor] = []
for t in model.generate_interleaved(**chat, max_new_tokens=512, audio_temperature=1.0, audio_top_k=4):
if t.numel() == 1:
print(processor.text.decode(t), end="", flush=True)
else:
audio_out.append(t)
# output: Sure thing! How about “Comfortable Chairs, Crafted with Care” or “Elegant Seats, Handcrafted for You”? Let me know if you’d like a few more options.
# Detokenize second turn audio, removing the last "end-of-audio" codes
mimi_codes = torch.stack(audio_out[:-1], 1).unsqueeze(0)
with torch.no_grad():
waveform = processor.mimi.decode(mimi_codes)[0]
torchaudio.save("answer2.wav", waveform.cpu(), 24_000)
```
### ASR
For ASR, we use sequential generation, with the fixed system prompt `Perform ASR.`. The output is capitalized and punctuated.
<details>
<summary>Input audio snippet</summary>
https://github.com/user-attachments/assets/b3cc017f-363d-49f3-8e7d-f6db9556900e
**Model output**: The stale smell of old beer lingers. It takes heat to bring out the odor. A cold dip restores health and zest. A salt pickle tastes fine with ham. Tacos al pastor are my favorite. A zestful food is the hot cross bun.
</details>
```python
import torch
import torchaudio
from liquid_audio import LFM2AudioModel, LFM2AudioProcessor, ChatState, LFMModality
# Load models
HF_REPO = "LiquidAI/LFM2-Audio-1.5B"
processor = LFM2AudioProcessor.from_pretrained(HF_REPO).eval()
model = LFM2AudioModel.from_pretrained(HF_REPO).eval()
# Set up inputs for the model
chat = ChatState(processor)
chat.new_turn("system")
chat.add_text("Perform ASR.")
chat.end_turn()
chat.new_turn("user")
wav, sampling_rate = torchaudio.load("assets/asr.wav")
chat.add_audio(wav, sampling_rate)
chat.end_turn()
chat.new_turn("assistant")
# Generate text
for t in model.generate_sequential(**chat, max_new_tokens=512):
if t.numel() == 1:
print(processor.text.decode(t), end="", flush=True)
# Output: The stale smell of old beer lingers. It takes heat to bring out the odor. A cold dip restores health and zest. A salt pickle tastes fine with ham. Tacos al pastor are my favorite. A zestful food is the hot cross bun.
```
### TTS
For TTS, we also use sequential generation, with the fixed system prompt `Perform TTS.`. In addition, we can prompt the voice and a style using a natural language description.
<details>
<summary>TTS Sample</summary>
**Voice description**: A male speaker delivers his lines with a low-pitched voice and an animated tone. The recording is of excellent quality with almost no noise and a very close-sounding atmosphere.
**Input sentence**: What is this obsession people have with books? They put them in their houses—like they're trophies. What do you need it for after you read it?
**Output audio**
https://github.com/user-attachments/assets/2fa953cf-d8a8-477a-b841-c4f18d9266e6
</details>
```python
import torch
import torchaudio
from liquid_audio import LFM2AudioModel, LFM2AudioProcessor, ChatState, LFMModality
# Load models
HF_REPO = "LiquidAI/LFM2-Audio-1.5B"
processor = LFM2AudioProcessor.from_pretrained(HF_REPO).eval()
model = LFM2AudioModel.from_pretrained(HF_REPO).eval()
# Set up inputs for the model
chat = ChatState(processor)
chat.new_turn("system")
chat.add_text("Perform TTS.\nUse the following voice: A male speaker delivers his lines with a low-pitched voice and an animated tone. The recording is of excellent quality with almost no noise and a very close-sounding atmosphere.")
chat.end_turn()
chat.new_turn("user")
chat.add_text("What is this obsession people have with books? They put them in their houses—like they're trophies. What do you need it for after you read it?")
chat.end_turn()
chat.new_turn("assistant")
# Generate text
audio_out: list[torch.Tensor] = []
for t in model.generate_sequential(**chat, max_new_tokens=512, audio_temperature = 0.8, audio_top_k=64):
if t.numel() > 1:
audio_out.append(t)
# Detokenize audio
mimi_codes = torch.stack(audio_out[:-1], 1).unsqueeze(0)
with torch.no_grad():
waveform = processor.mimi.decode(mimi_codes)[0]
torchaudio.save("tts.wav", waveform.cpu(), 24_000)
```
## License
The code in this repository and associated weights are licensed under the [LFM Open License v1.0](LICENSE).
The code for the audio encoder is based on [Nvidia NeMo](https://github.com/NVIDIA-NeMo/NeMo/tree/main), licensed under [Apache 2.0](https://github.com/NVIDIA-NeMo/NeMo/blob/294ddff187f68c055d87ffe9400e65975b38693d/LICENSE), and the [canary-180m-flash](https://huggingface.co/nvidia/canary-180m-flash) checkpoint, licensed under [CC-BY 4.0](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/cc-by-4.0.md). To simplify dependency resolution, we also ship the Python code of [Kyutai Mimi](https://github.com/kyutai-labs/moshi), licensed under the [MIT License](https://github.com/kyutai-labs/moshi/blob/aee53fc0fc0119e4d7343e5ea4dd6ddafd7f09c4/LICENSE-MIT).
## /assets/asr.wav
Binary file available at https://raw.githubusercontent.com/Liquid4All/liquid-audio/refs/heads/main/assets/asr.wav
## /assets/question.wav
Binary file available at https://raw.githubusercontent.com/Liquid4All/liquid-audio/refs/heads/main/assets/question.wav
## /pyproject.toml
```toml path="/pyproject.toml"
[project]
name = "liquid-audio"
version = "1.0.0"
description = "Liquid Audio - Speech-to-Speech audio models"
readme = "README.md"
authors = [
{ name = "Liquid AI, Inc", email = "support@liquid.ai" }
]
license = "LicenseRef-LFM-Open-License-v1.0"
license-files = ["LICENSE"]
requires-python = ">=3.12"
dependencies = [
"accelerate>=1.10.1",
"einops>=0.8.1",
"librosa>=0.11.0",
"sentencepiece>=0.2.1",
"torch>=2.8.0",
"torchaudio>=2.8.0",
"transformers>=4.55.4",
]
keywords = ["Liquid AI", "LFM", "LFM2", "Audio", "Speech-to-Speech"]
[project.urls]
Homepage = "https://www.liquid.ai/"
Repository = "https://github.com/Liquid4All/liquid-audio/"
Issues = "https://github.com/Liquid4All/liquid-audio/issues"
[project.scripts]
liquid-audio-demo = "liquid_audio.demo.chat:main [demo]"
[project.optional-dependencies]
demo = [
"fastrtc[vad]>=0.0.30",
]
[build-system]
requires = ["uv_build>=0.8.13,<0.9.0"]
build-backend = "uv_build"
[dependency-groups]
dev = [
"ipython>=9.4.0",
"mypy>=1.17.1",
"ruff>=0.12.10",
]
[tool.ruff]
line-length = 127 # The GitHub editor is 127 chars wide
extend-exclude = [
"src/liquid_audio/model/conformer",
"src/liquid_audio/moshi",
]
[tool.ruff.lint]
select = [
# pycodestyle
"E",
# Pyflakes
"F",
# pyupgrade
"UP",
# flake8-bugbear
"B",
# flake8-simplify
"SIM",
# isort
"I",
# tryceratops
"TRY",
# perflint
"PERF",
# refurb
"FURB",
#ruff
"RUF",
]
ignore = [
# These conflict with ruff format
"W191",
"E111",
"E114",
"E117",
"E501",
"D206",
"D300",
"Q000",
"Q001",
"Q002",
"Q003",
"COM812",
"COM819",
"ISC001",
"ISC002",
# Probably more readable without ternary operators
"SIM108",
# "Ambiguous variable name", probably too strict for kernel variables
"E741",
# Allow error message outside of class
"TRY003",
# Allow ABCs without abstract methods
"B024",
]
[tool.mypy]
mypy_path = "src"
packages = [
"liquid_audio",
]
exclude = [
"liquid_audio.model.conformer",
]
warn_unused_configs = true
warn_redundant_casts = true
warn_unused_ignores = true
strict_equality = true
extra_checks = true
check_untyped_defs = true
# TODO: check if needed
[[tool.mypy.overrides]]
module = [
"accelerate.*",
"torchaudio.*",
]
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = [
'liquid_audio.moshi.*'
]
ignore_errors = true
```
## /src/liquid_audio/__init__.py
```py path="/src/liquid_audio/__init__.py"
from liquid_audio.model.lfm2_audio import LFM2AudioModel
from liquid_audio.processor import ChatState, LFM2AudioProcessor
from liquid_audio.utils import LFMModality
__all__ = ["ChatState", "LFM2AudioModel", "LFM2AudioProcessor", "LFMModality"]
```
## /src/liquid_audio/demo/__init__.py
```py path="/src/liquid_audio/demo/__init__.py"
```
## /src/liquid_audio/demo/chat.py
```py path="/src/liquid_audio/demo/chat.py"
from queue import Queue
from threading import Thread
import gradio as gr
import numpy as np
import torch
from fastrtc import AdditionalOutputs, ReplyOnPause, WebRTC
from liquid_audio import ChatState, LFMModality
from .model import lfm2_audio, mimi, proc
def chat_producer(
q: Queue[torch.Tensor | None],
chat: ChatState,
temp: float | None,
topk: int | None,
):
print(f"Starting generation with state {chat}.")
with torch.no_grad(), mimi.streaming(1):
for t in lfm2_audio.generate_interleaved(
**chat,
max_new_tokens=1024,
audio_temperature=temp,
audio_top_k=topk,
):
q.put(t)
if t.numel() > 1:
if (t == 2048).any():
continue
wav_chunk = mimi.decode(t[None, :, None])[0]
q.put(wav_chunk)
q.put(None)
def chat_response(audio: tuple[int, np.ndarray], _id: str, chat: ChatState, temp: float | None = 1.0, topk: int | None = 4):
if temp == 0:
temp = None
if topk == 0:
topk = None
if temp is not None:
temp = float(temp)
if topk is not None:
topk = int(topk)
if len(chat.text) == 1:
chat.new_turn("system")
chat.add_text("Respond with interleaved text and audio.")
chat.end_turn()
chat.new_turn("user")
rate, wav = audio
chat.add_audio(torch.tensor(wav / 32_768, dtype=torch.float), rate)
chat.end_turn()
chat.new_turn("assistant")
q: Queue[torch.Tensor | None] = Queue()
chat_thread = Thread(target=chat_producer, args=(q, chat, temp, topk))
chat_thread.start()
out_text: list[torch.Tensor] = []
out_audio: list[torch.Tensor] = []
out_modality: list[LFMModality] = []
while True:
t = q.get()
if t is None:
break
elif t.numel() == 1: # text
out_text.append(t)
out_modality.append(LFMModality.TEXT)
print(proc.text.decode(t), end="")
cur_string = proc.text.decode(torch.cat(out_text)).removesuffix("<|text_end|>")
yield AdditionalOutputs(cur_string)
elif t.numel() == 8:
out_audio.append(t)
out_modality.append(LFMModality.AUDIO_OUT)
elif t.numel() == 1920:
np_chunk = (t.cpu().numpy() * 32_767).astype(np.int16)
yield (24_000, np_chunk)
else:
raise RuntimeError(f"unexpected shape: {t.shape}")
chat.append(
text=torch.stack(out_text, 1),
audio_out=torch.stack(out_audio, 1),
modality_flag=torch.tensor(out_modality, device="cuda"),
)
chat.end_turn()
chat.new_turn("user")
def clear():
gr.Info("Cleared chat history", duration=3)
return ChatState(proc), None
with gr.Blocks() as demo:
gr.Markdown("# LFM2-Audio speech-to-speech chat")
chat_state = gr.State(ChatState(proc))
webrtc = WebRTC(
modality="audio",
mode="send-receive",
# variant="textbox",
full_screen=False,
)
text_out = gr.Textbox(
lines=4,
label="Output",
)
clear_btn = gr.Button("Reset chat")
webrtc.stream(
ReplyOnPause(
chat_response, # type: ignore[arg-type]
input_sample_rate=24_000,
output_sample_rate=24_000,
can_interrupt=False,
),
inputs=[webrtc, chat_state],
outputs=[webrtc],
)
webrtc.on_additional_outputs(
lambda s: s,
outputs=[text_out],
)
clear_btn.click(clear, outputs=[chat_state, text_out])
def main():
demo.launch()
if __name__ == "__main__":
main()
```
## /src/liquid_audio/demo/model.py
```py path="/src/liquid_audio/demo/model.py"
"""Initialize models"""
import logging
import torch
from liquid_audio import LFM2AudioModel, LFM2AudioProcessor
logger = logging.getLogger(__name__)
__all__ = ["lfm2_audio", "mimi", "proc"]
HF_DIR = "LiquidAI/LFM2-Audio-1.5B"
logging.info("Loading processor")
proc = LFM2AudioProcessor.from_pretrained(HF_DIR).eval()
logging.info("Loading model")
lfm2_audio = LFM2AudioModel.from_pretrained(HF_DIR).eval()
logging.info("Loading tokenizer")
mimi = proc.mimi.eval()
logging.info("Warmup tokenizer")
with mimi.streaming(1), torch.no_grad():
for _ in range(5):
x = torch.randint(2048, (1, 8, 1), device="cuda")
mimi.decode(x)
```
## /src/liquid_audio/model/__init__.py
```py path="/src/liquid_audio/model/__init__.py"
```
## /src/liquid_audio/model/conformer/__init__.py
```py path="/src/liquid_audio/model/conformer/__init__.py"
```
## /src/liquid_audio/model/conformer/encoder.py
```py path="/src/liquid_audio/model/conformer/encoder.py"
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
nemo Conformer
adapted from https://github.com/NVIDIA/NeMo/blob/c83adff36efaa549f7bdd26e97c01a60e9f9026b/nemo/collections/asr/modules/conformer_encoder.py
"""
from dataclasses import dataclass
from torch import nn
import torch
from .mha import MultiHeadAttention, RelPositionalEncoding
from .modules import CausalConv1D, ConformerLayer
from .subsampling import ConvSubsampling
from .utils import CacheAwareStreamingConfig, compute_stochastic_depth_drop_probs
@dataclass(kw_only=True)
class ConformerEncoderConfig:
feat_in: int
feat_out: int
n_layers: int
d_model: int
subsampling: str
subsampling_factor: int
subsampling_conv_channels: int
causal_downsampling: bool
reduction: str | None
reduction_position: int | None
reduction_factor: int
ff_expansion_factor: int
self_attention_model: str
n_heads: int
att_context_size: list[list[int]]
xscaling: bool
untie_biases: bool
pos_emb_max_len: int
conv_kernel_size: int
conv_norm_type: str
conv_context_size: list[int] | None
dropout: float
dropout_pre_encoder: float
dropout_emb: float
dropout_att: float
class ConformerEncoder(nn.Module):
"""
The encoder for ASR model of Conformer.
Based on this paper:
'Conformer: Convolution-augmented Transformer for Speech Recognition' by Anmol Gulati et al.
https://arxiv.org/abs/2005.08100
Args:
feat_in (int): the size of feature channels
n_layers (int): number of layers of ConformerBlock
d_model (int): the hidden size of the model
feat_out (int): the size of the output features
Defaults to -1 (means feat_out is d_model)
subsampling (str): the method of subsampling:
choices = ['vggnet', 'striding', 'dw-striding', 'stacking', 'stacking_norm']
Defaults to striding.
subsampling_factor (int): the subsampling factor which should be power of 2
Defaults to 4.
subsampling_conv_chunking_factor(int): optionally, force chunk inputs (helpful for large inputs)
Should be power of 2, 1 (auto-chunking, default), or -1 (no chunking)
subsampling_conv_channels (int): the size of the convolutions in the subsampling module
Defaults to -1 which would set it to d_model.
reduction (str, Optional): the method of reduction, choices=['pooling', 'striding']. If no value
is passed, then no reduction is performed and the models runs with the original 4x subsampling.
reduction_position (int, Optional): the index of the layer to apply reduction. If -1, apply reduction
at the end.
reduction_factor (int): the reduction factor which should be either 1 or a power of 2
Defaults to 1.
ff_expansion_factor (int): the expansion factor in feed forward layers
Defaults to 4.
self_attention_model (str): the type of the attention layer and positional encoding.
'rel_pos':
relative positional embedding and Transformer-XL
'rel_pos_local_attn':
relative positional embedding and Transformer-XL with local attention using
overlapping chunks. Attention context is determined by att_context_size parameter.
'abs_pos':
absolute positional embedding and Transformer
Default is rel_pos.
pos_emb_max_len (int): the maximum length of positional embeddings
Defaults to 5000
n_heads (int): number of heads in multi-headed attention layers
Defaults to 4.
att_context_size (List[Union[List[int],int]]): specifies the context sizes on each side.
Each context size should be a list of two integers like `[100, 100]`.
A list of context sizes like `[[100,100]`, `[100,50]]` can also be passed. -1 means unlimited context.
Defaults to `[-1, -1]`
att_context_probs (List[float]): a list of probabilities of each one of the att_context_size
when a list of them is passed. If not specified, uniform distribution is being used.
Defaults to None
att_context_style (str): 'regular' or 'chunked_limited'.
Defaults to 'regular'
xscaling (bool): enables scaling the inputs to the multi-headed attention layers by `sqrt(d_model)`.
Defaults to True.
untie_biases (bool): whether to not share (untie) the bias weights between layers of Transformer-XL
Defaults to True.
conv_kernel_size (int): the size of the convolutions in the convolutional modules
Defaults to 31.
conv_norm_type (str): the type of the normalization in the convolutional modules
Defaults to 'batch_norm'.
conv_context_size (list): it can be"causal" or a list of two integers
while `conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size`.
`None` means `[(conv_kernel_size-1)//2`, `(conv_kernel_size-1)//2]`, and 'causal' means
`[(conv_kernel_size-1), 0]`.
Defaults to None.
conv_dual_mode (bool): specifies if convolution should be dual mode when dual_offline mode is being used.
When enables, the left half of the convolution kernel would get masked in streaming cases.
Defaults to False.
use_bias (bool): Use bias in all Linear and Conv1d layers from each ConformerLayer to improve
activation flow and stabilize training of huge models.
Defaults to True.
dropout (float): the dropout rate used in all layers except the attention layers
Defaults to 0.1.
dropout_pre_encoder (float): the dropout rate used before the encoder
Defaults to 0.1.
dropout_emb (float): the dropout rate used for the positional embeddings
Defaults to 0.1.
dropout_att (float): the dropout rate used for the attention layer
Defaults to 0.0.
stochastic_depth_drop_prob (float): if non-zero, will randomly drop
layers during training. The higher this value, the more often layers
are dropped. Defaults to 0.0.
stochastic_depth_mode (str): can be either "linear" or "uniform". If
set to "uniform", all layers have the same probability of drop. If
set to "linear", the drop probability grows linearly from 0 for the
first layer to the desired value for the final layer. Defaults to
"linear".
stochastic_depth_start_layer (int): starting layer for stochastic depth.
All layers before this will never be dropped. Note that drop
probability will be adjusted accordingly if mode is "linear" when
start layer is > 1. Defaults to 1.
global_tokens (int): number of tokens to be used for global attention.
Only relevant if self_attention_model is 'rel_pos_local_attn'.
Defaults to 0.
global_tokens_spacing (int): how far apart the global tokens are
Defaults to 1.
global_attn_separate (bool): whether the q, k, v layers used for global tokens should be separate.
Defaults to False.
use_pytorch_sdpa (bool): use torch sdpa instead of manual attention.
Defaults to False.
use_pytorch_sdpa_backends (list[str]): list of backend names to use in sdpa.
None or empty list means all backends. e.g. ["MATH"]
Defaults to None.
bypass_pre_encode: if True, skip the pre-encoder module and the `audio_signal` should be pre-encoded
embeddings. The `audio_signal` input supports two formats depending on the `bypass_pre_encode`
boolean flag. This determines the required format of the input variable `audio_signal`.
Defaults to `bypass_pre_encode=False`. `bypass_pre_encode=True` is used for the cases
where frame-level, context-independent embeddings are needed to be saved or reused.
(e.g., speaker cache in streaming speaker diarization)
sync_max_audio_length (bool): when true, performs NCCL all_reduce to allocate the same amount of memory for
positional encoding buffers on all GPUs. Disabling this setting may help with deadlocks in certain
scenarios such as model parallelism, or generally when this module is not being ran on some GPUs
as a part of the training step.
"""
def input_example(self, max_batch=1, max_dim=256):
"""
Generates input examples for tracing etc.
Returns:
A tuple of input examples.
"""
dev = next(self.parameters()).device
if self.export_cache_support:
window_size = max_dim
if self.streaming_cfg is not None:
if isinstance(self.streaming_cfg.chunk_size, list):
chunk_size = self.streaming_cfg.chunk_size[1]
else:
chunk_size = self.streaming_cfg.chunk_size
if isinstance(self.streaming_cfg.pre_encode_cache_size, list):
pre_encode_cache_size = self.streaming_cfg.pre_encode_cache_size[1]
else:
pre_encode_cache_size = self.streaming_cfg.pre_encode_cache_size
window_size = chunk_size + pre_encode_cache_size
input_example = torch.randn(max_batch, self._feat_in, window_size, device=dev)
input_example_length = torch.randint(
window_size // 4, window_size, (max_batch,), device=dev, dtype=torch.int64
)
cache_last_channel, cache_last_time, cache_last_channel_len = self.get_initial_cache_state(
batch_size=max_batch, device=dev, max_dim=max_dim
)
all_input_example = tuple(
[
input_example,
input_example_length,
cache_last_channel.transpose(0, 1),
cache_last_time.transpose(0, 1),
cache_last_channel_len,
]
)
else:
input_example = torch.randn(max_batch, self._feat_in, max_dim, device=dev)
input_example_length = torch.randint(max_dim // 4, max_dim, (max_batch,), device=dev, dtype=torch.int64)
all_input_example = tuple([input_example, input_example_length])
return all_input_example
@property
def input_types(self):
"""Returns definitions of module input ports."""
return OrderedDict(
{
"audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()),
"length": NeuralType(tuple('B'), LengthsType()),
"cache_last_channel": NeuralType(('D', 'B', 'T', 'D'), ChannelType(), optional=True),
"cache_last_time": NeuralType(('D', 'B', 'D', 'T'), ChannelType(), optional=True),
"cache_last_channel_len": NeuralType(tuple('B'), LengthsType(), optional=True),
"bypass_pre_encode": NeuralType(tuple(), BoolType(), optional=True),
}
)
@property
def input_types_for_export(self):
"""Returns definitions of module input ports."""
return OrderedDict(
{
"audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()),
"length": NeuralType(tuple('B'), LengthsType()),
"cache_last_channel": NeuralType(('B', 'D', 'T', 'D'), ChannelType(), optional=True),
"cache_last_time": NeuralType(('B', 'D', 'D', 'T'), ChannelType(), optional=True),
"cache_last_channel_len": NeuralType(tuple('B'), LengthsType(), optional=True),
"bypass_pre_encode": NeuralType(tuple(), BoolType(), optional=True),
}
)
@property
def output_types(self):
"""Returns definitions of module output ports."""
return OrderedDict(
{
"outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
"encoded_lengths": NeuralType(tuple('B'), LengthsType()),
"cache_last_channel_next": NeuralType(('D', 'B', 'T', 'D'), ChannelType(), optional=True),
"cache_last_time_next": NeuralType(('D', 'B', 'D', 'T'), ChannelType(), optional=True),
"cache_last_channel_next_len": NeuralType(tuple('B'), LengthsType(), optional=True),
}
)
@property
def output_types_for_export(self):
"""Returns definitions of module output ports."""
return OrderedDict(
{
"outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
"encoded_lengths": NeuralType(tuple('B'), LengthsType()),
"cache_last_channel_next": NeuralType(('B', 'D', 'T', 'D'), ChannelType(), optional=True),
"cache_last_time_next": NeuralType(('B', 'D', 'D', 'T'), ChannelType(), optional=True),
"cache_last_channel_next_len": NeuralType(tuple('B'), LengthsType(), optional=True),
}
)
@property
def disabled_deployment_input_names(self):
if not self.export_cache_support:
return set(["cache_last_channel", "cache_last_time", "cache_last_channel_len"])
else:
return set()
@property
def disabled_deployment_output_names(self):
if not self.export_cache_support:
return set(["cache_last_channel_next", "cache_last_time_next", "cache_last_channel_next_len"])
else:
return set()
def __init__(
self,
feat_in,
n_layers,
d_model,
feat_out=-1,
causal_downsampling=False,
subsampling='striding',
subsampling_factor=4,
subsampling_conv_chunking_factor=1,
subsampling_conv_channels=-1,
reduction=None,
reduction_position=None,
reduction_factor=1,
ff_expansion_factor=4,
self_attention_model='rel_pos',
n_heads=4,
att_context_size=None,
att_context_probs=None,
att_context_style='regular',
xscaling=True,
untie_biases=True,
pos_emb_max_len=5000,
conv_kernel_size=31,
conv_norm_type='batch_norm',
conv_context_size=None,
use_bias=True,
dropout=0.1,
dropout_pre_encoder=0.1,
dropout_emb=0.1,
dropout_att=0.0,
stochastic_depth_drop_prob: float = 0.0,
stochastic_depth_mode: str = "linear",
stochastic_depth_start_layer: int = 1,
global_tokens: int = 0,
global_tokens_spacing: int = 1,
global_attn_separate: bool = False,
use_pytorch_sdpa: bool = False,
use_pytorch_sdpa_backends=None,
sync_max_audio_length: bool = True,
):
super().__init__()
d_ff = d_model * ff_expansion_factor
self.d_model = d_model
self.n_layers = n_layers
self._feat_in = feat_in
self.att_context_style = att_context_style
self.subsampling_factor = subsampling_factor
self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
self.self_attention_model = self_attention_model
self.global_tokens = global_tokens
self.global_attn_separate = global_attn_separate
self.global_tokens_spacing = global_tokens_spacing
self.use_pytorch_sdpa = use_pytorch_sdpa
if use_pytorch_sdpa_backends is None:
use_pytorch_sdpa_backends = []
self.use_pytorch_sdpa_backends = use_pytorch_sdpa_backends
self.sync_max_audio_length = sync_max_audio_length
# Setting up the att_context_size
(
self.att_context_size_all,
self.att_context_size,
self.att_context_probs,
self.conv_context_size,
) = self._calc_context_sizes(
att_context_style=att_context_style,
att_context_size=att_context_size,
att_context_probs=att_context_probs,
conv_context_size=conv_context_size,
conv_kernel_size=conv_kernel_size,
)
if xscaling:
self.xscale = math.sqrt(d_model)
else:
self.xscale = None
# Subsampling
if subsampling_conv_channels == -1:
subsampling_conv_channels = d_model
if subsampling and subsampling_factor > 1:
if subsampling in ['stacking', 'stacking_norm']:
# stacking_norm has an extra layer norm after stacking comparing to stacking
self.pre_encode = StackingSubsampling(
subsampling_factor=subsampling_factor,
feat_in=feat_in,
feat_out=d_model,
norm=True if subsampling == 'stacking_norm' else False,
)
else:
self.pre_encode = ConvSubsampling(
subsampling=subsampling,
subsampling_factor=subsampling_factor,
feat_in=feat_in,
feat_out=d_model,
conv_channels=subsampling_conv_channels,
subsampling_conv_chunking_factor=subsampling_conv_chunking_factor,
activation=nn.ReLU(True),
is_causal=causal_downsampling,
)
else:
self.pre_encode = nn.Linear(feat_in, d_model)
# Reduction
if reduction and reduction_factor > 1:
assert reduction_position >= -1 and reduction_position < n_layers
self.reduction_subsampling = SubsamplingReductionModule(
reduction=reduction,
d_model=d_model,
reduction_factor=reduction_factor,
)
self.reduction_position = reduction_position
else:
self.reduction_subsampling = None
self.reduction_position = None
self._feat_out = d_model
# Biases for relative positional encoding
if not untie_biases and self_attention_model == "rel_pos":
d_head = d_model // n_heads
pos_bias_u = nn.Parameter(torch.Tensor(n_heads, d_head))
pos_bias_v = nn.Parameter(torch.Tensor(n_heads, d_head))
nn.init.zeros_(pos_bias_u)
nn.init.zeros_(pos_bias_v)
else:
pos_bias_u = None
pos_bias_v = None
# Positional encodings
self.pos_emb_max_len = pos_emb_max_len
if self_attention_model == "rel_pos":
self.pos_enc = RelPositionalEncoding(
d_model=d_model,
dropout_rate=dropout_pre_encoder,
max_len=pos_emb_max_len,
xscale=self.xscale,
dropout_rate_emb=dropout_emb,
)
elif self_attention_model == 'rel_pos_local_attn':
if max(att_context_size) <= 0:
raise ValueError("When using local attention, context size must be set > 0")
self.pos_enc = LocalAttRelPositionalEncoding(
att_context_size=att_context_size,
d_model=d_model,
dropout_rate=dropout,
max_len=pos_emb_max_len,
xscale=self.xscale,
dropout_rate_emb=dropout_emb,
)
elif self_attention_model == "abs_pos":
pos_bias_u = None
pos_bias_v = None
self.pos_enc = PositionalEncoding(
d_model=d_model, dropout_rate=dropout_pre_encoder, max_len=pos_emb_max_len, xscale=self.xscale
)
else:
raise ValueError(f"Not valid self_attention_model: '{self_attention_model}'!")
self.layers = nn.ModuleList()
for i in range(n_layers):
layer = ConformerLayer(
d_model=d_model,
d_ff=d_ff,
self_attention_model=self_attention_model,
global_tokens=global_tokens,
global_tokens_spacing=global_tokens_spacing,
global_attn_separate=global_attn_separate,
n_heads=n_heads,
conv_kernel_size=conv_kernel_size,
conv_norm_type=conv_norm_type,
conv_context_size=self.conv_context_size,
dropout=dropout,
dropout_att=dropout_att,
pos_bias_u=pos_bias_u,
pos_bias_v=pos_bias_v,
att_context_size=self.att_context_size,
use_bias=use_bias,
use_pytorch_sdpa=self.use_pytorch_sdpa,
use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends,
)
self.layers.append(layer)
if feat_out > 0 and feat_out != self._feat_out:
self.out_proj = nn.Linear(self._feat_out, feat_out)
self._feat_out = feat_out
else:
self.out_proj = None
self._feat_out = d_model
self.set_max_audio_length(self.pos_emb_max_len)
self.use_pad_mask = True
self.setup_streaming_params()
self.export_cache_support = False
self.layer_drop_probs = compute_stochastic_depth_drop_probs(
len(self.layers), stochastic_depth_drop_prob, stochastic_depth_mode, stochastic_depth_start_layer
)
# will be set in self.forward() if defined in AccessMixin config
self.interctc_capture_at_layers = None
def forward_for_export(
self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None
):
"""
Forward function for model export. Please see `forward()` for more details.
"""
if cache_last_channel is not None:
cache_last_channel = cache_last_channel.transpose(0, 1)
cache_last_time = cache_last_time.transpose(0, 1)
rets = self.forward_internal(
audio_signal,
length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
)
rets = self.streaming_post_process(rets, keep_all_outputs=False)
if len(rets) == 2:
return rets
elif rets[2] is None and rets[3] is None and rets[4] is None:
return (rets[0], rets[1])
else:
return (
rets[0],
rets[1],
rets[2].transpose(0, 1),
rets[3].transpose(0, 1),
rets[4],
)
def streaming_post_process(self, rets, keep_all_outputs=True):
"""
Post-process the output of the forward function for streaming.
Args:
rets: The output of the forward function.
keep_all_outputs: Whether to keep all outputs.
"""
if len(rets) == 2:
return rets[0], rets[1], None, None, None
(encoded, encoded_len, cache_last_channel_next, cache_last_time_next, cache_last_channel_next_len) = rets
if cache_last_channel_next is not None and self.streaming_cfg.last_channel_cache_size >= 0:
if self.streaming_cfg.last_channel_cache_size > 0:
cache_last_channel_next = cache_last_channel_next[
:, :, -self.streaming_cfg.last_channel_cache_size :, :
]
if self.streaming_cfg.valid_out_len > 0 and (not keep_all_outputs or self.att_context_style == "regular"):
encoded = encoded[:, :, : self.streaming_cfg.valid_out_len]
encoded_len = torch.clamp(encoded_len, max=self.streaming_cfg.valid_out_len)
return (encoded, encoded_len, cache_last_channel_next, cache_last_time_next, cache_last_channel_next_len)
def forward(
self,
audio_signal,
length,
cache_last_channel=None,
cache_last_time=None,
cache_last_channel_len=None,
bypass_pre_encode=False,
):
"""
Forward function for the ConformerEncoder accepting an audio signal and its corresponding length.
The `audio_signal` input supports two formats depending on the `bypass_pre_encode` boolean flag.
This determines the required format of the input variable `audio_signal`:
(1) bypass_pre_encode = False (default):
`audio_signal` must be a tensor containing audio features.
Shape: (batch, self._feat_in, n_frames)
(2) bypass_pre_encode = True:
`audio_signal` must be a tensor containing pre-encoded embeddings.
Shape: (batch, n_frame, self.d_model)
"""
if not bypass_pre_encode and audio_signal.shape[-2] != self._feat_in:
raise ValueError(
f"If bypass_pre_encode is False, audio_signal should have shape "
f"(batch, {self._feat_in}, n_frame) but got last dimension {audio_signal.shape[-2]}."
)
if bypass_pre_encode and audio_signal.shape[-1] != self.d_model:
raise ValueError(
f"If bypass_pre_encode is True, audio_signal should have shape "
f"(batch, n_frame, {self.d_model}) but got last dimension {audio_signal.shape[-1]}."
)
if bypass_pre_encode:
self.update_max_seq_length(
seq_length=audio_signal.size(2) * self.subsampling_factor, device=audio_signal.device
)
else:
self.update_max_seq_length(seq_length=audio_signal.size(2), device=audio_signal.device)
return self.forward_internal(
audio_signal,
length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
bypass_pre_encode=bypass_pre_encode,
)
def forward_internal(
self,
audio_signal,
length,
cache_last_channel=None,
cache_last_time=None,
cache_last_channel_len=None,
bypass_pre_encode=False,
):
"""
The `audio_signal` input supports two formats depending on the `bypass_pre_encode` boolean flag.
This determines the required format of the input variable `audio_signal`:
(1) bypass_pre_encode = False (default):
`audio_signal` must be a tensor containing audio features.
Shape: (batch, self._feat_in, n_frames)
(2) bypass_pre_encode = True:
`audio_signal` must be a tensor containing pre-encoded embeddings.
Shape: (batch, n_frame, self.d_model)
`bypass_pre_encode=True` is used in cases where frame-level, context-independent embeddings are
needed to be saved or reused (e.g., speaker cache in streaming speaker diarization).
"""
if length is None:
length = audio_signal.new_full(
(audio_signal.size(0),), audio_signal.size(-1), dtype=torch.int64, device=audio_signal.device
)
# select a random att_context_size with the distribution specified by att_context_probs during training
# for non-validation cases like test, validation or inference, it uses the first mode in self.att_context_size
if self.training and len(self.att_context_size_all) > 1:
cur_att_context_size = random.choices(self.att_context_size_all, weights=self.att_context_probs)[0]
else:
cur_att_context_size = self.att_context_size
if not bypass_pre_encode:
audio_signal = torch.transpose(audio_signal, 1, 2)
if isinstance(self.pre_encode, nn.Linear):
audio_signal = self.pre_encode(audio_signal)
else:
audio_signal, length = self.pre_encode(x=audio_signal, lengths=length)
length = length.to(torch.int64)
# `self.streaming_cfg` is set by setup_streaming_cfg(), called in the init
if self.streaming_cfg.drop_extra_pre_encoded > 0 and cache_last_channel is not None:
audio_signal = audio_signal[:, self.streaming_cfg.drop_extra_pre_encoded :, :]
length = (length - self.streaming_cfg.drop_extra_pre_encoded).clamp(min=0)
if self.reduction_position is not None and cache_last_channel is not None:
raise ValueError("Caching with reduction feature is not supported yet!")
max_audio_length = audio_signal.size(1)
if cache_last_channel is not None:
cache_len = self.streaming_cfg.last_channel_cache_size
cache_keep_size = max_audio_length - self.streaming_cfg.cache_drop_size
max_audio_length = max_audio_length + cache_len
padding_length = length + cache_len
offset = torch.neg(cache_last_channel_len) + cache_len
else:
padding_length = length
cache_last_channel_next = None
cache_len = 0
offset = None
audio_signal, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len)
# Create the self-attention and padding masks
pad_mask, att_mask = self._create_masks(
att_context_size=cur_att_context_size,
padding_length=padding_length,
max_audio_length=max_audio_length,
offset=offset,
device=audio_signal.device,
)
if cache_last_channel is not None:
pad_mask = pad_mask[:, cache_len:]
if att_mask is not None:
att_mask = att_mask[:, cache_len:]
# Convert caches from the tensor to list
cache_last_time_next = []
cache_last_channel_next = []
for lth, (drop_prob, layer) in enumerate(zip(self.layer_drop_probs, self.layers)):
original_signal = audio_signal
if cache_last_channel is not None:
cache_last_channel_cur = cache_last_channel[lth]
cache_last_time_cur = cache_last_time[lth]
else:
cache_last_channel_cur = None
cache_last_time_cur = None
audio_signal = layer(
x=audio_signal,
att_mask=att_mask,
pos_emb=pos_emb,
pad_mask=pad_mask,
cache_last_channel=cache_last_channel_cur,
cache_last_time=cache_last_time_cur,
)
if cache_last_channel_cur is not None:
(audio_signal, cache_last_channel_cur, cache_last_time_cur) = audio_signal
cache_last_channel_next.append(cache_last_channel_cur)
cache_last_time_next.append(cache_last_time_cur)
# applying stochastic depth logic from https://arxiv.org/abs/2102.03216
if self.training and drop_prob > 0.0:
should_drop = torch.rand(1) < drop_prob
# adjusting to match expectation
if should_drop:
# that's not efficient, but it's hard to implement distributed
# version of dropping layers without deadlock or random seed meddling
# so multiplying the signal by 0 to ensure all weights get gradients
audio_signal = audio_signal * 0.0 + original_signal
else:
# not doing this operation if drop prob is 0 as it's identity in that case
audio_signal = (audio_signal - original_signal) / (1.0 - drop_prob) + original_signal
if self.reduction_position == lth:
audio_signal, length = self.reduction_subsampling(x=audio_signal, lengths=length)
max_audio_length = audio_signal.size(1)
# Don't update the audio_signal here because then it will again scale the audio_signal
# and cause an increase in the WER
_, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len)
pad_mask, att_mask = self._create_masks(
att_context_size=cur_att_context_size,
padding_length=length,
max_audio_length=max_audio_length,
offset=offset,
device=audio_signal.device,
)
# saving tensors if required for interctc loss
# if self.is_access_enabled(getattr(self, "model_guid", None)):
# if self.interctc_capture_at_layers is None:
# self.interctc_capture_at_layers = self.access_cfg.get('interctc', {}).get('capture_layers', [])
# if lth in self.interctc_capture_at_layers:
# lth_audio_signal = audio_signal
# if self.out_proj is not None:
# lth_audio_signal = self.out_proj(audio_signal)
# # shape is the same as the shape of audio_signal output, i.e. [B, D, T]
# self.register_accessible_tensor(
# name=f'interctc/layer_output_{lth}', tensor=torch.transpose(lth_audio_signal, 1, 2)
# )
# self.register_accessible_tensor(name=f'interctc/layer_length_{lth}', tensor=length)
if self.out_proj is not None:
audio_signal = self.out_proj(audio_signal)
# Reduction
if self.reduction_position == -1:
audio_signal, length = self.reduction_subsampling(x=audio_signal, lengths=length)
audio_signal = torch.transpose(audio_signal, 1, 2)
length = length.to(dtype=torch.int64)
if cache_last_channel is not None:
cache_last_channel_next = torch.stack(cache_last_channel_next, dim=0)
cache_last_time_next = torch.stack(cache_last_time_next, dim=0)
return (
audio_signal,
length,
cache_last_channel_next,
cache_last_time_next,
torch.clamp(cache_last_channel_len + cache_keep_size, max=cache_len),
)
else:
return audio_signal, length
def update_max_seq_length(self, seq_length: int, device):
"""
Updates the maximum sequence length for the model.
Args:
seq_length (int): New maximum sequence length.
device (torch.device): Device to use for computations.
"""
# Find global max audio length across all nodes
if self.sync_max_audio_length and torch.distributed.is_initialized():
global_max_len = torch.tensor([seq_length], dtype=torch.float32, device=device)
# Update across all ranks in the distributed system
torch.distributed.all_reduce(global_max_len, op=torch.distributed.ReduceOp.MAX)
seq_length = global_max_len.int().item()
if seq_length > self.max_audio_length:
self.set_max_audio_length(seq_length)
def set_max_audio_length(self, max_audio_length):
"""
Sets maximum input length.
Pre-calculates internal seq_range mask.
Args:
max_audio_length (int): New maximum sequence length.
"""
self.max_audio_length = max_audio_length
device = next(self.parameters()).device
dtype = next(self.parameters()).dtype
self.pos_enc.extend_pe(max_audio_length, device, dtype)
def _create_masks(self, att_context_size, padding_length, max_audio_length, offset, device):
if self.self_attention_model != "rel_pos_local_attn":
att_mask = torch.ones(1, max_audio_length, max_audio_length, dtype=torch.bool, device=device)
if self.att_context_style == "regular":
if att_context_size[0] >= 0:
att_mask = att_mask.triu(diagonal=-att_context_size[0])
if att_context_size[1] >= 0:
att_mask = att_mask.tril(diagonal=att_context_size[1])
elif self.att_context_style == "chunked_limited":
# When right context is unlimited, just the left side of the masking need to get updated
if att_context_size[1] == -1:
if att_context_size[0] >= 0:
att_mask = att_mask.triu(diagonal=-att_context_size[0])
else:
chunk_size = att_context_size[1] + 1
# left_chunks_num specifies the number of chunks to be visible by each chunk on the left side
if att_context_size[0] >= 0:
left_chunks_num = att_context_size[0] // chunk_size
else:
left_chunks_num = 10000
chunk_idx = torch.arange(0, max_audio_length, dtype=torch.int, device=att_mask.device)
chunk_idx = torch.div(chunk_idx, chunk_size, rounding_mode="trunc")
diff_chunks = chunk_idx.unsqueeze(1) - chunk_idx.unsqueeze(0)
chunked_limited_mask = torch.logical_and(
torch.le(diff_chunks, left_chunks_num), torch.ge(diff_chunks, 0)
)
att_mask = torch.logical_and(att_mask, chunked_limited_mask.unsqueeze(0))
else:
att_mask = None
# pad_mask is the masking to be used to ignore paddings
pad_mask = torch.arange(0, max_audio_length, device=device).expand(
padding_length.size(0), -1
) < padding_length.unsqueeze(-1)
if offset is not None:
pad_mask_off = torch.arange(0, max_audio_length, device=device).expand(
padding_length.size(0), -1
) >= offset.unsqueeze(-1)
pad_mask = pad_mask_off.logical_and(pad_mask)
if att_mask is not None:
# pad_mask_for_att_mask is the mask which helps to ignore paddings
pad_mask_for_att_mask = pad_mask.unsqueeze(1).repeat([1, max_audio_length, 1])
pad_mask_for_att_mask = torch.logical_and(pad_mask_for_att_mask, pad_mask_for_att_mask.transpose(1, 2))
# att_mask is the masking to be used by the MHA layers to ignore the tokens not supposed to be visible
att_mask = att_mask[:, :max_audio_length, :max_audio_length]
# paddings should also get ignored, so pad_mask_for_att_mask is used to ignore their corresponding scores
att_mask = torch.logical_and(pad_mask_for_att_mask, att_mask.to(pad_mask_for_att_mask.device))
att_mask = ~att_mask
pad_mask = ~pad_mask
return pad_mask, att_mask
def enable_pad_mask(self, on=True):
"""
Enables or disables the pad mask and assign the boolean state `on`.
Returns:
mask (bool): The current state of the pad mask.
"""
# On inference, user may choose to disable pad mask
mask = self.use_pad_mask
self.use_pad_mask = on
return mask
def _calc_context_sizes(
self, att_context_size, att_context_probs, att_context_style, conv_context_size, conv_kernel_size
):
# convert att_context_size to a standard list of lists
if att_context_size:
att_context_size_all = list(att_context_size)
if isinstance(att_context_size_all[0], int):
att_context_size_all = [att_context_size_all]
for i, att_cs in enumerate(att_context_size_all):
# if isinstance(att_cs, ListConfig):
# att_context_size_all[i] = list(att_cs)
if att_context_style == "chunked_limited":
if att_cs[0] > 0 and att_cs[0] % (att_cs[1] + 1) > 0:
raise ValueError(f"att_context_size[{i}][0] % (att_context_size[{i}][1] + 1) should be zero!")
if att_cs[1] < 0 and len(att_context_size_all) <= 1:
raise ValueError(
f"Right context (att_context_size[{i}][1]) can not be unlimited for chunked_limited style!"
)
else:
att_context_size_all = [[-1, -1]]
if att_context_probs:
if len(att_context_probs) != len(att_context_size_all):
raise ValueError("The size of the att_context_probs should be the same as att_context_size.")
att_context_probs = list(att_context_probs)
if sum(att_context_probs) != 1:
raise ValueError(
"The sum of numbers in att_context_probs should be equal to one to be a distribution."
)
else:
att_context_probs = [1.0 / len(att_context_size_all)] * len(att_context_size_all)
if conv_context_size is not None:
# if isinstance(conv_context_size, ListConfig):
# conv_context_size = list(conv_context_size)
if not isinstance(conv_context_size, list) and not isinstance(conv_context_size, str):
raise ValueError(
"Invalid conv_context_size! It should be the string 'causal' or a list of two integers."
)
if conv_context_size == "causal":
conv_context_size = [conv_kernel_size - 1, 0]
else:
if conv_context_size[0] + conv_context_size[1] + 1 != conv_kernel_size:
raise ValueError(f"Invalid conv_context_size: {self.conv_context_size}!")
else:
conv_context_size = [(conv_kernel_size - 1) // 2, (conv_kernel_size - 1) // 2]
return att_context_size_all, att_context_size_all[0], att_context_probs, conv_context_size
def set_default_att_context_size(self, att_context_size):
"""
Sets the default attention context size from `att_context_size` argument.
Args:
att_context_size (list): The attention context size to be set.
"""
if att_context_size not in self.att_context_size_all:
logging.warning(
f"att_context_size={att_context_size} is not among the list of the supported "
f"look-aheads: {self.att_context_size_all}"
)
if att_context_size is not None:
self.att_context_size = att_context_size
self.setup_streaming_params()
def setup_streaming_params(
self,
chunk_size: int = None,
shift_size: int = None,
left_chunks: int = None,
att_context_size: list = None,
max_context: int = 10000,
):
"""
This function sets the needed values and parameters to perform streaming.
The configuration would be stored in self.streaming_cfg.
The streaming configuration is needed to simulate streaming inference.
Args:
chunk_size (int): overrides the chunk size
shift_size (int): overrides the shift size for chunks
left_chunks (int): overrides the number of left chunks visible to each chunk
max_context (int): the value used for the cache size of last_channel layers
if left context is set to infinity (-1)
Defaults to -1 (means feat_out is d_model)
"""
streaming_cfg = CacheAwareStreamingConfig()
# When att_context_size is not specified, it uses the default_att_context_size
if att_context_size is None:
att_context_size = self.att_context_size
if chunk_size is not None:
if chunk_size < 1:
raise ValueError("chunk_size needs to be a number larger or equal to one.")
lookahead_steps = chunk_size - 1
streaming_cfg.cache_drop_size = chunk_size - shift_size
elif self.att_context_style == "chunked_limited":
lookahead_steps = att_context_size[1]
streaming_cfg.cache_drop_size = 0
elif self.att_context_style == "regular":
lookahead_steps = att_context_size[1] * self.n_layers + self.conv_context_size[1] * self.n_layers
streaming_cfg.cache_drop_size = lookahead_steps
else:
streaming_cfg.cache_drop_size = 0
lookahead_steps = None
if chunk_size is None:
streaming_cfg.last_channel_cache_size = att_context_size[0] if att_context_size[0] >= 0 else max_context
else:
if left_chunks is None:
streaming_cfg.last_channel_cache_size = (
att_context_size[0] if att_context_size[0] >= 0 else max_context
)
logging.warning(
f"left_chunks is not set. Setting it to default: {streaming_cfg.last_channel_cache_size}."
)
else:
streaming_cfg.last_channel_cache_size = left_chunks * chunk_size
if hasattr(self.pre_encode, "get_sampling_frames"):
sampling_frames = self.pre_encode.get_sampling_frames()
else:
sampling_frames = 0
if isinstance(sampling_frames, list):
streaming_cfg.chunk_size = [
sampling_frames[0] + self.subsampling_factor * lookahead_steps,
sampling_frames[1] + self.subsampling_factor * lookahead_steps,
]
else:
streaming_cfg.chunk_size = sampling_frames * (1 + lookahead_steps)
if isinstance(sampling_frames, list):
streaming_cfg.shift_size = [
sampling_frames[0] + sampling_frames[1] * (lookahead_steps - streaming_cfg.cache_drop_size),
sampling_frames[1] + sampling_frames[1] * (lookahead_steps - streaming_cfg.cache_drop_size),
]
else:
streaming_cfg.shift_size = sampling_frames * (1 + lookahead_steps - streaming_cfg.cache_drop_size)
if isinstance(streaming_cfg.shift_size, list):
streaming_cfg.valid_out_len = (
streaming_cfg.shift_size[1] - sampling_frames[1]
) // self.subsampling_factor + 1
else:
streaming_cfg.valid_out_len = streaming_cfg.shift_size // self.subsampling_factor
if hasattr(self.pre_encode, "get_streaming_cache_size"):
streaming_cfg.pre_encode_cache_size = self.pre_encode.get_streaming_cache_size()
else:
streaming_cfg.pre_encode_cache_size = 0
if isinstance(streaming_cfg.pre_encode_cache_size, list):
if streaming_cfg.pre_encode_cache_size[1] >= 1:
streaming_cfg.drop_extra_pre_encoded = (
1 + (streaming_cfg.pre_encode_cache_size[1] - 1) // self.subsampling_factor
)
else:
streaming_cfg.drop_extra_pre_encoded = 0
else:
streaming_cfg.drop_extra_pre_encoded = streaming_cfg.pre_encode_cache_size // self.subsampling_factor
for m in self.layers.modules():
if hasattr(m, "_max_cache_len"):
if isinstance(m, MultiHeadAttention):
m.cache_drop_size = streaming_cfg.cache_drop_size
if isinstance(m, CausalConv1D):
m.cache_drop_size = streaming_cfg.cache_drop_size
self.streaming_cfg = streaming_cfg
def get_initial_cache_state(self, batch_size=1, dtype=torch.float32, device=None, max_dim=0):
if device is None:
device = next(self.parameters()).device
if max_dim > 0:
create_tensor = torch.randn
else:
create_tensor = torch.zeros
last_time_cache_size = self.conv_context_size[0]
cache_last_channel = create_tensor(
(
len(self.layers),
batch_size,
self.streaming_cfg.last_channel_cache_size,
self.d_model,
),
device=device,
dtype=dtype,
)
cache_last_time = create_tensor(
(len(self.layers), batch_size, self.d_model, last_time_cache_size),
device=device,
dtype=dtype,
)
if max_dim > 0:
cache_last_channel_len = torch.randint(
0,
min(max_dim, self.streaming_cfg.last_channel_cache_size),
(batch_size,),
device=device,
dtype=torch.int64,
)
for i in range(batch_size):
cache_last_channel[:, i, cache_last_channel_len[i] :, :] = 0
# what is the right rule to zero out cache_last_time?
if cache_last_channel_len[i] == 0:
cache_last_time[:, i, :, :] = 0
else:
cache_last_channel_len = torch.zeros(batch_size, device=device, dtype=torch.int64)
return cache_last_channel, cache_last_time, cache_last_channel_len
def change_attention_model(
self,
self_attention_model: str = None,
att_context_size: list[int] = None,
update_config: bool = True,
device: torch.device = None,
):
"""
Update the self_attention_model which changes the positional encoding and attention layers.
Args:
self_attention_model (str): type of the attention layer and positional encoding
'rel_pos':
relative positional embedding and Transformer-XL
'rel_pos_local_attn':
relative positional embedding and Transformer-XL with local attention using
overlapping windows. Attention context is determined by att_context_size parameter.
'abs_pos':
absolute positional embedding and Transformer
If None is provided, the self_attention_model isn't changed. Defaults to None.
att_context_size (List[int]): List of 2 ints corresponding to left and right attention context sizes,
or None to keep as it is. Defaults to None.
update_config (bool): Whether to update the config or not with the new attention model.
Defaults to True.
device (torch.device): If provided, new layers will be moved to the device.
Defaults to None.
"""
if att_context_size:
att_context_size = list(att_context_size)
else:
att_context_size = self.att_context_size
if self_attention_model is None:
self_attention_model = self.self_attention_model
if self_attention_model == 'rel_pos_local_attn' and max(att_context_size) <= 0:
raise ValueError("When using local attention, context size must be set > 0")
if self_attention_model == "rel_pos":
new_pos_enc = RelPositionalEncoding(
d_model=self._cfg.d_model,
dropout_rate=self._cfg.dropout,
max_len=self._cfg.pos_emb_max_len,
xscale=self.xscale,
dropout_rate_emb=self._cfg.dropout_emb,
)
elif self_attention_model == 'rel_pos_local_attn':
new_pos_enc = LocalAttRelPositionalEncoding(
att_context_size=att_context_size,
d_model=self._cfg.d_model,
dropout_rate=self._cfg.dropout,
max_len=self._cfg.pos_emb_max_len,
xscale=self.xscale,
dropout_rate_emb=self._cfg.dropout_emb,
)
elif self_attention_model == "abs_pos":
new_pos_enc = PositionalEncoding(
d_model=self._cfg.d_model,
dropout_rate=self._cfg.dropout,
max_len=self._cfg.pos_emb_max_len,
xscale=self.xscale,
)
else:
raise ValueError(f"Not valid self_attention_model: '{self_attention_model}'!")
if device is not None:
new_pos_enc = new_pos_enc.to(device=device)
del self.pos_enc
self.pos_enc = new_pos_enc
self.self_attention_model = self_attention_model
self.att_context_size = att_context_size
self.set_max_audio_length(self.pos_emb_max_len)
for _, m in self.named_modules():
if type(m) == ConformerLayer:
if self_attention_model == 'rel_pos':
new_attn = RelPositionMultiHeadAttention(
n_head=self._cfg.n_heads,
n_feat=self._cfg.d_model,
dropout_rate=self._cfg.dropout_att,
max_cache_len=att_context_size[0],
pos_bias_u=None,
pos_bias_v=None,
use_pytorch_sdpa=self.use_pytorch_sdpa,
use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends,
)
elif self_attention_model == 'rel_pos_local_attn':
new_attn = RelPositionMultiHeadAttentionLongformer(
n_head=self._cfg.n_heads,
n_feat=self._cfg.d_model,
dropout_rate=self._cfg.dropout_att,
max_cache_len=att_context_size[0],
att_context_size=att_context_size,
pos_bias_u=None,
pos_bias_v=None,
use_pytorch_sdpa=self.use_pytorch_sdpa,
use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends,
)
elif self_attention_model == 'abs_pos':
new_attn = MultiHeadAttention(
n_head=self._cfg.n_heads,
n_feat=self._cfg.d_model,
dropout_rate=self._cfg.dropout_att,
max_cache_len=att_context_size[0],
use_pytorch_sdpa=self.use_pytorch_sdpa,
use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends,
)
else:
raise ValueError(
f"'{self_attention_model}' is not not a valid value for 'self_attention_model', "
f"valid values can be from ['rel_pos', 'rel_pos_local_attn', 'abs_pos']"
)
if device is not None:
new_attn = new_attn.to(device=device)
new_attn.load_state_dict(m.self_attn.state_dict(), strict=False)
del m.self_attn
m.self_attn = new_attn
m.self_attention_model = self_attention_model
if update_config:
with open_dict(self._cfg):
self._cfg.self_attention_model = self_attention_model
self._cfg.att_context_size = att_context_size
def change_subsampling_conv_chunking_factor(self, subsampling_conv_chunking_factor: int):
"""
Update the conv_chunking_factor (int)
Default is 1 (auto)
Set it to -1 (disabled) or to a specific value (power of 2) if you OOM in the conv subsampling layers
Args:
subsampling_conv_chunking_factor (int)
"""
if not hasattr(self.pre_encode, "change_subsampling_conv_chunking_factor"):
logging.info("Model pre_encoder doesn't have a change_subsampling_conv_chunking_factor method ")
return
self.pre_encode.change_subsampling_conv_chunking_factor(
subsampling_conv_chunking_factor=subsampling_conv_chunking_factor
)
```
## /src/liquid_audio/model/conformer/mha.py
```py path="/src/liquid_audio/model/conformer/mha.py"
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""NeMo multhead attention
adapted from https://github.com/NVIDIA/NeMo/blob/c83adff36efaa549f7bdd26e97c01a60e9f9026b/nemo/collections/asr/parts/submodules/multi_head_attention.py
"""
import torch
from torch import nn
import math
from .utils import avoid_float16_autocast_context
INF_VAL = 10000.0
class PositionalEncoding(torch.nn.Module):
"""Fixed sinusoidal positional encoding.
Args:
d_model (int): embedding dim
dropout_rate (float): dropout rate
max_len (int): maximum input length
xscale (bool): whether to scale the input by sqrt(d_model)
dropout_rate_emb (float): dropout rate for the positional embeddings
"""
def __init__(self, d_model, dropout_rate, max_len=5000, xscale=None, dropout_rate_emb=0.0):
"""Construct an PositionalEncoding object."""
super().__init__()
self.d_model = d_model
self.xscale = xscale
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.max_len = max_len
if dropout_rate_emb > 0:
self.dropout_emb = nn.Dropout(dropout_rate_emb)
else:
self.dropout_emb = None
def create_pe(self, positions, dtype):
pos_length = positions.size(0)
pe = torch.zeros(pos_length, self.d_model, device=positions.device)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32, device=positions.device)
* -(math.log(INF_VAL) / self.d_model)
)
pe[:, 0::2] = torch.sin(positions * div_term)
pe[:, 1::2] = torch.cos(positions * div_term)
pe = pe.unsqueeze(0).to(dtype)
if hasattr(self, 'pe'):
self.pe = pe
else:
self.register_buffer('pe', pe, persistent=False)
def extend_pe(self, length, device, dtype):
"""Reset and extend the positional encodings if needed."""
if hasattr(self, 'pe') and self.pe.size(1) >= length:
return
positions = torch.arange(0, length, dtype=torch.float32, device=device).unsqueeze(1)
self.create_pe(positions=positions, dtype=dtype)
def forward(self, x: torch.Tensor, cache_len=0):
"""Adds positional encoding.
Args:
x (torch.Tensor): Input. Its shape is (batch, time, feature_size)
cache_len (int): the size of the cache which is used to shift positions
Returns:
x+pos_emb (torch.Tensor): Its shape is (batch, time, feature_size)
pos_emb (torch.Tensor): Its shape is (1, time, feature_size)
"""
input_len = x.size(1) + cache_len
if self.xscale:
x = x * self.xscale
pos_emb = self.pe[:, :input_len]
if self.dropout_emb:
pos_emb = self.dropout_emb(pos_emb)
x = x + pos_emb
return self.dropout(x), pos_emb
class RelPositionalEncoding(PositionalEncoding):
"""Relative positional encoding for TransformerXL's layers
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): embedding dim
dropout_rate (float): dropout rate
max_len (int): maximum input length
xscale (bool): whether to scale the input by sqrt(d_model)
dropout_rate_emb (float): dropout rate for the positional embeddings
"""
def extend_pe(self, length, device, dtype):
"""Reset and extend the positional encodings if needed."""
needed_size = 2 * length - 1
if hasattr(self, 'pe') and self.pe.size(1) >= needed_size:
return
# positions would be from negative numbers to positive
# positive positions would be used for left positions and negative for right positions
positions = torch.arange(length - 1, -length, -1, dtype=torch.float32, device=device).unsqueeze(1)
self.create_pe(positions=positions, dtype=dtype)
def forward(self, x, cache_len=0):
"""Compute positional encoding.
Args:
x (torch.Tensor): Input. Its shape is (batch, time, feature_size)
cache_len (int): the size of the cache which is used to shift positions
Returns:
x (torch.Tensor): Its shape is (batch, time, feature_size)
pos_emb (torch.Tensor): Its shape is (1, time, feature_size)
"""
if self.xscale:
x = x * self.xscale
# center_pos would be the index of position 0
# negative positions would be used for right and positive for left tokens
# for input of length L, 2*L-1 positions are needed, positions from (L-1) to -(L-1)
input_len = x.size(1) + cache_len
center_pos = self.pe.size(1) // 2 + 1
start_pos = center_pos - input_len
end_pos = center_pos + input_len - 1
pos_emb = self.pe[:, start_pos:end_pos]
if self.dropout_emb:
pos_emb = self.dropout_emb(pos_emb)
return self.dropout(x), pos_emb
class MultiHeadAttention(nn.Module):
"""Multi-Head Attention layer of Transformer.
Args:
n_head (int): number of heads
n_feat (int): size of the features
dropout_rate (float): dropout rate
use_bias (bool): whether to remove bias in linear and conv layers
use_pytorch_sdpa (bool): use torch sdpa instead of manual attention
use_pytorch_sdpa_backends list[str]: list of backend names to use in sdpa. None or empty list means all backends. e.g. ["MATH"]
"""
def __init__(
self,
n_head,
n_feat,
dropout_rate,
max_cache_len=0,
use_bias=True,
use_pytorch_sdpa=False,
use_pytorch_sdpa_backends=None,
):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadAttention, self).__init__()
self.use_pytorch_sdpa = use_pytorch_sdpa
if self.use_pytorch_sdpa and use_pytorch_sdpa_backends:
use_pytorch_sdpa_backends = list(
map(
lambda backend_name: getattr(torch.nn.attention.SDPBackend, backend_name),
use_pytorch_sdpa_backends,
)
)
self.use_pytorch_sdpa_backends = use_pytorch_sdpa_backends
self.cache_drop_size = None
self.use_bias = use_bias
self.dropout_rate = dropout_rate
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.s_d_k = math.sqrt(self.d_k)
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat, bias=use_bias)
self.linear_k = nn.Linear(n_feat, n_feat, bias=use_bias)
self.linear_v = nn.Linear(n_feat, n_feat, bias=use_bias)
self.linear_out = nn.Linear(n_feat, n_feat, bias=use_bias)
self.dropout = nn.Dropout(p=dropout_rate)
self._max_cache_len = max_cache_len
def forward_qkv(self, query, key, value):
"""Transforms query, key and value.
Args:
query (torch.Tensor): (batch, time1, size)
key (torch.Tensor): (batch, time2, size)
value (torch.Tensor): (batch, time2, size)
returns:
q (torch.Tensor): (batch, head, time1, size)
k (torch.Tensor): (batch, head, time2, size)
v (torch.Tensor): (batch, head, time2, size)
"""
n_batch = query.size(0)
t1 = query.size(1)
t2 = key.size(1)
q = self.linear_q(query).view(n_batch, t1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, t2, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, t2, self.h, self.d_k)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
return q, k, v
def forward_attention(self, value, scores, mask):
"""Compute attention context vector.
Args:
value (torch.Tensor): (batch, head, time2, size)
scores(torch.Tensor): (batch, head, time1, time2)
mask(torch.Tensor): (batch, time1, time2)
returns:
value (torch.Tensor): transformed `value` (batch, time2, d_model) weighted by the attention scores
"""
n_batch = value.size(0)
time = scores.size(2)
if mask is not None:
mask = mask.unsqueeze(1) # (batch, 1, time1, time2)
scores = scores.masked_fill(mask, -INF_VAL)
attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
else:
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = x.transpose(1, 2).reshape(n_batch, time, self.h * self.d_k) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
def forward(self, query, key, value, mask, pos_emb=None, cache=None):
"""Compute 'Scaled Dot Product Attention'.
Args:
query (torch.Tensor): (batch, time1, size)
key (torch.Tensor): (batch, time2, size)
value(torch.Tensor): (batch, time2, size)
mask (torch.Tensor): (batch, time1, time2)
cache (torch.Tensor) : (batch, time_cache, size)
returns:
output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention
cache (torch.Tensor) : (batch, time_cache_next, size)
"""
key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache)
if torch.is_autocast_enabled():
query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32)
# temporary until we solve this more gracefully
with avoid_float16_autocast_context():
q, k, v = self.forward_qkv(query, key, value)
if self.use_pytorch_sdpa:
n_batch = value.size(0)
if mask is not None:
mask = ~mask.unsqueeze(1)
dropout_rate = self.dropout_rate if self.training else 0
if self.use_pytorch_sdpa_backends:
with torch.nn.attention.sdpa_kernel(self.use_pytorch_sdpa_backends):
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=dropout_rate
)
else:
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=dropout_rate
)
# this IF block can be deleted when https://github.com/pytorch/pytorch/pull/131863 is in the stable version
if mask is not None:
all_masked_rows = torch.all(~mask, dim=-1)
all_masked_rows.unsqueeze_(-1)
out = out.masked_fill(all_masked_rows, 0.0)
out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = self.linear_out(out) # (batch, time1, d_model)
else:
scores = torch.matmul(q, k.transpose(-2, -1)) / self.s_d_k
out = self.forward_attention(v, scores, mask)
if cache is None:
return out
else:
return out, cache
def update_cache(self, key, value, query, cache):
if cache is not None:
key = value = torch.cat([cache, key], dim=1)
q_keep_size = query.shape[1] - self.cache_drop_size
cache = torch.cat([cache[:, q_keep_size:, :], query[:, :q_keep_size, :]], dim=1)
return key, value, query, cache
class RelPositionMultiHeadAttention(MultiHeadAttention):
"""Multi-Head Attention layer of Transformer-XL with support of relative positional encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): number of heads
n_feat (int): size of the features
dropout_rate (float): dropout rate
use_bias (bool): whether to apply bias in linear and conv layers of MultiHeadAttention
"""
def __init__(
self,
n_head,
n_feat,
dropout_rate,
pos_bias_u,
pos_bias_v,
max_cache_len=0,
use_bias=True,
use_pytorch_sdpa=False,
use_pytorch_sdpa_backends=None,
):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(
n_head=n_head,
n_feat=n_feat,
dropout_rate=dropout_rate,
max_cache_len=max_cache_len,
use_bias=use_bias,
use_pytorch_sdpa=use_pytorch_sdpa,
use_pytorch_sdpa_backends=use_pytorch_sdpa_backends,
)
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable biases are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
if pos_bias_u is None or pos_bias_v is None:
self.pos_bias_u = nn.Parameter(torch.FloatTensor(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.FloatTensor(self.h, self.d_k))
# nn.init.normal_(self.pos_bias_u, 0.0, 0.02)
# nn.init.normal_(self.pos_bias_v, 0.0, 0.02)
nn.init.zeros_(self.pos_bias_u)
nn.init.zeros_(self.pos_bias_v)
else:
self.pos_bias_u = pos_bias_u
self.pos_bias_v = pos_bias_v
def rel_shift(self, x):
"""Compute relative positional encoding.
Args:
x (torch.Tensor): (batch, nheads, time, 2*time-1)
"""
b, h, qlen, pos_len = x.size() # (b, h, t1, t2)
# need to add a column of zeros on the left side of last dimension to perform the relative shifting
x = torch.nn.functional.pad(x, pad=(1, 0)) # (b, h, t1, t2+1)
x = x.view(b, h, pos_len+1, qlen) # (b, h, t2+1, t1)
# need to drop the first row
x = x[:, :, 1:].view(b, h, qlen, pos_len) # (b, h, t1, t2)
return x
def forward(self, query, key, value, mask, pos_emb, cache=None):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): (batch, time1, size)
key (torch.Tensor): (batch, time2, size)
value(torch.Tensor): (batch, time2, size)
mask (torch.Tensor): (batch, time1, time2)
pos_emb (torch.Tensor) : (batch, time1, size)
cache (torch.Tensor) : (batch, time_cache, size)
Returns:
output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention
cache (torch.Tensor) : (batch, time_cache_next, size)
"""
key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache)
if torch.is_autocast_enabled():
query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32)
# temporary until we solve this more gracefully
with avoid_float16_autocast_context():
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0)
n_batch = value.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
if self.use_pytorch_sdpa:
scale_factor = 1 / math.sqrt(q_with_bias_u.size(-1))
matrix_bd = matrix_bd[:, :, :, : k.size(-2)] * scale_factor
if mask is not None:
mask = mask.unsqueeze(1)
matrix_bd.masked_fill_(mask, -INF_VAL)
dropout_rate = self.dropout_rate if self.training else 0
if self.use_pytorch_sdpa_backends:
with torch.nn.attention.sdpa_kernel(self.use_pytorch_sdpa_backends):
out = torch.nn.functional.scaled_dot_product_attention(
q_with_bias_u, k, v, attn_mask=matrix_bd, dropout_p=dropout_rate
)
else:
out = torch.nn.functional.scaled_dot_product_attention(
q_with_bias_u, k, v, attn_mask=matrix_bd, dropout_p=dropout_rate
)
# this IF block can be deleted when https://github.com/pytorch/pytorch/pull/131863 is in the stable version
if mask is not None:
all_masked_rows = torch.all(mask, dim=-1)
all_masked_rows.unsqueeze_(-1)
all_masked_rows = all_masked_rows.expand(-1, out.size(1), -1, out.size(-1))
out = out.masked_fill(all_masked_rows, 0.0)
out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = self.linear_out(out) # (batch, time1, d_model)
else:
# drops extra elements in the matrix_bd to match the matrix_ac's size
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)]
scores = (matrix_ac + matrix_bd) / self.s_d_k # (batch, head, time1, time2)
out = self.forward_attention(v, scores, mask)
if cache is None:
return out
else:
return out, cache
```
## /src/liquid_audio/model/conformer/modules.py
```py path="/src/liquid_audio/model/conformer/modules.py"
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""NeMo conformer modules and layers
adapted from https://github.com/NVIDIA/NeMo/blob/c83adff36efaa549f7bdd26e97c01a60e9f9026b/nemo/collections/asr/parts/submodules/conformer_modules.py
"""
import torch
from torch import nn
from torch.nn import functional as F
from .mha import RelPositionMultiHeadAttention
class ConformerLayer(torch.nn.Module):
"""A single block of the Conformer encoder.
Args:
d_model (int): input dimension of MultiheadAttentionMechanism and PositionwiseFeedForward
d_ff (int): hidden dimension of PositionwiseFeedForward
self_attention_model (str): type of the attention layer and positional encoding
'rel_pos': relative positional embedding and Transformer-XL
'rel_pos_local_attn': relative positional embedding and Transformer-XL with local attention using
overlapping chunks. Attention context is determined by att_context_size parameter.
'abs_pos': absolute positional embedding and Transformer
Default is rel_pos.
global_tokens (int): number of tokens to be used for global attention.
Only relevant if self_attention_model is 'rel_pos_local_attn'.
Defaults to 0.
global_tokens_spacing (int): how far apart the global tokens are
Defaults to 1.
global_attn_separate (bool): whether the q, k, v layers used for global tokens should be separate.
Defaults to False.
n_heads (int): number of heads for multi-head attention
conv_kernel_size (int): kernel size for depthwise convolution in convolution module
dropout (float): dropout probabilities for linear layers
dropout_att (float): dropout probabilities for attention distributions
use_bias (bool): Apply bias to all Linear and Conv1d layers from each ConformerLayer to improve activation flow and stabilize training of huge models.
Defaults to True.
"""
def __init__(
self,
d_model,
d_ff,
self_attention_model='rel_pos',
global_tokens=0,
global_tokens_spacing=1,
global_attn_separate=False,
n_heads=4,
conv_kernel_size=31,
conv_norm_type='batch_norm',
conv_context_size=None,
dropout=0.1,
dropout_att=0.1,
pos_bias_u=None,
pos_bias_v=None,
att_context_size=[-1, -1],
use_bias=True,
use_pytorch_sdpa=False,
use_pytorch_sdpa_backends=None,
):
super().__init__()
self.use_pytorch_sdpa = use_pytorch_sdpa
if use_pytorch_sdpa_backends is None:
use_pytorch_sdpa_backends = []
self.use_pytorch_sdpa_backends = use_pytorch_sdpa_backends
self.self_attention_model = self_attention_model
self.n_heads = n_heads
self.fc_factor = 0.5
# first feed forward module
self.norm_feed_forward1 = nn.LayerNorm(d_model)
self.feed_forward1 = ConformerFeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout, use_bias=use_bias)
# convolution module
self.norm_conv = nn.LayerNorm(d_model)
self.conv = ConformerConvolution(
d_model=d_model,
kernel_size=conv_kernel_size,
norm_type=conv_norm_type,
conv_context_size=conv_context_size,
use_bias=use_bias,
)
# multi-headed self-attention module
self.norm_self_att = nn.LayerNorm(d_model)
MHA_max_cache_len = att_context_size[0]
if self_attention_model == 'rel_pos':
self.self_attn = RelPositionMultiHeadAttention(
n_head=n_heads,
n_feat=d_model,
dropout_rate=dropout_att,
pos_bias_u=pos_bias_u,
pos_bias_v=pos_bias_v,
max_cache_len=MHA_max_cache_len,
use_bias=use_bias,
use_pytorch_sdpa=self.use_pytorch_sdpa,
use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends,
)
elif self_attention_model == 'rel_pos_local_attn':
self.self_attn = RelPositionMultiHeadAttentionLongformer(
n_head=n_heads,
n_feat=d_model,
dropout_rate=dropout_att,
pos_bias_u=pos_bias_u,
pos_bias_v=pos_bias_v,
max_cache_len=MHA_max_cache_len,
att_context_size=att_context_size,
global_tokens=global_tokens,
global_tokens_spacing=global_tokens_spacing,
global_attn_separate=global_attn_separate,
use_bias=use_bias,
)
elif self_attention_model == 'abs_pos':
self.self_attn = MultiHeadAttention(
n_head=n_heads,
n_feat=d_model,
dropout_rate=dropout_att,
max_cache_len=MHA_max_cache_len,
use_bias=use_bias,
use_pytorch_sdpa=self.use_pytorch_sdpa,
use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends,
)
else:
raise ValueError(
f"'{self_attention_model}' is not not a valid value for 'self_attention_model', "
f"valid values can be from ['rel_pos', 'rel_pos_local_attn', 'abs_pos']"
)
# second feed forward module
self.norm_feed_forward2 = nn.LayerNorm(d_model)
self.feed_forward2 = ConformerFeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout, use_bias=use_bias)
self.dropout = nn.Dropout(dropout)
self.norm_out = nn.LayerNorm(d_model)
def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None, cache_last_channel=None, cache_last_time=None):
"""
Args:
x (torch.Tensor): input signals (B, T, d_model)
att_mask (torch.Tensor): attention masks(B, T, T)
pos_emb (torch.Tensor): (L, 1, d_model)
pad_mask (torch.tensor): padding mask
cache_last_channel (torch.tensor) : cache for MHA layers (B, T_cache, d_model)
cache_last_time (torch.tensor) : cache for convolutional layers (B, d_model, T_cache)
Returns:
x (torch.Tensor): (B, T, d_model)
cache_last_channel (torch.tensor) : next cache for MHA layers (B, T_cache, d_model)
cache_last_time (torch.tensor) : next cache for convolutional layers (B, d_model, T_cache)
"""
residual = x
x = self.norm_feed_forward1(x)
x = self.feed_forward1(x)
residual = residual + self.dropout(x) * self.fc_factor
x = self.norm_self_att(residual)
if self.self_attention_model == 'rel_pos':
x = self.self_attn(query=x, key=x, value=x, mask=att_mask, pos_emb=pos_emb, cache=cache_last_channel)
elif self.self_attention_model == 'rel_pos_local_attn':
x = self.self_attn(query=x, key=x, value=x, pad_mask=pad_mask, pos_emb=pos_emb, cache=cache_last_channel)
elif self.self_attention_model == 'abs_pos':
x = self.self_attn(query=x, key=x, value=x, mask=att_mask, cache=cache_last_channel)
else:
x = None
if x is not None and cache_last_channel is not None:
(x, cache_last_channel) = x
residual = residual + self.dropout(x)
# if self.is_adapter_available():
# # Call the MHA adapters
# pack_input = {
# 'x': residual,
# 'loc': 'mha',
# 'att_mask': att_mask,
# 'pos_emb': pos_emb,
# }
# pack_input = self.forward_enabled_adapters(pack_input)
# residual = pack_input['x']
x = self.norm_conv(residual)
x = self.conv(x, pad_mask=pad_mask, cache=cache_last_time)
if cache_last_time is not None:
(x, cache_last_time) = x
residual = residual + self.dropout(x)
x = self.norm_feed_forward2(residual)
x = self.feed_forward2(x)
residual = residual + self.dropout(x) * self.fc_factor
x = self.norm_out(residual)
# if self.is_adapter_available():
# # Call the adapters
# pack_input = {
# 'x': x,
# 'loc': 'post',
# }
# pack_input = self.forward_enabled_adapters(pack_input)
# x = pack_input['x']
# if self.is_access_enabled(getattr(self, "model_guid", None)) and self.access_cfg.get(
# 'save_encoder_tensors', False
# ):
# self.register_accessible_tensor(name='encoder', tensor=x)
if cache_last_channel is None:
return x
else:
return x, cache_last_channel, cache_last_time
class ConformerConvolution(nn.Module):
"""The convolution module for the Conformer model.
Args:
d_model (int): hidden dimension
kernel_size (int): kernel size for depthwise convolution
pointwise_activation (str): name of the activation function to be used for the pointwise conv.
Note that Conformer uses a special key `glu_` which is treated as the original default from
the paper.
use_bias (bool): Use bias in all Linear and Conv1d layers improve activation flow and stabilize training of huge models.
Defaults to True
"""
def __init__(
self,
d_model,
kernel_size,
norm_type='batch_norm',
conv_context_size=None,
pointwise_activation='glu_',
use_bias=True,
):
super(ConformerConvolution, self).__init__()
assert (kernel_size - 1) % 2 == 0
self.d_model = d_model
self.kernel_size = kernel_size
self.norm_type = norm_type
self.use_bias = use_bias
if conv_context_size is None:
conv_context_size = (kernel_size - 1) // 2
# if pointwise_activation in activation_registry:
# self.pointwise_activation = activation_registry[pointwise_activation]()
# dw_conv_input_dim = d_model * 2
# if hasattr(self.pointwise_activation, 'inplace'):
# self.pointwise_activation.inplace = True
# else:
assert pointwise_activation == 'glu_'
self.pointwise_activation = pointwise_activation
dw_conv_input_dim = d_model
self.pointwise_conv1 = nn.Conv1d(
in_channels=d_model,
out_channels=d_model * 2,
kernel_size=1,
stride=1,
padding=0,
bias=self.use_bias,
)
self.depthwise_conv = CausalConv1D(
in_channels=dw_conv_input_dim,
out_channels=dw_conv_input_dim,
kernel_size=kernel_size,
stride=1,
padding=conv_context_size,
groups=dw_conv_input_dim,
bias=self.use_bias,
)
if norm_type == 'batch_norm':
self.batch_norm = nn.BatchNorm1d(dw_conv_input_dim)
elif norm_type == 'instance_norm':
self.batch_norm = nn.InstanceNorm1d(dw_conv_input_dim)
elif norm_type == 'layer_norm':
self.batch_norm = nn.LayerNorm(dw_conv_input_dim)
elif norm_type == 'fused_batch_norm':
self.batch_norm = FusedBatchNorm1d(dw_conv_input_dim)
elif norm_type.startswith('group_norm'):
num_groups = int(norm_type.replace("group_norm", ""))
self.batch_norm = nn.GroupNorm(num_groups=num_groups, num_channels=d_model)
else:
raise ValueError(f"conv_norm_type={norm_type} is not valid!")
self.activation = nn.SiLU()
self.pointwise_conv2 = nn.Conv1d(
in_channels=dw_conv_input_dim,
out_channels=d_model,
kernel_size=1,
stride=1,
padding=0,
bias=self.use_bias,
)
def forward(self, x, pad_mask=None, cache=None):
x = x.transpose(1, 2)
x = self.pointwise_conv1(x)
# Compute the activation function or use GLU for original Conformer
if self.pointwise_activation == 'glu_':
x = nn.functional.glu(x, dim=1)
else:
x = self.pointwise_activation(x)
if pad_mask is not None:
x = x.masked_fill(pad_mask.unsqueeze(1), 0.0)
x = self.depthwise_conv(x, cache=cache)
if cache is not None:
x, cache = x
if self.norm_type == "layer_norm":
x = x.transpose(1, 2)
x = self.batch_norm(x)
x = x.transpose(1, 2)
else:
x = self.batch_norm(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
x = x.transpose(1, 2)
if cache is None:
return x
else:
return x, cache
def reset_parameters_conv(self):
pw1_max = pw2_max = self.d_model**-0.5
dw_max = self.kernel_size**-0.5
with torch.no_grad():
nn.init.uniform_(self.pointwise_conv1.weight, -pw1_max, pw1_max)
nn.init.uniform_(self.pointwise_conv2.weight, -pw2_max, pw2_max)
nn.init.uniform_(self.depthwise_conv.weight, -dw_max, dw_max)
if self.use_bias:
nn.init.uniform_(self.pointwise_conv1.bias, -pw1_max, pw1_max)
nn.init.uniform_(self.pointwise_conv2.bias, -pw2_max, pw2_max)
nn.init.uniform_(self.depthwise_conv.bias, -dw_max, dw_max)
class ConformerFeedForward(nn.Module):
"""
feed-forward module of Conformer model.
use_bias (bool): Apply bias to all Linear and Conv1d layers improve activation flow and stabilize training of huge models.
"""
def __init__(self, d_model, d_ff, dropout, activation=nn.SiLU(), use_bias=True):
super().__init__()
self.d_model = d_model
self.d_ff = d_ff
self.use_bias = use_bias
self.linear1 = nn.Linear(d_model, d_ff, bias=self.use_bias)
self.activation = activation
self.dropout = nn.Dropout(p=dropout)
self.linear2 = nn.Linear(d_ff, d_model, bias=self.use_bias)
def forward(self, x):
x = self.linear1(x)
x = self.activation(x)
x = self.dropout(x)
x = self.linear2(x)
return x
def reset_parameters_ff(self):
ffn1_max = self.d_model**-0.5
ffn2_max = self.d_ff**-0.5
with torch.no_grad():
nn.init.uniform_(self.linear1.weight, -ffn1_max, ffn1_max)
nn.init.uniform_(self.linear2.weight, -ffn2_max, ffn2_max)
if self.use_bias:
nn.init.uniform_(self.linear1.bias, -ffn1_max, ffn1_max)
nn.init.uniform_(self.linear2.bias, -ffn2_max, ffn2_max)
class CausalConv1D(nn.Conv1d):
"""
A causal version of nn.Conv1d where each step would have limited access to locations on its right or left
All arguments are the same as nn.Conv1d except padding.
If padding is set None, then paddings are set automatically to make it a causal convolution where each location would not see any steps on its right.
If padding is set as a list (size of 2), then padding[0] would be used as left padding and padding[1] as right padding.
It would make it possible to control the number of steps to be accessible on the right and left.
This mode is not supported when stride > 1. padding[0]+padding[1] should be equal to (kernel_size - 1).
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: str | int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None,
) -> None:
self.cache_drop_size = None
if padding is None:
self._left_padding = kernel_size - 1
self._right_padding = stride - 1
else:
if stride != 1 and padding != kernel_size - 1:
raise ValueError("No striding allowed for non-symmetric convolutions!")
if isinstance(padding, int):
self._left_padding = padding
self._right_padding = padding
elif isinstance(padding, list) and len(padding) == 2 and padding[0] + padding[1] == kernel_size - 1:
self._left_padding = padding[0]
self._right_padding = padding[1]
else:
raise ValueError(f"Invalid padding param: {padding}!")
self._max_cache_len = self._left_padding
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
def update_cache(self, x, cache=None):
if cache is None:
new_x = F.pad(x, pad=(self._left_padding, self._right_padding))
next_cache = cache
else:
new_x = F.pad(x, pad=(0, self._right_padding))
new_x = torch.cat([cache, new_x], dim=-1)
if self.cache_drop_size > 0:
next_cache = new_x[:, :, : -self.cache_drop_size]
else:
next_cache = new_x
next_cache = next_cache[:, :, -cache.size(-1) :]
return new_x, next_cache
def forward(self, x, cache=None):
x, cache = self.update_cache(x, cache=cache)
x = super().forward(x)
if cache is None:
return x
else:
return x, cache
```
## /src/liquid_audio/model/conformer/processor.py
```py path="/src/liquid_audio/model/conformer/processor.py"
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""NeMo audio processing
Adapted from https://github.com/NVIDIA/NeMo/blob/09a962536c7a52f1964224e6e687ffb4a34fef79/nemo/collections/asr/modules/audio_preprocessing.py#L61
"""
import logging
from abc import ABC, abstractmethod
import random
import librosa
import torch
from torch import nn
class AudioPreprocessor(nn.Module, ABC):
"""
An interface for Neural Modules that performs audio pre-processing,
transforming the wav files to features.
"""
def __init__(self, win_length, hop_length):
super().__init__()
self.win_length = win_length
self.hop_length = hop_length
self.torch_windows = {
'hann': torch.hann_window,
'hamming': torch.hamming_window,
'blackman': torch.blackman_window,
'bartlett': torch.bartlett_window,
'ones': torch.ones,
None: torch.ones,
}
# Normally, when you call to(dtype) on a torch.nn.Module, all
# floating point parameters and buffers will change to that
# dtype, rather than being float32. The AudioPreprocessor
# classes, uniquely, don't actually have any parameters or
# buffers from what I see. In addition, we want the input to
# the preprocessor to be float32, but need to create the
# output in appropriate precision. We have this empty tensor
# here just to detect which dtype tensor this module should
# output at the end of execution.
self.register_buffer("dtype_sentinel_tensor", torch.tensor((), dtype=torch.float32), persistent=False)
@torch.no_grad()
def forward(self, input_signal, length):
if input_signal.dtype != torch.float32:
logging.warning(
f"AudioPreprocessor received an input signal of dtype {input_signal.dtype}, rather than torch.float32. In sweeps across multiple datasets, we have found that the preprocessor is not robust to low precision mathematics. As such, it runs in float32. Your input will be cast to float32, but this is not necessarily enough to recovery full accuracy. For example, simply casting input_signal from torch.float32 to torch.bfloat16, then back to torch.float32 before running AudioPreprocessor causes drops in absolute WER of up to 0.1%. torch.bfloat16 simply does not have enough mantissa bits to represent enough values in the range [-1.0,+1.0] correctly.",
)
processed_signal, processed_length = self.get_features(input_signal.to(torch.float32), length)
processed_signal = processed_signal.to(self.dtype_sentinel_tensor.dtype)
return processed_signal, processed_length
@abstractmethod
def get_features(self, input_signal, length):
# Called by forward(). Subclasses should implement this.
pass
class AudioToMelSpectrogramPreprocessor(AudioPreprocessor):
"""Featurizer module that converts wavs to mel spectrograms.
Args:
sample_rate (int): Sample rate of the input audio data.
Defaults to 16000
window_size (float): Size of window for fft in seconds
Defaults to 0.02
window_stride (float): Stride of window for fft in seconds
Defaults to 0.01
n_window_size (int): Size of window for fft in samples
Defaults to None. Use one of window_size or n_window_size.
n_window_stride (int): Stride of window for fft in samples
Defaults to None. Use one of window_stride or n_window_stride.
window (str): Windowing function for fft. can be one of ['hann',
'hamming', 'blackman', 'bartlett']
Defaults to "hann"
normalize (str): Can be one of ['per_feature', 'all_features']; all
other options disable feature normalization. 'all_features'
normalizes the entire spectrogram to be mean 0 with std 1.
'pre_features' normalizes per channel / freq instead.
Defaults to "per_feature"
n_fft (int): Length of FT window. If None, it uses the smallest power
of 2 that is larger than n_window_size.
Defaults to None
preemph (float): Amount of pre emphasis to add to audio. Can be
disabled by passing None.
Defaults to 0.97
features (int): Number of mel spectrogram freq bins to output.
Defaults to 64
lowfreq (int): Lower bound on mel basis in Hz.
Defaults to 0
highfreq (int): Lower bound on mel basis in Hz.
Defaults to None
log (bool): Log features.
Defaults to True
log_zero_guard_type(str): Need to avoid taking the log of zero. There
are two options: "add" or "clamp".
Defaults to "add".
log_zero_guard_value(float, or str): Add or clamp requires the number
to add with or clamp to. log_zero_guard_value can either be a float
or "tiny" or "eps". torch.finfo is used if "tiny" or "eps" is
passed.
Defaults to 2**-24.
dither (float): Amount of white-noise dithering.
Defaults to 1e-5
pad_to (int): Ensures that the output size of the time dimension is
a multiple of pad_to.
Defaults to 16
frame_splicing (int): Defaults to 1
exact_pad (bool): If True, sets stft center to False and adds padding, such that num_frames = audio_length
// hop_length. Defaults to False.
pad_value (float): The value that shorter mels are padded with.
Defaults to 0
mag_power (float): The power that the linear spectrogram is raised to
prior to multiplication with mel basis.
Defaults to 2 for a power spec
rng : Random number generator
nb_augmentation_prob (float) : Probability with which narrowband augmentation would be applied to
samples in the batch.
Defaults to 0.0
nb_max_freq (int) : Frequency above which all frequencies will be masked for narrowband augmentation.
Defaults to 4000
use_torchaudio: Whether to use the `torchaudio` implementation.
mel_norm: Normalization used for mel filterbank weights.
Defaults to 'slaney' (area normalization)
stft_exact_pad: Deprecated argument, kept for compatibility with older checkpoints.
stft_conv: Deprecated argument, kept for compatibility with older checkpoints.
"""
def save_to(self, save_path: str):
pass
@classmethod
def restore_from(cls, restore_path: str):
pass
@property
def input_types(self):
"""Returns definitions of module input ports."""
return {
"input_signal": NeuralType(('B', 'T'), AudioSignal(freq=self._sample_rate)),
"length": NeuralType(
tuple('B'), LengthsType()
), # Please note that length should be in samples not seconds.
}
@property
def output_types(self):
"""Returns definitions of module output ports.
processed_signal:
0: AxisType(BatchTag)
1: AxisType(MelSpectrogramSignalTag)
2: AxisType(ProcessedTimeTag)
processed_length:
0: AxisType(BatchTag)
"""
return {
"processed_signal": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
"processed_length": NeuralType(tuple('B'), LengthsType()),
}
def __init__(
self,
sample_rate=16000,
window_size=0.02,
window_stride=0.01,
n_window_size=None,
n_window_stride=None,
window="hann",
normalize="per_feature",
n_fft=None,
preemph=0.97,
features=64,
lowfreq=0,
highfreq=None,
log=True,
log_zero_guard_type="add",
log_zero_guard_value=2**-24,
dither=1e-5,
pad_to=16,
frame_splicing=1,
exact_pad=False,
pad_value=0,
mag_power=2.0,
rng=None,
nb_augmentation_prob=0.0,
nb_max_freq=4000,
use_torchaudio: bool = False,
mel_norm="slaney",
stft_exact_pad=False, # Deprecated arguments; kept for config compatibility
stft_conv=False, # Deprecated arguments; kept for config compatibility
):
self._sample_rate = sample_rate
if window_size and n_window_size:
raise ValueError(f"{self} received both window_size and " f"n_window_size. Only one should be specified.")
if window_stride and n_window_stride:
raise ValueError(
f"{self} received both window_stride and " f"n_window_stride. Only one should be specified."
)
if window_size:
n_window_size = int(window_size * self._sample_rate)
if window_stride:
n_window_stride = int(window_stride * self._sample_rate)
super().__init__(n_window_size, n_window_stride)
# Given the long and similar argument list, point to the class and instantiate it by reference
if not use_torchaudio:
featurizer_class = FilterbankFeatures
else:
featurizer_class = FilterbankFeaturesTA
self.featurizer = featurizer_class(
sample_rate=self._sample_rate,
n_window_size=n_window_size,
n_window_stride=n_window_stride,
window=window,
normalize=normalize,
n_fft=n_fft,
preemph=preemph,
nfilt=features,
lowfreq=lowfreq,
highfreq=highfreq,
log=log,
log_zero_guard_type=log_zero_guard_type,
log_zero_guard_value=log_zero_guard_value,
dither=dither,
pad_to=pad_to,
frame_splicing=frame_splicing,
exact_pad=exact_pad,
pad_value=pad_value,
mag_power=mag_power,
rng=rng,
nb_augmentation_prob=nb_augmentation_prob,
nb_max_freq=nb_max_freq,
mel_norm=mel_norm,
stft_exact_pad=stft_exact_pad, # Deprecated arguments; kept for config compatibility
stft_conv=stft_conv, # Deprecated arguments; kept for config compatibility
)
def input_example(self, max_batch: int = 8, max_dim: int = 32000, min_length: int = 200):
dev = self.filter_banks.device
signals = torch.randn(size=[max_batch, max_dim], device=dev)
lengths = torch.randint(low=min_length, high=max_dim, size=[max_batch], device=dev)
lengths[0] = max_dim
return signals, lengths
def get_features(self, input_signal, length):
return self.featurizer(input_signal, length)
@property
def filter_banks(self):
return self.featurizer.filter_banks
CONSTANT = 1e-5
class FilterbankFeatures(nn.Module):
"""Featurizer that converts wavs to Mel Spectrograms.
See AudioToMelSpectrogramPreprocessor for args.
"""
def __init__(
self,
sample_rate=16000,
n_window_size=320,
n_window_stride=160,
window="hann",
normalize="per_feature",
n_fft=None,
preemph=0.97,
nfilt=64,
lowfreq=0,
highfreq=None,
log=True,
log_zero_guard_type="add",
log_zero_guard_value=2**-24,
dither=CONSTANT,
pad_to=16,
max_duration=16.7,
frame_splicing=1,
exact_pad=False,
pad_value=0,
mag_power=2.0,
use_grads=False,
rng=None,
nb_augmentation_prob=0.0,
nb_max_freq=4000,
mel_norm="slaney",
stft_exact_pad=False, # Deprecated arguments; kept for config compatibility
stft_conv=False, # Deprecated arguments; kept for config compatibility
):
super().__init__()
if stft_conv or stft_exact_pad:
logging.warning(
"Using torch_stft is deprecated and has been removed. The values have been forcibly set to False "
"for FilterbankFeatures and AudioToMelSpectrogramPreprocessor. Please set exact_pad to True "
"as needed."
)
if exact_pad and n_window_stride % 2 == 1:
raise NotImplementedError(
f"{self} received exact_pad == True, but hop_size was odd. If audio_length % hop_size == 0. Then the "
"returned spectrogram would not be of length audio_length // hop_size. Please use an even hop_size."
)
self.log_zero_guard_value = log_zero_guard_value
if (
n_window_size is None
or n_window_stride is None
or not isinstance(n_window_size, int)
or not isinstance(n_window_stride, int)
or n_window_size <= 0
or n_window_stride <= 0
):
raise ValueError(
f"{self} got an invalid value for either n_window_size or "
f"n_window_stride. Both must be positive ints."
)
logging.info(f"PADDING: {pad_to}")
self.sample_rate = sample_rate
self.win_length = n_window_size
self.hop_length = n_window_stride
self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None
self.exact_pad = exact_pad
self.sample_rate = sample_rate
if exact_pad:
logging.info("STFT using exact pad")
torch_windows = {
'hann': torch.hann_window,
'hamming': torch.hamming_window,
'blackman': torch.blackman_window,
'bartlett': torch.bartlett_window,
'none': None,
}
window_fn = torch_windows.get(window, None)
window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None
self.register_buffer("window", window_tensor)
self.normalize = normalize
self.log = log
self.dither = dither
self.frame_splicing = frame_splicing
self.nfilt = nfilt
self.preemph = preemph
self.pad_to = pad_to
highfreq = highfreq or sample_rate / 2
filterbanks = torch.tensor(
librosa.filters.mel(
sr=sample_rate, n_fft=self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq, norm=mel_norm
),
dtype=torch.float,
).unsqueeze(0)
self.register_buffer("fb", filterbanks)
# Calculate maximum sequence length
max_length = self.get_seq_len(torch.tensor(max_duration * sample_rate, dtype=torch.float))
max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0
self.max_length = max_length + max_pad
self.pad_value = pad_value
self.mag_power = mag_power
# We want to avoid taking the log of zero
# There are two options: either adding or clamping to a small value
if log_zero_guard_type not in ["add", "clamp"]:
raise ValueError(
f"{self} received {log_zero_guard_type} for the "
f"log_zero_guard_type parameter. It must be either 'add' or "
f"'clamp'."
)
self.use_grads = use_grads
if not use_grads:
self.forward = torch.no_grad()(self.forward)
self._rng = random.Random() if rng is None else rng
self.nb_augmentation_prob = nb_augmentation_prob
if self.nb_augmentation_prob > 0.0:
if nb_max_freq >= sample_rate / 2:
self.nb_augmentation_prob = 0.0
else:
self._nb_max_fft_bin = int((nb_max_freq / sample_rate) * n_fft)
# log_zero_guard_value is the the small we want to use, we support
# an actual number, or "tiny", or "eps"
self.log_zero_guard_type = log_zero_guard_type
logging.debug(f"sr: {sample_rate}")
logging.debug(f"n_fft: {self.n_fft}")
logging.debug(f"win_length: {self.win_length}")
logging.debug(f"hop_length: {self.hop_length}")
logging.debug(f"n_mels: {nfilt}")
logging.debug(f"fmin: {lowfreq}")
logging.debug(f"fmax: {highfreq}")
logging.debug(f"using grads: {use_grads}")
logging.debug(f"nb_augmentation_prob: {nb_augmentation_prob}")
def stft(self, x):
return torch.stft(
x,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
center=False if self.exact_pad else True,
window=self.window.to(dtype=torch.float, device=x.device),
return_complex=True,
pad_mode="constant",
)
def log_zero_guard_value_fn(self, x):
if isinstance(self.log_zero_guard_value, str):
if self.log_zero_guard_value == "tiny":
return torch.finfo(x.dtype).tiny
elif self.log_zero_guard_value == "eps":
return torch.finfo(x.dtype).eps
else:
raise ValueError(
f"{self} received {self.log_zero_guard_value} for the "
f"log_zero_guard_type parameter. It must be either a "
f"number, 'tiny', or 'eps'"
)
else:
return self.log_zero_guard_value
def get_seq_len(self, seq_len):
# Assuming that center is True is stft_pad_amount = 0
pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2
seq_len = torch.floor_divide((seq_len + pad_amount - self.n_fft), self.hop_length)
return seq_len.to(dtype=torch.long)
@property
def filter_banks(self):
return self.fb
def forward(self, x, seq_len, linear_spec=False):
seq_len_time = seq_len
seq_len_unfixed = self.get_seq_len(seq_len)
# fix for seq_len = 0 for streaming; if size was 0, it is always padded to 1, and normalizer fails
seq_len = torch.where(seq_len == 0, torch.zeros_like(seq_len_unfixed), seq_len_unfixed)
if self.stft_pad_amount is not None:
x = torch.nn.functional.pad(
x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "constant"
).squeeze(1)
# dither (only in training mode for eval determinism)
if self.training and self.dither > 0:
x += self.dither * torch.randn_like(x)
# do preemphasis
if self.preemph is not None:
timemask = torch.arange(x.shape[1], device=x.device).unsqueeze(0) < seq_len_time.unsqueeze(1)
x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1)
x = x.masked_fill(~timemask, 0.0)
# disable autocast to get full range of stft values
with torch.amp.autocast(x.device.type, enabled=False):
x = self.stft(x)
# torch stft returns complex tensor (of shape [B,N,T]); so convert to magnitude
# guard is needed for sqrt if grads are passed through
guard = 0 if not self.use_grads else CONSTANT
x = torch.view_as_real(x)
x = torch.sqrt(x.pow(2).sum(-1) + guard)
if self.training and self.nb_augmentation_prob > 0.0:
for idx in range(x.shape[0]):
if self._rng.random() < self.nb_augmentation_prob:
x[idx, self._nb_max_fft_bin :, :] = 0.0
# get power spectrum
if self.mag_power != 1.0:
x = x.pow(self.mag_power)
# return plain spectrogram if required
if linear_spec:
return x, seq_len
# disable autocast, otherwise it might be automatically casted to fp16
# on fp16 compatible GPUs and get NaN values for input value of 65520
with torch.amp.autocast(x.device.type, enabled=False):
# dot with filterbank energies
x = torch.matmul(self.fb.to(x.dtype), x)
# log features if required
if self.log:
if self.log_zero_guard_type == "add":
x = torch.log(x + self.log_zero_guard_value_fn(x))
elif self.log_zero_guard_type == "clamp":
x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x)))
else:
raise ValueError("log_zero_guard_type was not understood")
# frame splicing if required
if self.frame_splicing > 1:
x = splice_frames(x, self.frame_splicing)
# normalize if required
if self.normalize:
x, _, _ = normalize_batch(x, seq_len, normalize_type=self.normalize)
# mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency)
max_len = x.size(-1)
mask = torch.arange(max_len, device=x.device)
mask = mask.repeat(x.size(0), 1) >= seq_len.unsqueeze(1)
x = x.masked_fill(mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value)
del mask
pad_to = self.pad_to
if pad_to == "max":
x = nn.functional.pad(x, (0, self.max_length - x.size(-1)), value=self.pad_value)
elif pad_to > 0:
pad_amt = x.size(-1) % pad_to
if pad_amt != 0:
x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value)
return x, seq_len
def normalize_batch(x, seq_len, normalize_type):
x_mean = None
x_std = None
if normalize_type == "per_feature":
batch_size = x.shape[0]
max_time = x.shape[2]
# When doing stream capture to a graph, item() is not allowed
# becuase it calls cudaStreamSynchronize(). Therefore, we are
# sacrificing some error checking when running with cuda graphs.
if (
torch.cuda.is_available()
and not torch.cuda.is_current_stream_capturing()
and torch.any(seq_len == 1).item()
):
raise ValueError(
"normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result "
"in torch.std() returning nan. Make sure your audio length has enough samples for a single "
"feature (ex. at least `hop_length` for Mel Spectrograms)."
)
time_steps = torch.arange(max_time, device=x.device).unsqueeze(0).expand(batch_size, max_time)
valid_mask = time_steps < seq_len.unsqueeze(1)
x_mean_numerator = torch.where(valid_mask.unsqueeze(1), x, 0.0).sum(axis=2)
x_mean_denominator = valid_mask.sum(axis=1)
x_mean = x_mean_numerator / x_mean_denominator.unsqueeze(1)
# Subtract 1 in the denominator to correct for the bias.
x_std = torch.sqrt(
torch.sum(torch.where(valid_mask.unsqueeze(1), x - x_mean.unsqueeze(2), 0.0) ** 2, axis=2)
/ (x_mean_denominator.unsqueeze(1) - 1.0)
)
x_std = x_std.masked_fill(x_std.isnan(), 0.0) # edge case: only 1 frame in denominator
# make sure x_std is not zero
x_std += CONSTANT
return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2), x_mean, x_std
elif normalize_type == "all_features":
x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
for i in range(x.shape[0]):
x_mean[i] = x[i, :, : seq_len[i].item()].mean()
x_std[i] = x[i, :, : seq_len[i].item()].std()
# make sure x_std is not zero
x_std += CONSTANT
return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1), x_mean, x_std
elif "fixed_mean" in normalize_type and "fixed_std" in normalize_type:
x_mean = torch.tensor(normalize_type["fixed_mean"], device=x.device)
x_std = torch.tensor(normalize_type["fixed_std"], device=x.device)
return (
(x - x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2)) / x_std.view(x.shape[0], x.shape[1]).unsqueeze(2),
x_mean,
x_std,
)
else:
return x, x_mean, x_std
```
## /src/liquid_audio/model/conformer/subsampling.py
```py path="/src/liquid_audio/model/conformer/subsampling.py"
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""NeMo subsampling
adapted from https://github.com/NVIDIA/NeMo/blob/c83adff36efaa549f7bdd26e97c01a60e9f9026b/nemo/collections/asr/parts/submodules/subsampling.py
"""
import logging
import math
import torch
from torch import nn
logger = logging.getLogger(__name__)
class ConvSubsampling(torch.nn.Module):
"""Convolutional subsampling which supports VGGNet and striding approach introduced in:
VGGNet Subsampling: Transformer-transducer: end-to-end speech recognition with self-attention (https://arxiv.org/pdf/1910.12977.pdf)
Striding Subsampling: "Speech-Transformer: A No-Recurrence Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong et al. (https://ieeexplore.ieee.org/document/8462506)
Args:
subsampling (str): The subsampling technique from {"vggnet", "striding", "dw-striding"}
subsampling_factor (int): The subsampling factor which should be a power of 2
subsampling_conv_chunking_factor (int): Input chunking factor which can be -1 (no chunking)
1 (auto) or a power of 2. Default is 1
feat_in (int): size of the input features
feat_out (int): size of the output features
conv_channels (int): Number of channels for the convolution layers.
activation (Module): activation function, default is nn.ReLU()
"""
def __init__(
self,
subsampling,
subsampling_factor,
feat_in,
feat_out,
conv_channels,
subsampling_conv_chunking_factor=1,
activation=nn.ReLU(),
is_causal=False,
):
super().__init__()
self._subsampling = subsampling
self._conv_channels = conv_channels
self._feat_in = feat_in
self._feat_out = feat_out
if subsampling_factor % 2 != 0:
raise ValueError("Sampling factor should be a multiply of 2!")
self._sampling_num = int(math.log2(subsampling_factor))
self.subsampling_factor = subsampling_factor
self.is_causal = is_causal
if (
subsampling_conv_chunking_factor != -1
and subsampling_conv_chunking_factor != 1
and subsampling_conv_chunking_factor % 2 != 0
):
raise ValueError("subsampling_conv_chunking_factor should be -1, 1, or a power of 2")
self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
in_channels = 1
layers = []
if subsampling == 'vggnet':
self._stride = 2
self._kernel_size = 2
self._ceil_mode = True
self._left_padding = 0
self._right_padding = 0
for i in range(self._sampling_num):
layers.append(
torch.nn.Conv2d(
in_channels=in_channels, out_channels=conv_channels, kernel_size=3, stride=1, padding=1
)
)
layers.append(activation)
layers.append(
torch.nn.Conv2d(
in_channels=conv_channels, out_channels=conv_channels, kernel_size=3, stride=1, padding=1
)
)
layers.append(activation)
layers.append(
torch.nn.MaxPool2d(
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._left_padding,
ceil_mode=self._ceil_mode,
)
)
in_channels = conv_channels
elif subsampling == 'dw_striding':
self._stride = 2
self._kernel_size = 3
self._ceil_mode = False
if self.is_causal:
self._left_padding = self._kernel_size - 1
self._right_padding = self._stride - 1
self._max_cache_len = subsampling_factor + 1
else:
self._left_padding = (self._kernel_size - 1) // 2
self._right_padding = (self._kernel_size - 1) // 2
self._max_cache_len = 0
# Layer 1
if self.is_causal:
layers.append(
CausalConv2D(
in_channels=in_channels,
out_channels=conv_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=None,
)
)
else:
layers.append(
torch.nn.Conv2d(
in_channels=in_channels,
out_channels=conv_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._left_padding,
)
)
in_channels = conv_channels
layers.append(activation)
for i in range(self._sampling_num - 1):
if self.is_causal:
layers.append(
CausalConv2D(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=None,
groups=in_channels,
)
)
else:
layers.append(
torch.nn.Conv2d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._left_padding,
groups=in_channels,
)
)
layers.append(
torch.nn.Conv2d(
in_channels=in_channels,
out_channels=conv_channels,
kernel_size=1,
stride=1,
padding=0,
groups=1,
)
)
layers.append(activation)
in_channels = conv_channels
elif subsampling == 'striding':
self._stride = 2
self._kernel_size = 3
self._ceil_mode = False
if self.is_causal:
self._left_padding = self._kernel_size - 1
self._right_padding = self._stride - 1
self._max_cache_len = subsampling_factor + 1
else:
self._left_padding = (self._kernel_size - 1) // 2
self._right_padding = (self._kernel_size - 1) // 2
self._max_cache_len = 0
for i in range(self._sampling_num):
if self.is_causal:
layers.append(
CausalConv2D(
in_channels=in_channels,
out_channels=conv_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=None,
)
)
else:
layers.append(
torch.nn.Conv2d(
in_channels=in_channels,
out_channels=conv_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._left_padding,
)
)
layers.append(activation)
in_channels = conv_channels
elif subsampling == 'striding_conv1d':
in_channels = feat_in
self._stride = 2
self._kernel_size = 5
self._ceil_mode = False
if self.is_causal:
self._left_padding = self._kernel_size - 1
self._right_padding = self._stride - 1
self._max_cache_len = subsampling_factor + 1
else:
self._left_padding = (self._kernel_size - 1) // 2
self._right_padding = (self._kernel_size - 1) // 2
self._max_cache_len = 0
for i in range(self._sampling_num):
if self.is_causal:
layers.append(
CausalConv1D(
in_channels=in_channels,
out_channels=feat_out if self._sampling_num == i + 1 else conv_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=None,
)
)
else:
layers.append(
torch.nn.Conv1d(
in_channels=in_channels,
out_channels=feat_out if self._sampling_num == i + 1 else conv_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._left_padding,
)
)
layers.append(activation)
in_channels = conv_channels
elif subsampling == 'dw_striding_conv1d':
in_channels = feat_in
self._stride = 2
self._kernel_size = 5
self._ceil_mode = False
self._left_padding = (self._kernel_size - 1) // 2
self._right_padding = (self._kernel_size - 1) // 2
# Layer 1
layers.extend(
[
torch.nn.Conv1d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._left_padding,
groups=in_channels,
),
torch.nn.Conv1d(
in_channels=in_channels,
out_channels=feat_out if self._sampling_num == 1 else conv_channels,
kernel_size=1,
stride=1,
padding=0,
groups=1,
),
]
)
in_channels = conv_channels
layers.append(activation)
for i in range(self._sampling_num - 1):
layers.extend(
[
torch.nn.Conv1d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._left_padding,
groups=in_channels,
),
torch.nn.Conv1d(
in_channels=in_channels,
out_channels=feat_out if self._sampling_num == i + 2 else conv_channels,
kernel_size=1,
stride=1,
padding=0,
groups=1,
),
]
)
layers.append(activation)
in_channels = conv_channels
else:
raise ValueError(f"Not valid sub-sampling: {subsampling}!")
if subsampling in ["vggnet", "dw_striding", "striding"]:
in_length = torch.tensor(feat_in, dtype=torch.float)
out_length = calc_length(
lengths=in_length,
all_paddings=self._left_padding + self._right_padding,
kernel_size=self._kernel_size,
stride=self._stride,
ceil_mode=self._ceil_mode,
repeat_num=self._sampling_num,
)
self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out)
self.conv2d_subsampling = True
elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]:
self.out = None
self.conv2d_subsampling = False
else:
raise ValueError(f"Not valid sub-sampling: {subsampling}!")
self.conv = MaskedConvSequential(*layers)
def get_sampling_frames(self):
return [1, self.subsampling_factor]
def get_streaming_cache_size(self):
return [0, self.subsampling_factor + 1]
def forward(self, x, lengths):
out_lengths = calc_length(
lengths,
all_paddings=self._left_padding + self._right_padding,
kernel_size=self._kernel_size,
stride=self._stride,
ceil_mode=self._ceil_mode,
repeat_num=self._sampling_num,
)
# Transpose to Channel First mode
if not self.conv2d_subsampling:
x = x.transpose(1, 2)
# split inputs if chunking_factor is set
if self.subsampling_conv_chunking_factor != -1 and self.conv2d_subsampling:
if self.subsampling_conv_chunking_factor == 1:
# if subsampling_conv_chunking_factor is 1, we split only if needed
# avoiding a bug / feature limiting indexing of tensors to 2**31
# see https://github.com/pytorch/pytorch/issues/80020
# Fix NeMo bug: compute exactly the largest output size
x_ceil = 2 ** 31
out_size = x.shape[0] * self._conv_channels * ((x.shape[1] + 1) // self._stride) * ((x.shape[2] + 1) // self._stride)
if out_size >= x_ceil:
need_to_split = True
else:
need_to_split = False
else:
# if subsampling_conv_chunking_factor > 1 we always split
need_to_split = True
if need_to_split:
x, lengths, success = self.conv_split_by_batch(x, lengths)
if not success: # if unable to split by batch, try by channel
if self._subsampling == 'dw_striding':
# TODO: implement lengths inside conv_split_by_channel
x = self.conv_split_by_channel(x)
lengths = out_lengths
else:
x, lengths = self.conv(x, lengths) # try anyway
else:
x, lengths = self.conv(x, lengths)
else:
x, lengths = self.conv(x)
# Flatten Channel and Frequency Axes
if self.conv2d_subsampling:
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).reshape(b, t, c*f))
# Transpose to Channel Last mode
else:
x = x.transpose(1, 2)
return x, lengths
def reset_parameters(self):
# initialize weights
if self._subsampling == 'dw_striding':
with torch.no_grad():
# init conv
scale = 1.0 / self._kernel_size
dw_max = (self._kernel_size**2) ** -0.5
pw_max = self._conv_channels**-0.5
torch.nn.init.uniform_(self.conv[0].weight, -scale, scale)
torch.nn.init.uniform_(self.conv[0].bias, -scale, scale)
for idx in range(2, len(self.conv), 3):
torch.nn.init.uniform_(self.conv[idx].weight, -dw_max, dw_max)
torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, dw_max)
torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max, pw_max)
torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max, pw_max)
# init fc (80 * 64 = 5120 from https://github.com/kssteven418/Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/src/models/conformer_encoder.py#L487
fc_scale = (self._feat_out * self._feat_in / self._sampling_num) ** -0.5
torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale)
torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale)
def conv_split_by_batch(self, x, lengths):
"""Tries to split input by batch, run conv and concat results"""
b, *_ = x.size()
if b == 1: # can't split if batch size is 1
return x, lengths, False
if self.subsampling_conv_chunking_factor > 1:
cf = self.subsampling_conv_chunking_factor
logger.debug(f'using manually set chunking factor: {cf}')
else:
# avoiding a bug / feature limiting indexing of tensors to 2**31
# see https://github.com/pytorch/pytorch/issues/80020
x_ceil = 2 ** 31
out_size = x.shape[0] * self._conv_channels * ((x.shape[1] + 1) // self._stride) * ((x.shape[2] + 1) // self._stride)
p = math.ceil(math.log2((out_size+1) / x_ceil))
cf = 2**p
logger.debug(f'using auto set chunking factor: {cf}')
new_batch_size = b // cf
if new_batch_size == 0: # input is too big
return x, lengths, False
logger.debug(f'conv subsampling: using split batch size {new_batch_size}')
ans = [
self.conv(chunk, ln)
for chunk, ln in zip(
torch.split(x, new_batch_size, 0),
torch.split(lengths, new_batch_size, 0),
)
]
return torch.cat([a[0] for a in ans]), torch.cat([a[1] for a in ans]), True
def conv_split_by_channel(self, x):
"""For dw convs, tries to split input by time, run conv and concat results"""
# Note: this method doesn't use the convolution masking implemented in MaskedConvolutionSequential
x = x.unsqueeze(0)
x = self.conv[0](x) # full conv2D
x = self.conv[1](x) # activation
for i in range(self._sampling_num - 1):
_, c, t, _ = x.size()
if self.subsampling_conv_chunking_factor > 1:
cf = self.subsampling_conv_chunking_factor
logger.debug(f'using manually set chunking factor: {cf}')
else:
# avoiding a bug / feature limiting indexing of tensors to 2**31
# see https://github.com/pytorch/pytorch/issues/80020
p = math.ceil(math.log(torch.numel(x) / 2**31, 2))
cf = 2**p
logger.debug(f'using auto set chunking factor: {cf}')
new_c = int(c // cf)
if new_c == 0:
logger.warning(f'chunking factor {cf} is too high; splitting down to one channel.')
new_c = 1
new_t = int(t // cf)
if new_t == 0:
logger.warning(f'chunking factor {cf} is too high; splitting down to one timestep.')
new_t = 1
logger.debug(f'conv dw subsampling: using split C size {new_c} and split T size {new_t}')
x = self.channel_chunked_conv(self.conv[i * 3 + 2], new_c, x) # conv2D, depthwise
# splitting pointwise convs by time
x = torch.cat([self.conv[i * 3 + 3](chunk) for chunk in torch.split(x, new_t, 2)], 2) # conv2D, pointwise
x = self.conv[i * 3 + 4](x) # activation
return x
def channel_chunked_conv(self, conv, chunk_size, x):
"""Performs channel chunked convolution"""
ind = 0
out_chunks = []
for chunk in torch.split(x, chunk_size, 1):
step = chunk.size()[1]
if self.is_causal:
chunk = nn.functional.pad(
chunk, pad=(self._kernel_size - 1, self._stride - 1, self._kernel_size - 1, self._stride - 1)
)
ch_out = nn.functional.conv2d(
chunk,
conv.weight[ind : ind + step, :, :, :],
bias=conv.bias[ind : ind + step],
stride=self._stride,
padding=0,
groups=step,
)
else:
ch_out = nn.functional.conv2d(
chunk,
conv.weight[ind : ind + step, :, :, :],
bias=conv.bias[ind : ind + step],
stride=self._stride,
padding=self._left_padding,
groups=step,
)
out_chunks.append(ch_out)
ind += step
return torch.cat(out_chunks, 1)
def change_subsampling_conv_chunking_factor(self, subsampling_conv_chunking_factor: int):
if (
subsampling_conv_chunking_factor != -1
and subsampling_conv_chunking_factor != 1
and subsampling_conv_chunking_factor % 2 != 0
):
raise ValueError("subsampling_conv_chunking_factor should be -1, 1, or a power of 2")
self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
def calc_length(lengths, all_paddings, kernel_size, stride, ceil_mode, repeat_num=1):
"""Calculates the output length of a Tensor passed through a convolution or max pooling layer"""
add_pad: float = all_paddings - kernel_size
one: float = 1.0
for i in range(repeat_num):
lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one
if ceil_mode:
lengths = torch.ceil(lengths)
else:
lengths = torch.floor(lengths)
return lengths.to(dtype=torch.int)
class MaskedConvSequential(nn.Sequential):
def forward(self, x, lengths):
# Convert input (batch, time, features) to conv format
x = x.unsqueeze(1) # (batch, 1, time, features)
current_lengths = lengths.clone().float()
mask = self._create_mask(x, current_lengths.long())
# Process through each layer with mask propagation
for i, layer in enumerate(self):
# Apply current mask before layer
x = apply_channel_mask(x, mask)
# Apply layer
x = layer(x)
# Update lengths for stride operations with proper padding
if hasattr(layer, 'stride') and layer.stride != (1, 1):
if hasattr(layer, "_left_padding"):
padding = (layer._left_padding, layer._right_padding) # CausalConv2D
else:
padding = layer.padding
current_lengths = calculate_conv_output_size(
current_lengths, layer.kernel_size[0], layer.stride[0], padding
)
mask = self._create_mask(x, current_lengths.long())
# Final masking
x = apply_channel_mask(x, mask)
return x, current_lengths.long()
def _create_mask(self, tensor, lengths):
"""Create mask matching tensor dimensions."""
batch_size, channels, time, features = tensor.shape
time_mask = torch.arange(time, device=tensor.device).expand(batch_size, time) < lengths.unsqueeze(1)
return time_mask.unsqueeze(-1).expand(batch_size, time, features).to(tensor.dtype)
def apply_channel_mask(tensor, mask):
"""Apply mask to tensor with channel dimension."""
# tensor: (batch, channels, time, features)
# mask: (batch, time, features)
batch_size, channels, time, features = tensor.shape
expanded_mask = mask.unsqueeze(1).expand(batch_size, channels, time, features)
return tensor * expanded_mask
def calculate_conv_output_size(input_size: torch.Tensor, kernel_size: int, stride: int, padding: tuple[int, int]):
"""Calculate exact output size after convolution."""
return (input_size + padding[0] + padding[1] - kernel_size) // stride + 1
```
## /src/liquid_audio/model/conformer/utils.py
```py path="/src/liquid_audio/model/conformer/utils.py"
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""NeMo casting utils
adapted from https://github.com/NVIDIA/NeMo/blob/c83adff36efaa549f7bdd26e97c01a60e9f9026b/nemo/utils/cast_utils.py
"""
from contextlib import nullcontext
from dataclasses import dataclass
import torch
def avoid_float16_autocast_context():
"""
If the current autocast context is float16, cast it to bfloat16
if available (unless we're in jit) or float32
"""
if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.float16:
if torch.jit.is_scripting() or torch.jit.is_tracing():
return torch.amp.autocast('cuda', dtype=torch.float32)
if torch.cuda.is_bf16_supported():
return torch.amp.autocast('cuda', dtype=torch.bfloat16)
else:
return torch.amp.autocast('cuda', dtype=torch.float32)
else:
return nullcontext()
@dataclass
class CacheAwareStreamingConfig:
chunk_size: int = (
0 # the size of each chunk at each step, it can be a list of two integers to specify different chunk sizes for the first step and others
)
shift_size: int = (
0 # the size of the shift in each step, it can be a list of two integers to specify different shift sizes for the first step and others
)
cache_drop_size: int = 0 # the number of steps to drop from the cache
last_channel_cache_size: int = 0 # the size of the needed cache for last channel layers
valid_out_len: int = (
0 # the number of the steps in the final output which are valid (have the same value as in the offline mode)
)
pre_encode_cache_size: int = (
0 # the size of the needed cache for the pre-encoding part of the model to avoid caching inside the pre-encoding layers
)
drop_extra_pre_encoded: int = 0 # the number of steps to get dropped after the pre-encoding layer
last_channel_num: int = 0 # number of the last channel layers (like MHA layers) which need caching in the model
last_time_num: int = 0 # number of the last time layers (like convolutions) which need caching in the model
def compute_stochastic_depth_drop_probs(
num_layers: int,
stochastic_depth_drop_prob: float = 0.0,
stochastic_depth_mode: str = "linear",
stochastic_depth_start_layer: int = 1,
) -> list[float]:
"""Computes drop probabilities for stochastic depth regularization technique.
The first layer is never dropped and the starting layer needs to be greater
or equal to 1.
Args:
num_layers (int): number of layers in the network.
stochastic_depth_drop_prob (float): if non-zero, will randomly drop
layers during training. The higher this value, the more often layers
are dropped. Defaults to 0.0.
stochastic_depth_mode (str): can be either "linear" or "uniform". If
set to "uniform", all layers have the same probability of drop. If
set to "linear", the drop probability grows linearly from 0 for the
first layer to the desired value for the final layer. Defaults to
"linear".
stochastic_depth_start_layer (int): starting layer for stochastic depth.
All layers before this will never be dropped. Note that drop
probability will be adjusted accordingly if mode is "linear" when
start layer is > 1. Defaults to 1.
Returns:
List[float]: list of drop probabilities for all layers
"""
if not (0 <= stochastic_depth_drop_prob < 1.0):
raise ValueError("stochastic_depth_drop_prob has to be in [0, 1).")
if not (1 <= stochastic_depth_start_layer <= num_layers):
raise ValueError("stochastic_depth_start_layer has to be in [1, num layers].")
# Layers before `stochastic_depth_start_layer` are never dropped
layer_drop_probs = [0.0] * stochastic_depth_start_layer
# Layers starting with `stochastic_depth_start_layer` may be dropped
if (L := num_layers - stochastic_depth_start_layer) > 0:
if stochastic_depth_mode == "linear":
# we start with 1/L * drop_prob and and end with the desired drop probability.
layer_drop_probs += [l / L * stochastic_depth_drop_prob for l in range(1, L + 1)]
elif stochastic_depth_mode == "uniform":
layer_drop_probs += [stochastic_depth_drop_prob] * L
else:
raise ValueError(
f'stochastic_depth_mode has to be one of ["linear", "uniform"]. Current value: {stochastic_depth_mode}'
)
return layer_drop_probs
```
## /src/liquid_audio/model/lfm2_audio.py
```py path="/src/liquid_audio/model/lfm2_audio.py"
from __future__ import annotations
import json
import math
from collections.abc import Generator
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import ClassVar, Literal, Self, TypedDict
import torch
from accelerate import init_on_device, load_checkpoint_in_model
from einops import rearrange
from torch import nn
from transformers import Lfm2Config, Lfm2Model
from transformers.models.lfm2.modeling_lfm2 import Lfm2HybridConvCache
from liquid_audio.model.conformer.encoder import ConformerEncoder, ConformerEncoderConfig
from liquid_audio.model.mlp import MLP
from liquid_audio.model.transformer import MHA, RawLMBackbone, SharedEmbedding, StandardBlock
from liquid_audio.processor import PreprocessorConfig
from liquid_audio.utils import LFMModality, get_model_dir, mel2emb_len, module_exists
class LFM2_HFConfig(TypedDict):
pretrained_model_name_or_path: str
revision: str
@dataclass(kw_only=True)
class LFM2AudioConfig:
architectures: list[str] # for huggingface compatibility
codebooks: int
tie_audio_embeddings: bool
semantic_codebook_factor: float
codebook_weight: Literal["log", "linear"]
interleaved_n_text: int
interleaved_n_audio: int
preprocessor: PreprocessorConfig
encoder: ConformerEncoderConfig
lfm: Lfm2Config
depthformer: DepthformerConfig
@dataclass(kw_only=True)
class DepthformerConfig:
layers: int
dim: int
tie: bool
class LFM2AudioModel(nn.Module):
audio_vocab_size: ClassVar[int] = 2048 + 1 # Includes +1 for EOAudio
def __init__(
self,
conf: LFM2AudioConfig,
):
super().__init__()
self.conf = conf
self.codebooks = conf.codebooks
## LFM2 ##
self.lfm = Lfm2Model(conf.lfm)
## Audio encoder ##
self.conformer = ConformerEncoder(**asdict(conf.encoder))
self.audio_adapter = MLP(self.conformer._feat_out, self.lfm.config.hidden_size, [self.lfm.config.hidden_size])
## Depthformer ##
self.depthformer_layers = conf.depthformer.layers
self.depthformer_dim = conf.depthformer.dim
self.depthformer_tie = conf.depthformer.tie
self.audio_embedding = SharedEmbedding(
dim=self.lfm.config.hidden_size,
vocab_size=self.audio_vocab_size * self.conf.codebooks,
embed_init_scale=1.0,
norm_eps=0.00001,
tie_embedding=conf.tie_audio_embeddings,
)
self.codebook_offsets: torch.Tensor
self.register_buffer("codebook_offsets", torch.arange(self.conf.codebooks) * self.audio_vocab_size)
self.audio_loss_weights: torch.Tensor
if conf.codebook_weight == "log":
weights = (torch.linspace(1, 0, self.codebooks) * math.log(conf.semantic_codebook_factor)).exp()
else:
weights = torch.ones((self.codebooks,))
weights[0] *= conf.semantic_codebook_factor
self.register_buffer(
"audio_loss_weights",
weights,
)
scale = 1 / math.sqrt(2 * self.depthformer_layers)
layers = [
StandardBlock(MHA(self.depthformer_dim, out_init_scale=scale), out_init_scale=scale)
for _ in range(self.depthformer_layers)
]
self.depthformer = RawLMBackbone(layers, has_embedding=False)
self.depth_linear = nn.Linear(self.lfm.config.hidden_size, self.depthformer_dim * self.codebooks)
self.depth_embeddings = nn.ModuleList(
[
SharedEmbedding(
dim=self.depthformer_dim,
vocab_size=self.audio_vocab_size,
tie_embedding=self.depthformer_tie,
)
for _ in range(self.codebooks)
]
)
@classmethod
def from_pretrained(
cls,
repo_id: str | Path,
*,
revision: str | None = None,
dtype: torch.dtype = torch.bfloat16,
device: torch.device | str = "cuda",
) -> Self:
cache_path = get_model_dir(repo_id, revision=revision)
with (cache_path / "config.json").open() as f:
config = json.load(f)
conf = LFM2AudioConfig(
lfm=Lfm2Config(**config.pop("lfm")),
encoder=ConformerEncoderConfig(**config.pop("encoder")),
depthformer=DepthformerConfig(**config.pop("depthformer")),
**config,
)
if isinstance(device, str):
device = torch.device(device)
with init_on_device(device, include_buffers=True):
model = cls(conf).to(device=device, dtype=dtype)
if module_exists("flash_attn"):
model.lfm.set_attn_implementation("flash_attention_2")
else:
model.lfm.set_attn_implementation("sdpa")
load_checkpoint_in_model(model, cache_path)
return model
@torch.no_grad()
def generate_sequential(
self,
*,
text: torch.Tensor,
audio_in: torch.Tensor,
audio_in_lens: torch.Tensor,
audio_out: torch.Tensor,
modality_flag: torch.Tensor,
max_new_tokens: int = 20,
text_temperature: float | None = None,
text_top_k: int | None = None,
audio_temperature: float | None = None,
audio_top_k: int | None = None,
) -> Generator[torch.Tensor, None, None]:
in_emb = self._prefill(
text=text,
audio_in=audio_in,
audio_in_lens=audio_in_lens,
audio_out=audio_out,
modality_flag=modality_flag,
)
current_modality: LFMModality = LFMModality.TEXT
cache: Lfm2HybridConvCache | None = None
for _ in range(max_new_tokens):
# breakpoint()
lfm_out = self.lfm(
inputs_embeds=in_emb,
past_key_values=cache,
use_cache=True,
)
output_embeddings = lfm_out.last_hidden_state
cache = lfm_out.past_key_values
if current_modality == LFMModality.TEXT:
text_logits = nn.functional.linear(output_embeddings[0, -1], self.lfm.embed_tokens.weight)
next_token = self._sample_text_token(text_logits, temperature=text_temperature, top_k=text_top_k)
yield next_token
if next_token == 128: # <|audio_start|>
current_modality = LFMModality.AUDIO_OUT
if next_token == 7: # <|im_end|>
break
in_emb = self.lfm.embed_tokens(next_token)[None, :]
elif current_modality == LFMModality.AUDIO_OUT:
next_token = self._sample_audio_frame(
output_embeddings[0, -1],
temperature=audio_temperature,
top_k=audio_top_k,
)
if next_token[0] == 2048:
next_token[:] = 2048
current_modality = LFMModality.TEXT
yield next_token
in_emb = self.audio_embedding(next_token + self.codebook_offsets).sum(0)[None, None, :]
@torch.no_grad()
def generate_interleaved(
self,
*,
text: torch.Tensor,
audio_in: torch.Tensor,
audio_in_lens: torch.Tensor,
audio_out: torch.Tensor,
modality_flag: torch.Tensor,
max_new_tokens: int = 20,
text_temperature: float | None = None,
text_top_k: int | None = None,
audio_temperature: float | None = None,
audio_top_k: int | None = None,
) -> Generator[torch.Tensor, None, None]:
in_emb = self._prefill(
text=text,
audio_in=audio_in,
audio_in_lens=audio_in_lens,
audio_out=audio_out,
modality_flag=modality_flag,
)
current_modality: LFMModality = LFMModality.TEXT
modality_left: int = self.conf.interleaved_n_text
cache: Lfm2HybridConvCache | None = None
text_done: bool = False
for _ in range(max_new_tokens):
modality_left -= 1
lfm_out = self.lfm(
inputs_embeds=in_emb,
past_key_values=cache,
use_cache=True,
)
output_embeddings = lfm_out.last_hidden_state
cache = lfm_out.past_key_values
if current_modality == LFMModality.TEXT:
text_logits = nn.functional.linear(output_embeddings[0, -1], self.lfm.embed_tokens.weight)
next_token = self._sample_text_token(text_logits, temperature=text_temperature, top_k=text_top_k)
if next_token == 7: # <|im_end|>
break
yield next_token
if next_token == 130: # <|text_end|>
text_done = True
if not modality_left or text_done:
current_modality = LFMModality.AUDIO_OUT
modality_left = self.conf.interleaved_n_audio
in_emb = self.lfm.embed_tokens(next_token)[None, :]
elif current_modality == LFMModality.AUDIO_OUT:
next_token = self._sample_audio_frame(
output_embeddings[0, -1],
temperature=audio_temperature,
top_k=audio_top_k,
)
if not modality_left and not text_done:
current_modality = LFMModality.TEXT
modality_left = self.conf.interleaved_n_text
if next_token[0] == 2048:
next_token[:] = 2048
current_modality = LFMModality.TEXT
yield next_token
in_emb = self.audio_embedding(next_token + self.codebook_offsets).sum(0)[None, None, :]
def _prefill(
self,
*,
text: torch.Tensor,
audio_in: torch.Tensor,
audio_in_lens: torch.Tensor,
audio_out: torch.Tensor,
modality_flag: torch.Tensor,
) -> torch.Tensor:
## Sanity check
assert len(text.shape) == 2
assert len(audio_in.shape) == 2
assert len(audio_in_lens.shape) == 1
assert len(audio_out.shape) == 2
assert len(modality_flag.shape) == 2
assert text.shape[0] == 1
assert audio_in.shape[0] == 128
assert audio_out.shape[0] >= self.codebooks
assert modality_flag.shape[0] == 1
assert (modality_flag == LFMModality.TEXT).sum() == text.shape[1]
assert (modality_flag == LFMModality.AUDIO_OUT).sum() == audio_out.shape[1]
assert (modality_flag == LFMModality.AUDIO_IN).sum() == mel2emb_len(audio_in_lens).sum()
assert audio_in.shape[1] == audio_in_lens.sum()
# Text embeddings
text_emb = self.lfm.embed_tokens(text[0])
text_mask = modality_flag == LFMModality.TEXT
# Audio-in embeddings
## Batch and pad
audio_in_list = audio_in.mT.split(audio_in_lens.tolist())
if audio_in_list:
padded_audio_in = nn.utils.rnn.pad_sequence(audio_in_list, batch_first=True)
else:
padded_audio_in = text_emb.new_empty((0, 8 + 1, 128))
## Encode
audio_enc, audio_in_len = self.conformer(padded_audio_in.mT, audio_in_lens)
## Unbatch, unpad
len_mask = torch.arange(audio_enc.shape[-1], device=audio_enc.device).unsqueeze(0) < audio_in_len.unsqueeze(1)
audio_enc_concatenated = audio_enc.mT[len_mask]
## Adapt
audio_in_emb = self.audio_adapter(audio_enc_concatenated)
audio_in_mask = modality_flag == LFMModality.AUDIO_IN
assert audio_in_emb.shape[0] == audio_in_mask.sum()
# Audio-out embeddings
offset_audio_tokens = audio_out[: self.codebooks] + self.codebook_offsets.unsqueeze(1)
audio_out_emb = self.audio_embedding(offset_audio_tokens).sum(0)
audio_out_mask = modality_flag == LFMModality.AUDIO_OUT
assert audio_out_emb.shape[0] == audio_out_mask.sum()
# Assemble LFM input
B, L, D = *modality_flag.shape, self.lfm.config.hidden_size
in_emb = text_emb.new_empty((B, L, D))
in_emb[text_mask] = text_emb
in_emb[audio_in_mask] = audio_in_emb
in_emb[audio_out_mask] = audio_out_emb
return in_emb
def _sample_text_token(
self, logits: torch.Tensor, *, temperature: float | None = None, top_k: int | None = None
) -> torch.Tensor:
greedy = temperature is None or temperature <= 0 or top_k == 1
if greedy:
next_token = logits.argmax(keepdim=True)
else:
assert isinstance(temperature, float) and temperature > 0
logits /= temperature
if top_k is not None:
min_score = torch.topk(logits, top_k).values[-1]
to_remove = logits < min_score
logits = torch.masked_fill(logits, to_remove, -float("inf"))
probs = logits.softmax(0)
next_token = torch.multinomial(probs, 1)
return next_token
def _sample_audio_frame(
self,
embedding: torch.Tensor, # lfm_dim sized vecto
*,
temperature: float | None = None,
top_k: int | None = None,
) -> torch.Tensor:
greedy = temperature is None or temperature <= 0 or top_k == 1
depthformer_in = rearrange(self.depth_linear(embedding), "(C D) -> C D", C=self.codebooks, D=self.depthformer_dim)
depthformer_token = torch.zeros_like(depthformer_in[0])
cache = None
out_tokens: list[torch.Tensor] = []
for i in range(self.codebooks):
cur_depthformer_input = depthformer_in[i] + depthformer_token
depthformer_out, cache = self.depthformer.forward_cached(cur_depthformer_input[None, None, :], cache)
depthformer_logits = self.depth_embeddings[i].get_logits(depthformer_out.squeeze()) # type: ignore[operator]
if greedy:
next_token = depthformer_logits.argmax(keepdim=True)
else:
assert isinstance(temperature, float) and temperature > 0
depthformer_logits /= temperature
if top_k is not None:
min_score = torch.topk(depthformer_logits, top_k).values[-1]
to_remove = depthformer_logits < min_score
depthformer_logits = torch.masked_fill(depthformer_logits, to_remove, -float("inf"))
probs = depthformer_logits.softmax(0)
next_token = torch.multinomial(probs, 1)
out_tokens.append(next_token)
depthformer_token = self.depth_embeddings[i](next_token).squeeze()
return torch.cat(out_tokens)
```
## /src/liquid_audio/model/mlp.py
```py path="/src/liquid_audio/model/mlp.py"
import torch
from torch import nn
class MLP(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_dim: list[int],
bias: bool = True,
use_layer_norm: bool = True,
dropout: float = 0.0,
):
super().__init__()
channels = [in_channels, *hidden_dim, out_channels]
layers: list[nn.Module] = []
if use_layer_norm:
layers.append(nn.LayerNorm(channels[0]))
for i in range(len(channels) - 1):
layers.append(
nn.Linear(
in_features=channels[i],
out_features=channels[i + 1],
bias=bias,
)
)
if i != (len(channels) - 2):
layers.append(nn.GELU())
if dropout > 0:
layers.append(nn.Dropout(p=dropout))
self.model = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
```
## /src/liquid_audio/model/transformer.py
```py path="/src/liquid_audio/model/transformer.py"
import math
from abc import ABC, abstractmethod
from collections.abc import Iterable, Sequence
from functools import partial
from typing import Any, Literal, TypeGuard, cast
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from torch.nn.attention.bias import causal_lower_right
type CacheType = torch.Tensor | None | Sequence["CacheType"]
class SequenceModel(nn.Module, ABC):
"""Models operating on sequences
Assumptions:
input shape [N, T, dim]
output shape [N, T', dim_out]
"""
dim: int
dim_out: int
@abstractmethod
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__()
@abstractmethod
def forward(self, x: torch.Tensor, cache: CacheType = None) -> torch.Tensor: ...
@abstractmethod
def forward_cached(self, x: torch.Tensor, cache: CacheType = None) -> tuple[torch.Tensor, CacheType]: ...
class LayerKVCache:
"""
Assumes input cache is a tuple of two tensors (key, value)
"""
def __init__(self, cache: tuple[torch.Tensor, torch.Tensor] | None = None) -> None:
assert cache is None or len(cache) == 2
self.key_cache = cache[0] if cache is not None else None
self.value_cache = cache[1] if cache is not None else None
def update(self, k: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if self.key_cache is None or self.value_cache is None:
# initialize
self.key_cache = k
self.value_cache = v
else:
self.key_cache = torch.cat([self.key_cache, k], dim=1)
self.value_cache = torch.cat([self.value_cache, v], dim=1)
return self.key_cache, self.value_cache
def get_cache_size(self) -> int:
if self.key_cache is None:
return 0
else:
return self.key_cache.shape[1]
class RMSNorm(SequenceModel):
def __init__(self, dim: int, eps: float = 1e-6) -> None:
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor, cache: CacheType = None) -> torch.Tensor:
assert cache is None, "RMSNorm expects None cache"
output = self._norm(x.float())
return (output * self.weight).type_as(x)
def forward_cached(self, x: torch.Tensor, cache: CacheType = None) -> tuple[torch.Tensor, CacheType]:
return self(x, cache), None
class GLU(SequenceModel):
def __init__(
self,
dim: int,
ff_dim: int | None = None,
mlp_init_scale: float = 1.0,
out_init_scale: float = 0.14434,
use_swiglu: bool = True,
multiple_of: int = 256,
ffn_dim_multiplier: float = 1.0,
):
super().__init__()
# linear_cls = FlexLinear
self.dim = dim
self.dim_out = dim
self.use_swiglu = use_swiglu
self.num_params = 0
if ff_dim is None: # from LFMv1Block
ff_dim = 4 * dim
if use_swiglu:
ff_dim = int(2 * ff_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
ff_dim = int(ffn_dim_multiplier * ff_dim)
ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, ff_dim, bias=False)
self.num_params += dim * ff_dim
std = mlp_init_scale / math.sqrt(dim)
torch.nn.init.normal_(self.w1.weight, mean=0.0, std=std)
if use_swiglu:
self.w3 = nn.Linear(dim, ff_dim, bias=False)
self.num_params += dim * ff_dim
std = mlp_init_scale / math.sqrt(dim)
torch.nn.init.normal_(self.w3.weight, mean=0.0, std=std)
self.w2 = nn.Linear(ff_dim, dim, bias=False)
self.num_params += ff_dim * dim
std = out_init_scale * mlp_init_scale / math.sqrt(ff_dim)
torch.nn.init.normal_(self.w2.weight, mean=0.0, std=std)
def forward(self, x: torch.Tensor, cache: CacheType = None) -> torch.Tensor:
assert cache is None, "expected None cache for GLU"
if self.use_swiglu:
return cast(torch.Tensor, self.w2(F.silu(self.w1(x)) * self.w3(x)))
else:
return cast(torch.Tensor, self.w2(F.gelu(self.w1(x))))
def forward_cached(self, x: torch.Tensor, cache: CacheType = None) -> tuple[torch.Tensor, CacheType]:
return self(x, cache), None
class BoundedAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 32,
head_style: Literal["mha", "gqa", "mqa"] = "mha",
gqa_dim: int | None = None,
qk_layernorm: bool = False,
norm_eps: float = 1e-5,
) -> None:
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.head_style = head_style
if self.head_style == "gqa":
# only access attribute if using gqa head style
assert gqa_dim is not None
self.gqa_dim = gqa_dim
assert self.num_heads % self.gqa_dim == 0, f"{self.gqa_dim} % {self.head_dim} != 0"
self.qk_layernorm = qk_layernorm
if self.qk_layernorm:
self.q_layernorm = RMSNorm(self.head_dim, eps=norm_eps)
self.k_layernorm = RMSNorm(self.head_dim, eps=norm_eps)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
freqs_cis: torch.Tensor | None = None,
cache: LayerKVCache | None = None,
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
bsz, seqlen = q.shape[0], q.shape[1]
if self.head_style == "mqa":
q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim)
k = k.reshape(bsz, seqlen, 1, self.head_dim)
v = v.reshape(bsz, seqlen, 1, self.head_dim)
elif self.head_style == "mha":
q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim)
k = k.reshape(bsz, seqlen, self.num_heads, self.head_dim)
v = v.reshape(bsz, seqlen, self.num_heads, self.head_dim)
elif self.head_style == "gqa":
q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim)
k = k.reshape(bsz, seqlen, self.gqa_dim, self.head_dim)
v = v.reshape(bsz, seqlen, self.gqa_dim, self.head_dim)
if self.qk_layernorm:
q = self.q_layernorm(q)
k = self.k_layernorm(k)
if freqs_cis is not None:
q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)
if cache is not None:
k, v = cache.update(k, v)
q_len = q.shape[1]
kv_len = k.shape[1]
query = q.transpose(1, 2)
key = k.transpose(1, 2)
value = v.transpose(1, 2)
enable_gqa = self.head_style in ("mqa", "gqa")
if q_len == kv_len:
output = nn.functional.scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=enable_gqa)
else:
if q_len == 1:
attn_mask = None
else:
attn_mask = causal_lower_right(q_len, kv_len)
output = nn.functional.scaled_dot_product_attention(
query, key, value, is_causal=False, enable_gqa=enable_gqa, attn_mask=attn_mask
)
output = output.transpose(1, 2)
output = output.reshape(bsz, seqlen, self.num_heads * self.head_dim)
return output, (k, v)
class MHA(SequenceModel):
def __init__(
self,
dim: int,
num_heads: int = 32,
head_style: Literal["mha", "gqa", "mqa"] = "gqa",
out_init_scale: float = 0.125, # 1/sqrt(2*n_layers) (from gpt3)
proj_init_scale: float = 1.0,
qk_layernorm: bool = True,
norm_eps: float = 0.00001,
gqa_dim: int = 8, # Optional if head_style is not "gqa"
freqs_cis: torch.Tensor | None = None,
max_seq_len: int = 128_000, # Stored positional encodings. Optional if freqs_cis is given
theta: float = 1_000_000.0, # Positional encoding theta. Optional if freqs_cis is given
):
super().__init__()
self.dim = self.dim_out = dim
self.num_heads = num_heads
assert self.dim % self.num_heads == 0, "expected dim to be divisible by num_heads"
self.head_dim = self.dim // self.num_heads
self.head_style = head_style
if self.head_style == "gqa":
# only access attribute if using gqa head style
assert gqa_dim is not None
self.gqa_dim = gqa_dim
# q, k, v + optional w and z projections
if self.head_style == "mha":
self.total_width = 3 * self.dim
elif self.head_style == "mqa":
self.total_width = self.dim + 2 * self.head_dim
elif self.head_style == "gqa":
assert self.gqa_dim is not None
self.total_width = self.dim + 2 * self.head_dim * self.gqa_dim
else:
raise NotImplementedError(f"head style {self.head_style} not implemented")
self.qkv_proj = nn.Linear(
self.dim,
self.total_width,
bias=False,
)
std = proj_init_scale / math.sqrt(self.dim)
torch.nn.init.normal_(self.qkv_proj.weight, mean=0.0, std=std)
self.out_proj = nn.Linear(self.dim, self.dim, bias=False)
std = out_init_scale * proj_init_scale / math.sqrt(self.dim)
torch.nn.init.normal_(self.out_proj.weight, mean=0.0, std=std)
self.bounded_attention = BoundedAttention(
dim=dim,
num_heads=num_heads,
head_style=head_style,
gqa_dim=gqa_dim,
qk_layernorm=qk_layernorm,
norm_eps=norm_eps,
)
if freqs_cis is not None:
self.freqs_cis = freqs_cis
else:
self.freqs_cis = precompute_freqs_cis(self.head_dim, max_seq_len, theta)
def _validate_cache(self, cache: CacheType) -> TypeGuard[tuple[torch.Tensor, torch.Tensor]]:
return (
isinstance(cache, tuple)
and len(cache) == 2
and isinstance(cache[0], torch.Tensor)
and isinstance(cache[1], torch.Tensor)
)
def forward(self, x: torch.Tensor, cache: CacheType = None) -> torch.Tensor:
return self.forward_cached(x, cache)[0]
def forward_cached(self, x: torch.Tensor, cache: CacheType = None) -> tuple[torch.Tensor, CacheType]:
if cache is not None:
assert self._validate_cache(cache)
kv_cache = LayerKVCache(cache)
else:
kv_cache = None
# x is (bsz, seqlen, d_model)
seq_len = x.shape[1]
x = self.qkv_proj(x)
if self.head_style == "mha":
xq, xk, xv = x.split(self.dim, dim=-1)
elif self.head_style == "mqa":
xq, xk, xv = x.split([self.dim, self.head_dim, self.head_dim], dim=-1)
elif self.head_style == "gqa":
xq, xk, xv = x.split(
[self.dim, self.head_dim * self.gqa_dim, self.head_dim * self.gqa_dim],
dim=-1,
)
# TODO: Need to clean up, hack for now to allow rpes in grafted model if using e.g. mqa for grafting
self.freqs_cis = self.freqs_cis.to(xq.device)
if kv_cache is not None:
# If using cache, get freqs for all new tokens starting from cache size
cache_size = kv_cache.get_cache_size()
freqs_cis = self.freqs_cis[cache_size : cache_size + seq_len]
else:
# Otherwise get freqs for full sequence
freqs_cis = self.freqs_cis[:seq_len]
ys, new_cache = self.bounded_attention(xq, xk, xv, freqs_cis=freqs_cis, cache=kv_cache)
ys = self.out_proj(ys)
return cast(torch.Tensor, ys), new_cache
class StandardBlock(SequenceModel):
"""Block with an operator + norm + skip connection, followed by a GLU + norm + skip connection"""
def __init__(
self,
operator: SequenceModel,
ff_dim: int | None = None,
mlp_init_scale: float = 1.0,
out_init_scale: float = 0.125, # 1/sqrt(2*n_layers) (from gpt3)
use_swiglu: bool = True,
multiple_of: int = 256,
ffn_dim_multiplier: float = 1.0,
norm_eps: float = 0.00001,
):
super().__init__()
self.operator = operator
self.dim = self.dim_out = self.operator.dim
if ff_dim is None:
ff_dim = 4 * self.dim
self.feed_forward = GLU(
dim=self.dim,
ff_dim=ff_dim,
mlp_init_scale=mlp_init_scale,
out_init_scale=out_init_scale,
use_swiglu=use_swiglu,
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier,
)
self.operator_norm = RMSNorm(self.dim, eps=norm_eps)
self.ffn_norm = RMSNorm(self.dim, eps=norm_eps)
def forward(self, x: torch.Tensor, cache: CacheType = None) -> torch.Tensor:
h = self.operator(self.operator_norm(x), cache)
h += x
h_glu = self.feed_forward(self.ffn_norm(h))
out = h + h_glu
return cast(torch.Tensor, out)
def forward_cached(self, x: torch.Tensor, cache: CacheType | None = None) -> tuple[torch.Tensor, CacheType]:
h, new_cache = self.operator.forward_cached(self.operator_norm(x), cache)
h += x
h_glu = self.feed_forward.forward(self.ffn_norm(h))
out = h + h_glu
return out, new_cache
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings.
xk (torch.Tensor): Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
xq_ = torch.view_as_complex(rearrange(xq.float(), "... (D two) -> ... D two", two=2))
xk_ = torch.view_as_complex(rearrange(xk.float(), "... (D two) -> ... D two", two=2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.
Args:
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
Returns:
torch.Tensor: Reshaped frequency tensor.
Raises:
AssertionError: If the frequency tensor doesn't match the expected shape.
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
"""
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
end (int): End index for precomputing frequencies.
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials.
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
class SharedEmbedding(nn.Module):
def __init__(
self,
dim: int,
vocab_size: int = 65_536,
embed_init_scale: float = 1.0,
norm_eps: float = 0.00001,
*,
tie_embedding: bool = True,
) -> None:
super().__init__()
self.embedding = torch.nn.Embedding(vocab_size, dim)
std = embed_init_scale / math.sqrt(dim)
torch.nn.init.normal_(self.embedding.weight, mean=0.0, std=std)
self.embedding_norm = RMSNorm(dim, eps=norm_eps) # Note: this is really the norm before output projection
self.to_logits = nn.Linear(dim, vocab_size, bias=False)
if tie_embedding:
self.to_logits.weight = self.embedding.weight
else:
# If not tying embedding, scale the output weights
std = embed_init_scale / math.sqrt(dim)
torch.nn.init.normal_(self.to_logits.weight, mean=0.0, std=std)
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
return self.embed(tokens)
def embed(self, tokens: torch.Tensor) -> torch.Tensor:
return cast(torch.Tensor, self.embedding(tokens))
def get_logits(self, embeddings: torch.Tensor) -> torch.Tensor:
return cast(torch.Tensor, self.to_logits(self.embedding_norm(embeddings)))
class RawLMBackbone(SequenceModel):
"""
"Raw" Backbone for LM models:
- input: continuous embeddings
- output: continuous embeddings
"""
def __init__(
self,
layers: Iterable[SequenceModel],
vocab_size: int = 65_536,
norm_eps: float = 0.00001,
embed_init_scale: float = 1.0,
*,
has_embedding: bool = True,
tie_embedding: bool = True,
) -> None:
super().__init__()
self.layers = cast(Sequence[SequenceModel], nn.ModuleList(layers))
self.dim = self.layers[0].dim
self.dim_out = self.layers[-1].dim_out
assert self.dim == self.dim_out, "expected first layer input dim to be equal to last layer's output dim"
if has_embedding:
# TODO: possibly wrap in wrap_sharded
self.embedding = SharedEmbedding(
self.dim, vocab_size, embed_init_scale=embed_init_scale, norm_eps=norm_eps, tie_embedding=tie_embedding
)
self.has_embedding = has_embedding
self.vocab_size = vocab_size
def forward(self, x: torch.Tensor, cache: CacheType | None = None) -> torch.Tensor:
if cache is not None:
assert isinstance(cache, list)
assert len(cache) == len(self.layers)
else:
cache = [None] * len(self.layers)
for layer, layer_cache in zip(self.layers, cache, strict=True):
x = layer(x, layer_cache)
return x
def forward_cached(self, x: torch.Tensor, cache: CacheType | None = None) -> tuple[torch.Tensor, CacheType]:
if cache is not None:
assert isinstance(cache, list)
assert len(cache) == len(self.layers)
else:
cache = [None] * len(self.layers)
cache_out: list[CacheType] = []
for layer, layer_cache in zip(self.layers, cache, strict=True):
x, new_cache = layer.forward_cached(x, layer_cache)
cache_out.append(new_cache)
return x, cache_out
def wrap_activation_checkpoint[T: nn.Module](mod: T) -> T:
# NOTE: we're using torch.utils.checkpoint.checkpoint here to avoid the error in backward pass when using zero-3 + activation checkpointing (https://github.com/microsoft/DeepSpeed/issues/4595)
checkpoint_fn = partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
mod.forward_ = mod.forward # type: ignore[assignment]
def forward(*args, **kwargs):
return checkpoint_fn(mod.forward_, *args, **kwargs)
mod.forward = forward
return mod
```
## /src/liquid_audio/moshi/__init__.py
```py path="/src/liquid_audio/moshi/__init__.py"
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
moshi is the inference codebase for Kyutai audio generation models.
The code has been adapted from Audiocraft, see LICENSE.audiocraft
Copyright (c) Meta Platforms, Inc. and affiliates.
"""
# flake8: noqa
from . import conditioners
from . import models
from . import modules
from . import quantization
from . import utils
__version__ = "0.2.12a3"
```
## /src/liquid_audio/moshi/client.py
```py path="/src/liquid_audio/moshi/client.py"
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Client for the Moshi server."""
import argparse
import asyncio
import queue
import sys
import aiohttp
import numpy as np
import sphn
import sounddevice as sd
from .client_utils import AnyPrinter, Printer, RawPrinter
class Connection:
def __init__(
self,
printer: AnyPrinter,
websocket: aiohttp.ClientWebSocketResponse,
sample_rate: float = 24000,
channels: int = 1,
frame_size: int = 1920,
) -> None:
self.printer = printer
self.websocket = websocket
self.sample_rate = sample_rate
self.frame_size = frame_size
self.channels = channels
self._done = False
self._in_stream = sd.InputStream(
samplerate=sample_rate,
channels=channels,
blocksize=self.frame_size,
callback=self._on_audio_input,
)
self._out_stream = sd.OutputStream(
samplerate=sample_rate,
channels=channels,
blocksize=frame_size,
callback=self._on_audio_output,
)
self._opus_writer = sphn.OpusStreamWriter(sample_rate)
self._opus_reader = sphn.OpusStreamReader(sample_rate)
self._output_queue = queue.Queue()
async def _queue_loop(self) -> None:
while True:
if self._done:
return
await asyncio.sleep(0.001)
msg = self._opus_writer.read_bytes()
if len(msg) > 0:
try:
await self.websocket.send_bytes(b"\x01" + msg)
except Exception as e:
print(e)
self._lost_connection()
return
async def _decoder_loop(self) -> None:
all_pcm_data = None
while True:
if self._done:
return
await asyncio.sleep(0.001)
pcm = self._opus_reader.read_pcm()
if all_pcm_data is None:
all_pcm_data = pcm
else:
all_pcm_data = np.concatenate((all_pcm_data, pcm))
while all_pcm_data.shape[-1] >= self.frame_size:
self._output_queue.put(all_pcm_data[: self.frame_size])
all_pcm_data = np.array(all_pcm_data[self.frame_size :])
async def _recv_loop(self) -> None:
try:
async for message in self.websocket:
if message.type == aiohttp.WSMsgType.CLOSED:
self.printer.log("info", "Connection closed")
break
elif message.type == aiohttp.WSMsgType.ERROR:
self.printer.log("error", f"{self.websocket.exception()}")
break
elif message.type != aiohttp.WSMsgType.BINARY:
self.printer.log("error", f"received from server: {message.type}")
continue
message = message.data
if not isinstance(message, bytes):
self.printer.log(
"warning", f"unsupported message type {type(message)}"
)
continue
if len(message) == 0:
self.printer.log("warning", "empty message")
continue
kind = message[0]
if kind == 1: # audio
payload = message[1:]
self._opus_reader.append_bytes(payload)
self.printer.print_pending()
elif kind == 2: # text
payload = message[1:]
self.printer.print_token(payload.decode())
else:
self.printer.log("warning", f"unknown message kind {kind}")
except Exception as e:
print(e)
self._lost_connection()
return
def _lost_connection(self) -> None:
if not self._done:
self.printer.log("error", "Lost connection with the server!")
self._done = True
def _on_audio_input(self, in_data, frames, time_, status) -> None:
assert in_data.shape == (self.frame_size, self.channels), in_data.shape
self._opus_writer.append_pcm(in_data[:, 0])
def _on_audio_output(self, out_data, frames, time_, status) -> None:
assert out_data.shape == (self.frame_size, self.channels), out_data.shape
try:
pcm_data = self._output_queue.get(block=False)
# TODO: handle other shapes by using some form of fifo/ring buffer.
assert pcm_data.shape == (self.frame_size,), pcm_data.shape
out_data[:, 0] = pcm_data
except queue.Empty:
out_data.fill(0)
self.printer.print_lag()
async def run(self) -> None:
with self._in_stream, self._out_stream:
await asyncio.gather(
self._recv_loop(), self._decoder_loop(), self._queue_loop()
)
async def run(printer: AnyPrinter, args):
if args.url is None:
proto = "ws"
if args.https:
proto += "s"
uri = f"{proto}://{args.host}:{args.port}/api/chat"
else:
proto = "wss"
if '://' in args.url:
proto, without_proto = args.url.split('://', 1)
if proto in ['ws', 'http']:
proto = "ws"
elif proto in ['wss', 'https']:
proto = "wss"
else:
printer.log("error", "The provided URL {args.url} seems to contain a protocol but it is unknown.")
sys.exit(1)
else:
without_proto = args.url
uri = f"{proto}://{without_proto}/api/chat"
printer.log("info", f"Connecting to {uri}.")
async with aiohttp.ClientSession() as session:
async with session.ws_connect(uri) as ws:
printer.log("info", "connected!")
printer.print_header()
connection = Connection(printer, ws)
await connection.run()
def main():
parser = argparse.ArgumentParser("client_opus")
parser.add_argument("--host", default="localhost", type=str, help="Hostname to connect to.")
parser.add_argument("--port", default=8998, type=int, help="Port to connect to.")
parser.add_argument("--https", action='store_true',
help="Set this flag for using a https connection.")
parser.add_argument("--url", type=str, help='Provides directly a URL, e.g. to a gradio tunnel.')
args = parser.parse_args()
printer: AnyPrinter
if sys.stdout.isatty():
printer = Printer()
else:
printer = RawPrinter()
try:
asyncio.run(run(printer, args))
except KeyboardInterrupt:
printer.log("warning", "Interrupting, exiting connection.")
printer.log("info", "All done!")
if __name__ == "__main__":
main()
```
## /src/liquid_audio/moshi/client_gradio.py
```py path="/src/liquid_audio/moshi/client_gradio.py"
import argparse
from typing import Generator, Literal, cast
import numpy as np
import sphn
from numpy.typing import NDArray
try:
import gradio as gr # type: ignore
import websockets.sync.client
from gradio_webrtc import AdditionalOutputs, StreamHandler, WebRTC # type: ignore
except ImportError:
raise ImportError("Please install gradio-webrtc>=0.0.18 to run this script.")
# See https://freddyaboulton.github.io/gradio-webrtc/deployment/ for
# instructions on how to set the rtc_configuration variable for deployment
# on cloud platforms like Heroku, Spaces, etc.
rtc_configuration = None
class MoshiHandler(StreamHandler):
def __init__(
self,
url: str,
expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 480,
) -> None:
self.url = url
proto, without_proto = self.url.split("://", 1)
if proto in ["ws", "http"]:
proto = "ws"
elif proto in ["wss", "https"]:
proto = "wss"
self._generator = None
self.output_chunk_size = 1920
self.ws = None
self.ws_url = f"{proto}://{without_proto}/api/chat"
self.stream_reader = sphn.OpusStreamReader(output_sample_rate)
self.stream_writer = sphn.OpusStreamWriter(output_sample_rate)
self.all_output_data = None
super().__init__(
expected_layout,
output_sample_rate,
output_frame_size,
input_sample_rate=24000,
)
def receive(self, frame: tuple[int, NDArray]) -> None:
if not self.ws:
self.ws = websockets.sync.client.connect(self.ws_url)
_, array = frame
array = array.squeeze().astype(np.float32) / 32768.0
self.stream_writer.append_pcm(array)
bytes = b"\x01" + self.stream_writer.read_bytes()
self.ws.send(bytes)
def generator(
self,
) -> Generator[tuple[int, NDArray] | None | AdditionalOutputs, None, None]:
for message in cast(websockets.sync.client.ClientConnection, self.ws):
if len(message) == 0:
yield None
kind = message[0]
if kind == 1:
payload = message[1:]
self.stream_reader.append_bytes(payload)
pcm = self.stream_reader.read_pcm()
if self.all_output_data is None:
self.all_output_data = pcm
else:
self.all_output_data = np.concatenate((self.all_output_data, pcm))
while self.all_output_data.shape[-1] >= self.output_chunk_size:
yield (
self.output_sample_rate,
self.all_output_data[: self.output_chunk_size].reshape(1, -1),
)
self.all_output_data = np.array(
self.all_output_data[self.output_chunk_size :]
)
elif kind == 2:
payload = cast(bytes, message[1:])
yield AdditionalOutputs(payload.decode())
def emit(self) -> tuple[int, NDArray] | AdditionalOutputs | None:
if not self.ws:
return
if not self._generator:
self._generator = self.generator()
try:
return next(self._generator)
except StopIteration:
self.reset()
def reset(self) -> None:
self._generator = None
self.all_output_data = None
def copy(self) -> StreamHandler:
return MoshiHandler(
self.url,
self.expected_layout, # type: ignore
self.output_sample_rate,
self.output_frame_size,
)
def shutdown(self) -> None:
if self.ws:
self.ws.close()
def main():
parser = argparse.ArgumentParser("client_gradio")
parser.add_argument("--url", type=str, help="URL to moshi server.")
args = parser.parse_args()
with gr.Blocks() as demo:
gr.HTML(
"""
<div style='text-align: center'>
<h1>
Talk To Moshi (Powered by WebRTC ⚡️)
</h1>
<p>
Each conversation is limited to 90 seconds. Once the time limit is up you can rejoin the conversation.
</p>
</div>
"""
)
chatbot = gr.Chatbot(type="messages", value=[])
webrtc = WebRTC(
label="Conversation",
modality="audio",
mode="send-receive",
rtc_configuration=rtc_configuration,
)
webrtc.stream(
MoshiHandler(args.url),
inputs=[webrtc, chatbot],
outputs=[webrtc],
time_limit=90,
)
def add_text(chat_history, response):
if len(chat_history) == 0:
chat_history.append({"role": "assistant", "content": ""})
chat_history[-1]["content"] += response
return chat_history
webrtc.on_additional_outputs(
add_text,
inputs=[chatbot],
outputs=chatbot,
queue=False,
show_progress="hidden",
)
demo.launch()
if __name__ == "__main__":
main()
```
## /src/liquid_audio/moshi/client_utils.py
```py path="/src/liquid_audio/moshi/client_utils.py"
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Utilities for the command line client, in particular for handling interactions with the terminal.
"""
from dataclasses import dataclass
import sys
def colorize(text, color):
code = f"\033[{color}m"
restore = "\033[0m"
return "".join([code, text, restore])
def make_log(level: str, msg: str) -> str:
if level == "warning":
prefix = colorize("[Warn]", "1;31")
elif level == "info":
prefix = colorize("[Info]", "1;34")
elif level == "error":
prefix = colorize("[Err ]", "1;31")
else:
raise ValueError(f"Unknown level {level}")
return prefix + " " + msg
def log(level: str, msg: str) -> None:
"""Log something with a given level."""
print(make_log(level, msg))
class RawPrinter:
def __init__(self, stream=sys.stdout, err_stream=sys.stderr):
self.stream = stream
self.err_stream = err_stream
def print_header(self):
pass
def print_token(self, token: str):
self.stream.write(token)
self.stream.flush()
def log(self, level: str, msg: str):
print(f"{level.capitalize()}: {msg}", file=self.err_stream)
def print_lag(self):
self.err_stream.write(colorize(" [LAG]", "31"))
self.err_stream.flush()
def print_pending(self):
pass
@dataclass
class LineEntry:
msg: str
color: str | None = None
def render(self):
if self.color is None:
return self.msg
else:
return colorize(self.msg, self.color)
def __len__(self):
return len(self.msg)
class Line:
def __init__(self, stream):
self.stream = stream
self._line: list[LineEntry] = []
self._has_padding: bool = False
self._max_line_length = 0
def __bool__(self):
return bool(self._line)
def __len__(self):
return sum(len(entry) for entry in self._line)
def add(self, msg: str, color: str | None = None) -> int:
entry = LineEntry(msg, color)
return self._add(entry)
def _add(self, entry: LineEntry) -> int:
if self._has_padding:
self.erase(count=0)
self._line.append(entry)
self.stream.write(entry.render())
self._max_line_length = max(self._max_line_length, len(self))
return len(entry)
def erase(self, count: int = 1):
if count:
entries = list(self._line[:-count])
else:
entries = list(self._line)
self._line.clear()
self.stream.write("\r")
for entry in entries:
self._line.append(entry)
self.stream.write(entry.render())
self._has_padding = False
def newline(self):
missing = self._max_line_length - len(self)
if missing > 0:
self.stream.write(" " * missing)
self.stream.write("\n")
self._line.clear()
self._max_line_length = 0
self._has_padding = False
def flush(self):
missing = self._max_line_length - len(self)
if missing > 0:
self.stream.write(" " * missing)
self._has_padding = True
self.stream.flush()
class Printer:
def __init__(self, max_cols: int = 80, stream=sys.stdout, err_stream=sys.stderr):
self.max_cols = max_cols
self.line = Line(stream)
self.stream = stream
self.err_stream = err_stream
self._pending_count = 0
self._pending_printed = False
def print_header(self):
self.line.add(" " + "-" * (self.max_cols) + " ")
self.line.newline()
self.line.flush()
self.line.add("| ")
def _remove_pending(self) -> bool:
if self._pending_printed:
self._pending_printed = False
self.line.erase(1)
return True
return False
def print_token(self, token: str, color: str | None = None):
self._remove_pending()
remaining = self.max_cols - len(self.line)
if len(token) <= remaining:
self.line.add(token, color)
else:
end = " " * remaining + " |"
if token.startswith(" "):
token = token.lstrip()
self.line.add(end)
self.line.newline()
self.line.add("| ")
self.line.add(token, color)
else:
assert color is None
erase_count = None
cumulated = ""
for idx, entry in enumerate(self.line._line[::-1]):
if entry.color:
# probably a LAG message
erase_count = idx
break
if entry.msg.startswith(" "):
erase_count = idx + 1
cumulated = entry.msg + cumulated
break
if erase_count is not None:
if erase_count > 0:
self.line.erase(erase_count)
remaining = self.max_cols - len(self.line)
end = " " * remaining + " |"
self.line.add(end)
self.line.newline()
self.line.add("| ")
token = cumulated.lstrip() + token
self.line.add(token)
else:
self.line.add(token[:remaining])
self.line.add(" |")
self.line.newline()
self.line.add("| ")
self.line.add(token[remaining:])
self.line.flush()
def log(self, level: str, msg: str):
msg = make_log(level, msg)
self._remove_pending()
if self.line:
self.line.newline()
self.line.flush()
print(msg, file=self.err_stream)
self.err_stream.flush()
def print_lag(self):
self.print_token(" [LAG]", "31")
def print_pending(self):
chars = ["|", "/", "-", "\\"]
count = int(self._pending_count / 5)
char = chars[count % len(chars)]
colors = ["32", "33", "31"]
self._remove_pending()
self.line.add(char, colors[count % len(colors)])
self._pending_printed = True
self._pending_count += 1
AnyPrinter = Printer | RawPrinter
```
## /src/liquid_audio/moshi/conditioners/__init__.py
```py path="/src/liquid_audio/moshi/conditioners/__init__.py"
# flake8: noqa
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Modules to help doing generations under some fixed conditions.
"""
from .base import (ConditionType, ConditionAttributes, ConditionFuser, ConditionProvider,
BaseConditioner, TensorCondition, ConditionTensors, dropout_all_conditions)
```
## /src/liquid_audio/moshi/conditioners/base.py
```py path="/src/liquid_audio/moshi/conditioners/base.py"
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# Adapted from
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
import logging
import typing as tp
from collections import defaultdict
from dataclasses import dataclass, field
from itertools import chain
import torch
from torch import nn
from ..modules.transformer import create_sin_embedding
logger = logging.getLogger(__name__)
TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
ConditionTensors = dict[str, 'ConditionType']
class ConditionType(tp.NamedTuple):
"""Return type for a conditioner: both a condition tensor, and a mask indicating valid positions.
"""
condition: torch.Tensor
mask: torch.Tensor
@dataclass(frozen=True)
class TensorCondition:
"""Looks quite similar to ConditionType, but represents the input to TensorConditioners.
`tensor` should be [B | 1, T, D], and `mask` should be `[B | 1, T]`.
"""
tensor: torch.Tensor
mask: torch.Tensor
@staticmethod
def from_tensor(tensor: torch.Tensor):
B, T, _ = tensor.shape
mask = torch.ones(B, T, dtype=torch.bool, device=tensor.device)
return TensorCondition(tensor, mask)
@staticmethod
def cat(conditions: tp.Sequence['TensorCondition']) -> 'TensorCondition':
assert conditions, "Cannot cat empty list."
ref_tensor = conditions[0].tensor
B, _, D = ref_tensor.shape
assert B == 1
B = len(conditions)
T = max(condition.tensor.shape[1] for condition in conditions)
mask = torch.zeros(B, T, dtype=torch.bool, device=ref_tensor.device)
tensor = torch.zeros(B, T, D, dtype=ref_tensor.dtype, device=ref_tensor.device)
for b, condition in enumerate(conditions):
tensor[b, :condition.tensor.shape[1], :] = condition.tensor[0]
mask[b, :condition.mask.shape[1]] = condition.mask[0]
return TensorCondition(tensor, mask)
@dataclass
class ConditionAttributes:
"""Standard class for representing the set of potential inputs to the conditioners.
Typically, `audiocraft.data.audio_dataset.SegmentInfo` will convert
to this class to make conditioning agnostic to the type of dataset.
There are two kinds of conditionings: text (or None), or raw torch tensors (with a mask).
"""
text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
tensor: tp.Dict[str, TensorCondition] = field(default_factory=dict)
@property
def text_attributes(self) -> tp.Iterable[str]:
return self.text.keys()
@property
def tensor_attributes(self) -> tp.Iterable[str]:
return self.text.keys()
@staticmethod
def condition_types() -> tp.FrozenSet[str]:
return frozenset(["text", "tensor"])
def copy(self) -> 'ConditionAttributes':
return ConditionAttributes(dict(self.text), dict(self.tensor))
Prepared = tp.TypeVar('Prepared') # represents the prepared condition input type.
class BaseConditioner(nn.Module, tp.Generic[Prepared]):
"""Base model for all conditioner modules.
Args:
dim (int): internal dim of the model.
output_dim (int): Output dim of the conditioner.
force_linear (bool, optional): Force linear projection even when `dim == output_dim`.
pad_empty (bool): if True, conditionings of 0 length will be padded to have length 1.
output_bias (bool): if True, the output projection will have a bias.
learn_padding (bool): if True, the padding value will be learnt, zero otherwise.
"""
def __init__(self, dim: int,
output_dim: int,
device: tp.Union[torch.device, str],
force_linear: bool = True,
pad_empty: bool = True,
output_bias: bool = False,
learn_padding: bool = True):
super().__init__()
self.dim = dim
self.output_dim = output_dim
self.pad_empty = pad_empty
self.device = device
self.output_proj: nn.Module
if force_linear or dim != output_dim:
self.output_proj = nn.Linear(dim, output_dim, bias=output_bias, device=device)
assert not output_bias
else:
self.output_proj = nn.Identity()
self.learnt_padding: tp.Optional[torch.Tensor]
if learn_padding:
self.learnt_padding = nn.Parameter(
torch.randn(1, 1, output_dim, device=device), requires_grad=True)
self.learnt_padding.data *= 0.2
else:
self.learnt_padding = None
def prepare(self, *args, **kwargs) -> Prepared:
"""Should be any part of the processing that will lead to a synchronization
point, e.g. BPE tokenization with transfer to the GPU.
The returned value will be saved and return later when calling forward().
"""
raise NotImplementedError()
def _get_condition(self, inputs: Prepared) -> ConditionType:
"""Gets input that should be used as conditioning (e.g, genre, description or a waveform).
Outputs a ConditionType, after the input data was embedded as a dense vector.
Returns:
ConditionType:
- A tensor of size [B, T, dim] where B is the batch size, T is the length of the
output embedding and `dim` is the internal dimension of the embedding.
- And a mask indicating where the padding tokens, of shape `[B, T]`.
"""
raise NotImplementedError()
def forward(self, inputs: Prepared) -> ConditionType:
cond, mask = self._get_condition(inputs)
B, T, C = cond.shape
if T == 0 and self.pad_empty:
cond = torch.zeros(B, T, C, device=cond.device, dtype=cond.dtype)
mask = torch.zeros_like(cond[..., 0], dtype=torch.bool)
cond = self.output_proj(cond)
maskf = mask.float()[..., None]
if self.learnt_padding is not None:
cond = cond * maskf + self.learnt_padding * (1 - maskf)
else:
cond = cond * maskf
return ConditionType(cond, mask)
class _BaseTextConditioner(BaseConditioner[Prepared]):
pass
class _BaseTensorConditioner(BaseConditioner[Prepared]):
pass
def dropout_tensor(condition: TensorCondition) -> TensorCondition:
"""Utility function for nullifying a WavCondition object.
"""
return TensorCondition(
tensor=torch.zeros_like(condition.tensor),
mask=torch.zeros_like(condition.mask))
def dropout_condition_(sample: ConditionAttributes, condition_type: str, condition: str) -> None:
"""Utility function for nullifying an attribute inside a ConditionAttributes object.
Works in-place.
"""
valid_conditions = ConditionAttributes.condition_types()
if condition_type not in valid_conditions:
raise ValueError(
"dropout_condition got an unexpected condition type!"
f" expected one of {valid_conditions} but got '{condition_type}'")
if condition not in getattr(sample, condition_type):
raise ValueError(
"dropout_condition received an unexpected condition!"
f" expected tensor={sample.tensor.keys()} and text={sample.text.keys()}"
f" but got '{condition}' of type '{condition_type}'!"
)
if condition_type == 'tensor':
tensor_condition = sample.tensor[condition]
sample.tensor[condition] = dropout_tensor(tensor_condition)
elif condition_type == 'text':
sample.text[condition] = None
else:
assert False
def dropout_all_conditions(attributes: tp.Sequence[ConditionAttributes]) -> list[ConditionAttributes]:
"""
Args:
attributes (list[ConditionAttributes]): All conditions attributes.
Returns:
list[ConditionAttributes]: Same with all conditions dropped.
"""
attributes = [attribute.copy() for attribute in attributes]
for condition_type in ConditionAttributes.condition_types():
for attribute in attributes:
for condition in getattr(attribute, condition_type):
dropout_condition_(attribute, condition_type, condition)
return attributes
class ConditionProvider(nn.Module):
"""Prepare and provide conditions given all the supported conditioners.
Args:
conditioners (dict): Dictionary of conditioners.
device (torch.device or str, optional): Device for conditioners and output condition types.
"""
def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"):
super().__init__()
self.device = device
self.conditioners = nn.ModuleDict(conditioners).to(device)
@property
def text_conditions(self):
return [k for k, v in self.conditioners.items() if isinstance(v, _BaseTextConditioner)]
@property
def tensor_conditions(self):
return [k for k, v in self.conditioners.items() if isinstance(v, _BaseTensorConditioner)]
def _collate_text(self, samples: tp.Sequence[ConditionAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
"""Given a list of ConditionAttributes objects, compile a dictionary where the keys
are the attributes and the values are the aggregated input per attribute.
For example:
Input:
[
ConditionAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
ConditionAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
]
Output:
{
"genre": ["Rock", "Hip-hop"],
"description": ["A rock song with a guitar solo", "A hip-hop verse"]
}
Args:
samples (list of ConditionAttributes): List of ConditionAttributes samples.
Returns:
dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch.
"""
out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
texts = [x.text for x in samples]
for text in texts:
for condition in self.text_conditions:
out[condition].append(text[condition])
return out
def _collate_tensors(self, samples: tp.Sequence[ConditionAttributes]) -> tp.Dict[str, TensorCondition]:
"""For each tensor attribute, collate the tensor from individual batch items.
Args:
samples (list of ConditionAttributes): List of ConditionAttributes samples.
Returns:
dict[str, TensorCondition]: A dictionary mapping an attribute name to tensor.
"""
per_attribute = defaultdict(list)
out: tp.Dict[str, TensorCondition] = {}
for sample in samples:
for attribute in self.tensor_conditions:
per_attribute[attribute].append(sample.tensor[attribute])
# stack all tensors to a single tensor
for attribute in self.tensor_conditions:
out[attribute] = TensorCondition.cat(per_attribute[attribute])
return out
def prepare(self, inputs: tp.Sequence[ConditionAttributes]) -> tp.Dict[str, tp.Any]:
"""Match attributes/tensors with existing conditioners in self, and call `prepare` for each one.
This should be called before starting any real GPU work to avoid synchronization points.
This will return a dict matching conditioner names to their arbitrary prepared representations.
Args:
inputs (list[ConditionAttributes]): List of ConditionAttributes objects containing
text and tensors conditions.
"""
assert all([isinstance(x, ConditionAttributes) for x in inputs]), (
"Got unexpected types input for conditioner! should be tp.List[ConditionAttributes]",
f" but types were {set([type(x) for x in inputs])}"
)
output = {}
text = self._collate_text(inputs)
tensors = self._collate_tensors(inputs)
assert set(text.keys() | tensors.keys()).issubset(set(self.conditioners.keys())), (
f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
f"got {text.keys(), tensors.keys()}"
)
missing_inputs = set(self.conditioners.keys()) - (set(text.keys()) | set(tensors.keys()))
if missing_inputs:
raise RuntimeError(f"Some conditioners did not receive an input: {missing_inputs}")
for attribute, batch in chain(text.items(), tensors.items()):
conditioner = self.conditioners[attribute]
assert isinstance(conditioner, BaseConditioner)
output[attribute] = conditioner.prepare(batch)
return output
def forward(self, prepared: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
"""Compute pairs of `(embedding, mask)` using the configured conditioners and the prepared representations.
The output is for example:
{
"genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
"description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
...
}
Args:
prepared (dict): Dict of prepared representations as returned by `prepare()`.
"""
output = {}
for name, inputs in prepared.items():
condition, mask = self.conditioners[name](inputs)
output[name] = ConditionType(condition, mask)
return output
def prepare_and_provide(self, inputs: tp.Sequence[ConditionAttributes]):
"""See .prepare() and .forward()."""
prepared = self.prepare(inputs)
return self(prepared)
class ConditionFuser(nn.Module):
"""Condition fuser handles the logic to combine the different conditions
to the actual model input.
Args:
fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
each condition. For example:
{
"prepend": ["description"],
"sum": ["genre", "bpm"],
"cross": ["description"],
}
cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
"""
FUSING_METHODS = ["sum", "prepend", "cross"]
def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
cross_attention_pos_emb_scale: float = 1.0):
super().__init__()
assert all(
[k in self.FUSING_METHODS for k in fuse2cond.keys()]
), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}"
self.cross_attention_pos_emb = cross_attention_pos_emb
self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
self.cond2fuse: tp.Dict[str, str] = {}
for fuse_method, conditions in fuse2cond.items():
for condition in conditions:
self.cond2fuse[condition] = fuse_method
if fuse_method not in ['cross', 'sum']:
raise RuntimeError("only `sum` and `cross` conditionings are supported "
f"for now, got {fuse_method}.")
@property
def has_conditions(self) -> bool:
return bool(self.cond2fuse)
@property
def has_prepend(self) -> bool:
"""Is there a conditioning that needs to be prepending to the Transformer sequence."""
return bool(self.fuse2cond['prepend'])
def get_cross(self, conditions: ConditionTensors) -> torch.Tensor | None:
"""Return the tensor to be provided for the cross attention."""
cross = None
for name in self.fuse2cond['cross']:
cond, _ = conditions[name]
if cross is None:
cross = cond
else:
cross = torch.cat([cross, cond], dim=1)
if self.cross_attention_pos_emb and cross is not None:
positions = torch.arange(
cross.shape[1],
device=cross.device
).view(1, -1, 1)
pos_emb = create_sin_embedding(positions, cross.shape[-1]).to(cross.dtype)
cross = cross + self.cross_attention_pos_emb_scale * pos_emb
return cross
def get_sum(self, conditions: ConditionTensors) -> torch.Tensor | None:
"""Return the tensor to be provided as an extra sum offset shared for each step."""
sum = None
for name in self.fuse2cond['sum']:
cond, _ = conditions[name]
assert cond.shape[1] == 1, cond.shape
if sum is None:
sum = cond
else:
sum = sum + cond
return sum
def get_prepend(self, conditions: ConditionTensors) -> torch.Tensor | None:
"""Return the tensor to be prepended to the transformer."""
prepend = None
for name in self.fuse2cond['prepend']:
cond, _ = conditions[name]
if prepend is None:
prepend = cond
else:
prepend = torch.cat([cond, prepend], dim=1)
if prepend is not None:
sum = self.get_sum(conditions)
if sum is not None:
prepend = prepend + sum
return prepend
```
## /src/liquid_audio/moshi/conditioners/tensors.py
```py path="/src/liquid_audio/moshi/conditioners/tensors.py"
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from .base import _BaseTensorConditioner, TensorCondition, ConditionType
class TensorConditioner(_BaseTensorConditioner[TensorCondition]):
"""Does basically nothing.
"""
def prepare(self, tensor: TensorCondition) -> TensorCondition:
device = next(iter(self.parameters())).device
return TensorCondition(tensor.tensor.to(device=device), tensor.mask.to(device=device))
def _get_condition(self, inputs: TensorCondition) -> ConditionType:
return ConditionType(inputs.tensor, inputs.mask)
```
## /src/liquid_audio/moshi/models/__init__.py
```py path="/src/liquid_audio/moshi/models/__init__.py"
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Models for the compression model Moshi,
"""
# flake8: noqa
from .compression import (
CompressionModel,
MimiModel,
)
from .lm import LMModel, LMGen
from .loaders import get_mimi, get_moshi_lm
```
## /src/liquid_audio/moshi/modules/__init__.py
```py path="/src/liquid_audio/moshi/modules/__init__.py"
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Modules used for building the models."""
# flake8: noqa
from .conv import (
NormConv1d,
NormConvTranspose1d,
StreamingConv1d,
StreamingConvTranspose1d,
pad_for_conv1d,
pad1d,
unpad1d,
)
from .seanet import SEANetEncoder, SEANetDecoder
from .transformer import StreamingTransformer
```
## /src/liquid_audio/moshi/modules/resample.py
```py path="/src/liquid_audio/moshi/modules/resample.py"
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
from einops import rearrange
import torch
from torch import nn
from .conv import StreamingConv1d, StreamingConvTranspose1d
class ConvDownsample1d(nn.Module):
"""
Downsampling by some integer amount `stride` using convolutions
with a kernel size of twice the stride.
If `causal` is True, the output uses a causal convolution.
"""
def __init__(
self,
stride: int,
dimension: tp.Optional[int] = None,
causal: bool = False,
learnt: bool = False,
channel_wise: bool = False,
):
super().__init__()
self.learnt = learnt
self.channel_wise = channel_wise
groups = 1
if learnt:
assert dimension is not None, "Dimension required for learnt convolutions."
in_channels = dimension
out_channels = dimension
if channel_wise:
groups = dimension
else:
in_channels = 1
out_channels = 1
self.conv = StreamingConv1d(
in_channels,
out_channels,
kernel_size=2 * stride,
stride=stride,
causal=causal,
groups=groups,
bias=False,
pad_mode="replicate",
)
if not learnt:
actual_conv = self.conv.conv.conv
actual_conv.weight.requires_grad_(False)
actual_conv.weight.data.fill_(1.0 / (2 * stride))
def forward(self, x: torch.Tensor):
batch_size = len(x)
if not self.learnt:
x = rearrange(x, "b c t -> (b c) () t")
y = self.conv(x)
if not self.learnt:
y = rearrange(y, "(b c) () t -> b c t", b=batch_size)
return y
class ConvTrUpsample1d(nn.Module):
"""
Upsample by some integer amount `stride` using transposed convolutions.
"""
def __init__(
self,
stride: int,
dimension: tp.Optional[int] = None,
causal: bool = False,
learnt: bool = False,
channel_wise: bool = False,
):
super().__init__()
self.learnt = learnt
self.channel_wise = channel_wise
groups = 1
if learnt:
assert dimension is not None, "Dimension required for learnt convolutions."
in_channels = dimension
out_channels = dimension
if channel_wise:
groups = dimension
else:
in_channels = 1
out_channels = 1
self.convtr = StreamingConvTranspose1d(
in_channels,
out_channels,
kernel_size=2 * stride,
stride=stride,
causal=causal,
groups=groups,
bias=False,
)
if not learnt:
actual_convtr = self.convtr.convtr.convtr
actual_convtr.weight.requires_grad_(False)
actual_convtr.weight.data.fill_(1.0)
def forward(self, x: torch.Tensor):
batch_size = len(x)
if not self.learnt:
x = rearrange(x, "b c t -> (b c) () t")
y = self.convtr(x)
if not self.learnt:
x_for_normalization = torch.ones_like(x[:1])
normalization = self.convtr(x_for_normalization)
y = y / normalization
y = rearrange(y, "(b c) () t -> b c t", b=batch_size)
return y
```
The content has been capped at 50000 tokens. The user could consider applying other filters to refine the result. The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.