Liquid4All/liquid-audio/main 113k tokens More Tools
```
├── .github/
   ├── workflows/
      ├── check.yml (200 tokens)
├── .gitignore
├── .python-version
├── LICENSE (omitted)
├── README.md (2.1k tokens)
├── assets/
   ├── asr.wav
   ├── question.wav
├── pyproject.toml (500 tokens)
├── src/
   ├── liquid_audio/
      ├── __init__.py
      ├── demo/
         ├── __init__.py
         ├── chat.py (700 tokens)
         ├── model.py (100 tokens)
      ├── model/
         ├── __init__.py
         ├── conformer/
            ├── __init__.py
            ├── encoder.py (11.2k tokens)
            ├── mha.py (3.9k tokens)
            ├── modules.py (3.7k tokens)
            ├── processor.py (4.9k tokens)
            ├── subsampling.py (4.9k tokens)
            ├── utils.py (1000 tokens)
         ├── lfm2_audio.py (3k tokens)
         ├── mlp.py (200 tokens)
         ├── transformer.py (4.2k tokens)
      ├── moshi/
         ├── __init__.py (100 tokens)
         ├── client.py (1400 tokens)
         ├── client_gradio.py (1100 tokens)
         ├── client_utils.py (1300 tokens)
         ├── conditioners/
            ├── __init__.py (100 tokens)
            ├── base.py (3.5k tokens)
            ├── tensors.py (100 tokens)
            ├── text.py (1000 tokens)
         ├── models/
            ├── __init__.py (100 tokens)
            ├── compression.py (3.5k tokens)
            ├── lm.py (7.4k tokens)
            ├── lm_utils.py (1000 tokens)
            ├── loaders.py (3.4k tokens)
            ├── tts.py (6.3k tokens)
         ├── modules/
            ├── __init__.py (100 tokens)
            ├── conv.py (2.8k tokens)
            ├── conv_test.py (1200 tokens)
            ├── gating.py (700 tokens)
            ├── lora.py (900 tokens)
            ├── resample.py (700 tokens)
            ├── rope.py (500 tokens)
            ├── seanet.py (3.1k tokens)
            ├── seanet_test.py (1200 tokens)
            ├── streaming.py (1700 tokens)
            ├── transformer.py (7.4k tokens)
         ├── py.typed
         ├── quantization/
            ├── __init__.py (100 tokens)
            ├── base.py (1100 tokens)
            ├── core_vq.py (4.3k tokens)
            ├── vq.py (2.7k tokens)
         ├── run_inference.py (2.4k tokens)
         ├── run_tts.py (2.1k tokens)
         ├── server.py (2.5k tokens)
         ├── utils/
            ├── __init__.py (100 tokens)
            ├── autocast.py (300 tokens)
            ├── compile.py (2.2k tokens)
            ├── quantize.py (500 tokens)
            ├── sampling.py (900 tokens)
            ├── utils.py (500 tokens)
      ├── processor.py (1300 tokens)
      ├── py.typed
      ├── utils.py (200 tokens)
├── uv.lock (omitted)
```


## /.github/workflows/check.yml

```yml path="/.github/workflows/check.yml" 
name: Check

on:
  push:
    branches: [ main ]
  pull_request:
    branches: [ main ]
    types:
      - opened
      - reopened
      - synchronize
      - ready_for_review
      - review_requested
  workflow_dispatch:

permissions:
  contents: read
  pull-requests: read

concurrency:
  group: ${{ github.workflow }}-${{ github.ref || github.run_id }}
  cancel-in-progress: true

jobs:
  check:
    if: >
      github.event_name != 'pull_request' ||
      !github.event.pull_request.draft ||
      github.event.action == 'review_requested'
    runs-on: ubuntu-latest
    timeout-minutes: 10
    steps:
      - uses: actions/checkout@v4

      - name: Set up Python
        uses: actions/setup-python@v5
        with:
          python-version: '3.12'

      - name: Install uv
        uses: astral-sh/setup-uv@v3
        with:
          enable-cache: true
          cache-dependency-glob: "uv.lock"

      - name: Install dependencies
        run: uv sync --dev --frozen

      - name: Run ruff check
        run: uv run ruff check

      - name: Run ruff format check
        run: uv run ruff format --check

```

## /.gitignore

```gitignore path="/.gitignore" 
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info

# Virtual environments
.venv

# MacOS
.DS_Store

```

## /.python-version

```python-version path="/.python-version" 
3.12

```

## /README.md

# Liquid Audio - Speech-to-Speech models

We present LFM2-Audio-1.5B, [Liquid AI](https://www.liquid.ai/)'s first end-to-end audio foundation model. Built with low-latency in mind, the lightweight [LFM2](https://huggingface.co/LiquidAI/LFM2-1.2B) backbone enables real time speech-to-speech conversations without sacrificing quality.

LFM2-Audio supports two generation modes, interleaved and sequential, to maximize performance and quality across different tasks. Interleaved generation outputs text and audio tokens in a fixed interleaved pattern. This approach minimizes time to first audio output and number of tokens generated, making it ideal for naturally flowing real-time speech-to-speech interactions on resource constrained devices. Sequential generation mode, where the model decides when to switch modalities via special tokens, is suitable for non-conversational tasks, such as speech-to-text (ASR) or text-to-speech (TTS).

## Installation
The package can be installed via `pip`
```bash
pip install liquid-audio
pip install "liquid-audio [demo]" # optional, to install demo dependencies
pip install flash-attn --no-build-isolation  # optional, to use flash attention 2. Will fallback to torch SDPA if not installed
```

## Usage
Generation is handled by two generation modes, interleaved and sequential, accessible from the methods `LFM2AudioModel.generate_interleaved` and `LFM2AudioModel.generate_sequential` respectively. Both are generators that yield `torch.Tensor`s. Text tokens are represented by tensors with 1 entry, and audio tokens are tensors with 8 entries, corresponding to 8 [Mimi](https://huggingface.co/docs/transformers/en/model_doc/mimi) codebooks.

The `LFM2AudioModel` class operates on tokens only. The `LFM2AudioProcessor` class is used convert between tokens and data. For text, this means the conversion from string to tokens and back. For audio inputs, this handles the conversion of waveforms to log-mel features, and for audio outputs, this handles the detokenization of audio tokens to waveform.

To facilitate the creation of inputs for the generation methods and to apply the correct chat templates, use the `ChatState` helper class. See examples below for usage instructions.

### Gradio demo
To use the demo interface, make sure to install the extra dependencies in the `[demo]` group, e.g.
```bash
pip install "liquid-audio [demo]"
```
To launch the demo, use the command `liquid-audio-demo` on the terminal. The demo interface will be available via the url http://localhost:7860.

### Multi-turn, multi-modal chat
For multi-turn chat with text and audio output, we use interleaved generation. The system prompt should be set to `Respond with interleaved text and audio.`. Here we use audio as the first user turn, and text as the second one

<details>

<summary>Conversation transcript</summary>

**User**

https://github.com/user-attachments/assets/e2ffb8c3-c84c-4460-9cb8-f95a13b6eec6

**Assistant**

Sure! How about "Handcrafted Woodworking, Precision Made for You"? Another option could be "Quality Woodworking, Quality Results." If you want something more personal, you might try "Your Woodworking Needs, Our Expertise."

https://github.com/user-attachments/assets/019664b5-3480-4801-b05a-bd62ddcb8d3e

**User**

My business specialized in chairs, can you give me something related to that?

**Assistant**

Sure thing! How about “Comfortable Chairs, Crafted with Care” or “Elegant Seats, Handcrafted for You”? Let me know if you’d like a few more options.

https://github.com/user-attachments/assets/d0d054b2-6d1d-49fb-94df-4aa0b6641990

</details>

```python
import torch
import torchaudio
from liquid_audio import LFM2AudioModel, LFM2AudioProcessor, ChatState, LFMModality

# Load models
HF_REPO = "LiquidAI/LFM2-Audio-1.5B"

processor = LFM2AudioProcessor.from_pretrained(HF_REPO).eval()
model = LFM2AudioModel.from_pretrained(HF_REPO).eval()

# Set up inputs for the model
chat = ChatState(processor)

chat.new_turn("system")
chat.add_text("Respond with interleaved text and audio.")
chat.end_turn()

chat.new_turn("user")
wav, sampling_rate = torchaudio.load("assets/question.wav")
chat.add_audio(wav, sampling_rate)
chat.end_turn()

chat.new_turn("assistant")

# Generate text and audio tokens.
text_out: list[torch.Tensor] = []
audio_out: list[torch.Tensor] = []
modality_out: list[LFMModality] = []
for t in model.generate_interleaved(**chat, max_new_tokens=512, audio_temperature=1.0, audio_top_k=4):
    if t.numel() == 1:
        print(processor.text.decode(t), end="", flush=True)
        text_out.append(t)
        modality_out.append(LFMModality.TEXT)
    else:
        audio_out.append(t)
        modality_out.append(LFMModality.AUDIO_OUT)

# output: Sure! How about "Handcrafted Woodworking, Precision Made for You"? Another option could be "Quality Woodworking, Quality Results." If you want something more personal, you might try "Your Woodworking Needs, Our Expertise."

# Detokenize audio, removing the last "end-of-audio" codes
# Mimi returns audio at 24kHz
mimi_codes = torch.stack(audio_out[:-1], 1).unsqueeze(0)
with torch.no_grad():
    waveform = processor.mimi.decode(mimi_codes)[0]
torchaudio.save("answer1.wav", waveform.cpu(), 24_000)

# Append newly generated tokens to chat history
chat.append(
    text = torch.stack(text_out, 1),
    audio_out = torch.stack(audio_out, 1),
    modality_flag = torch.tensor(modality_out),
)
chat.end_turn()

# Start new turn
chat.new_turn("user")
chat.add_text("My business specialized in chairs, can you give me something related to that?")
chat.end_turn()

chat.new_turn("assistant")

# Generate second turn text and audio tokens.
audio_out: list[torch.Tensor] = []
for t in model.generate_interleaved(**chat, max_new_tokens=512, audio_temperature=1.0, audio_top_k=4):
    if t.numel() == 1:
        print(processor.text.decode(t), end="", flush=True)
    else:
        audio_out.append(t)

# output: Sure thing! How about “Comfortable Chairs, Crafted with Care” or “Elegant Seats, Handcrafted for You”? Let me know if you’d like a few more options.

# Detokenize second turn audio, removing the last "end-of-audio" codes
mimi_codes = torch.stack(audio_out[:-1], 1).unsqueeze(0)
with torch.no_grad():
    waveform = processor.mimi.decode(mimi_codes)[0]
torchaudio.save("answer2.wav", waveform.cpu(), 24_000)
```


### ASR
For ASR, we use sequential generation, with the fixed system prompt `Perform ASR.`. The output is capitalized and punctuated.

<details>

<summary>Input audio snippet</summary>

https://github.com/user-attachments/assets/b3cc017f-363d-49f3-8e7d-f6db9556900e

**Model output**: The stale smell of old beer lingers. It takes heat to bring out the odor. A cold dip restores health and zest. A salt pickle tastes fine with ham. Tacos al pastor are my favorite. A zestful food is the hot cross bun.

</details>

```python
import torch
import torchaudio
from liquid_audio import LFM2AudioModel, LFM2AudioProcessor, ChatState, LFMModality

# Load models
HF_REPO = "LiquidAI/LFM2-Audio-1.5B"

processor = LFM2AudioProcessor.from_pretrained(HF_REPO).eval()
model = LFM2AudioModel.from_pretrained(HF_REPO).eval()

# Set up inputs for the model
chat = ChatState(processor)

chat.new_turn("system")
chat.add_text("Perform ASR.")
chat.end_turn()

chat.new_turn("user")
wav, sampling_rate = torchaudio.load("assets/asr.wav")
chat.add_audio(wav, sampling_rate)
chat.end_turn()

chat.new_turn("assistant")

# Generate text
for t in model.generate_sequential(**chat, max_new_tokens=512):
    if t.numel() == 1:
        print(processor.text.decode(t), end="", flush=True)

# Output: The stale smell of old beer lingers. It takes heat to bring out the odor. A cold dip restores health and zest. A salt pickle tastes fine with ham. Tacos al pastor are my favorite. A zestful food is the hot cross bun.
```

### TTS
For TTS, we also use sequential generation, with the fixed system prompt `Perform TTS.`. In addition, we can prompt the voice and a style using a natural language description.

<details>

<summary>TTS Sample</summary>

**Voice description**: A male speaker delivers his lines with a low-pitched voice and an animated tone. The recording is of excellent quality with almost no noise and a very close-sounding atmosphere.

**Input sentence**: What is this obsession people have with books? They put them in their houses—like they're trophies. What do you need it for after you read it?

**Output audio**

https://github.com/user-attachments/assets/2fa953cf-d8a8-477a-b841-c4f18d9266e6

</details>

```python
import torch
import torchaudio
from liquid_audio import LFM2AudioModel, LFM2AudioProcessor, ChatState, LFMModality

# Load models
HF_REPO = "LiquidAI/LFM2-Audio-1.5B"

processor = LFM2AudioProcessor.from_pretrained(HF_REPO).eval()
model = LFM2AudioModel.from_pretrained(HF_REPO).eval()

# Set up inputs for the model
chat = ChatState(processor)

chat.new_turn("system")
chat.add_text("Perform TTS.\nUse the following voice: A male speaker delivers his lines with a low-pitched voice and an animated tone. The recording is of excellent quality with almost no noise and a very close-sounding atmosphere.")
chat.end_turn()

chat.new_turn("user")
chat.add_text("What is this obsession people have with books? They put them in their houses—like they're trophies. What do you need it for after you read it?")
chat.end_turn()

chat.new_turn("assistant")

# Generate text
audio_out: list[torch.Tensor] = []
for t in model.generate_sequential(**chat, max_new_tokens=512, audio_temperature = 0.8, audio_top_k=64):
    if t.numel() > 1:
        audio_out.append(t)

# Detokenize audio
mimi_codes = torch.stack(audio_out[:-1], 1).unsqueeze(0)
with torch.no_grad():
    waveform = processor.mimi.decode(mimi_codes)[0]
torchaudio.save("tts.wav", waveform.cpu(), 24_000)
```


## License
The code in this repository and associated weights are licensed under the [LFM Open License v1.0](LICENSE).

The code for the audio encoder is based on [Nvidia NeMo](https://github.com/NVIDIA-NeMo/NeMo/tree/main), licensed under [Apache 2.0](https://github.com/NVIDIA-NeMo/NeMo/blob/294ddff187f68c055d87ffe9400e65975b38693d/LICENSE), and the [canary-180m-flash](https://huggingface.co/nvidia/canary-180m-flash) checkpoint, licensed under [CC-BY 4.0](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/cc-by-4.0.md). To simplify dependency resolution, we also ship the Python code of [Kyutai Mimi](https://github.com/kyutai-labs/moshi), licensed under the [MIT License](https://github.com/kyutai-labs/moshi/blob/aee53fc0fc0119e4d7343e5ea4dd6ddafd7f09c4/LICENSE-MIT).


## /assets/asr.wav

Binary file available at https://raw.githubusercontent.com/Liquid4All/liquid-audio/refs/heads/main/assets/asr.wav

## /assets/question.wav

Binary file available at https://raw.githubusercontent.com/Liquid4All/liquid-audio/refs/heads/main/assets/question.wav

## /pyproject.toml

```toml path="/pyproject.toml" 
[project]
name = "liquid-audio"
version = "1.0.0"
description = "Liquid Audio - Speech-to-Speech audio models"
readme = "README.md"
authors = [
    { name = "Liquid AI, Inc", email = "support@liquid.ai" }
]
license = "LicenseRef-LFM-Open-License-v1.0"
license-files = ["LICENSE"]
requires-python = ">=3.12"
dependencies = [
    "accelerate>=1.10.1",
    "einops>=0.8.1",
    "librosa>=0.11.0",
    "sentencepiece>=0.2.1",
    "torch>=2.8.0",
    "torchaudio>=2.8.0",
    "transformers>=4.55.4",
]
keywords = ["Liquid AI", "LFM", "LFM2", "Audio", "Speech-to-Speech"]

[project.urls]
Homepage = "https://www.liquid.ai/"
Repository = "https://github.com/Liquid4All/liquid-audio/"
Issues = "https://github.com/Liquid4All/liquid-audio/issues"

[project.scripts]
liquid-audio-demo = "liquid_audio.demo.chat:main [demo]"

[project.optional-dependencies]
demo = [
    "fastrtc[vad]>=0.0.30",
]

[build-system]
requires = ["uv_build>=0.8.13,<0.9.0"]
build-backend = "uv_build"

[dependency-groups]
dev = [
    "ipython>=9.4.0",
    "mypy>=1.17.1",
    "ruff>=0.12.10",
]

[tool.ruff]
line-length = 127 # The GitHub editor is 127 chars wide
extend-exclude = [
    "src/liquid_audio/model/conformer",
    "src/liquid_audio/moshi",
]

[tool.ruff.lint]
select = [
    # pycodestyle
    "E",
    # Pyflakes
    "F",
    # pyupgrade
    "UP",
    # flake8-bugbear
    "B",
    # flake8-simplify
    "SIM",
    # isort
    "I",
    # tryceratops
    "TRY",
    # perflint
    "PERF",
    # refurb
    "FURB",
    #ruff
    "RUF",
]

ignore = [
    # These conflict with ruff format
    "W191",
    "E111",
    "E114",
    "E117",
    "E501",
    "D206",
    "D300",
    "Q000",
    "Q001",
    "Q002",
    "Q003",
    "COM812",
    "COM819",
    "ISC001",
    "ISC002",
    # Probably more readable without ternary operators
    "SIM108",
    # "Ambiguous variable name", probably too strict for kernel variables
    "E741",
    # Allow error message outside of class
    "TRY003",
    # Allow ABCs without abstract methods
    "B024",
]

[tool.mypy]
mypy_path = "src"
packages = [
    "liquid_audio",
]
exclude = [
    "liquid_audio.model.conformer",
]
warn_unused_configs = true
warn_redundant_casts = true
warn_unused_ignores = true

strict_equality = true
extra_checks = true

check_untyped_defs = true

# TODO: check if needed
[[tool.mypy.overrides]]
module = [
    "accelerate.*",
    "torchaudio.*",
]
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = [
    'liquid_audio.moshi.*'
]
ignore_errors = true

```

## /src/liquid_audio/__init__.py

```py path="/src/liquid_audio/__init__.py" 
from liquid_audio.model.lfm2_audio import LFM2AudioModel
from liquid_audio.processor import ChatState, LFM2AudioProcessor
from liquid_audio.utils import LFMModality

__all__ = ["ChatState", "LFM2AudioModel", "LFM2AudioProcessor", "LFMModality"]

```

## /src/liquid_audio/demo/__init__.py

```py path="/src/liquid_audio/demo/__init__.py" 

```

## /src/liquid_audio/demo/chat.py

```py path="/src/liquid_audio/demo/chat.py" 
from queue import Queue
from threading import Thread

import gradio as gr
import numpy as np
import torch
from fastrtc import AdditionalOutputs, ReplyOnPause, WebRTC

from liquid_audio import ChatState, LFMModality

from .model import lfm2_audio, mimi, proc


def chat_producer(
    q: Queue[torch.Tensor | None],
    chat: ChatState,
    temp: float | None,
    topk: int | None,
):
    print(f"Starting generation with state {chat}.")
    with torch.no_grad(), mimi.streaming(1):
        for t in lfm2_audio.generate_interleaved(
            **chat,
            max_new_tokens=1024,
            audio_temperature=temp,
            audio_top_k=topk,
        ):
            q.put(t)

            if t.numel() > 1:
                if (t == 2048).any():
                    continue

                wav_chunk = mimi.decode(t[None, :, None])[0]
                q.put(wav_chunk)

    q.put(None)


def chat_response(audio: tuple[int, np.ndarray], _id: str, chat: ChatState, temp: float | None = 1.0, topk: int | None = 4):
    if temp == 0:
        temp = None
    if topk == 0:
        topk = None

    if temp is not None:
        temp = float(temp)
    if topk is not None:
        topk = int(topk)

    if len(chat.text) == 1:
        chat.new_turn("system")
        chat.add_text("Respond with interleaved text and audio.")
        chat.end_turn()

        chat.new_turn("user")

    rate, wav = audio
    chat.add_audio(torch.tensor(wav / 32_768, dtype=torch.float), rate)
    chat.end_turn()

    chat.new_turn("assistant")

    q: Queue[torch.Tensor | None] = Queue()
    chat_thread = Thread(target=chat_producer, args=(q, chat, temp, topk))
    chat_thread.start()

    out_text: list[torch.Tensor] = []
    out_audio: list[torch.Tensor] = []
    out_modality: list[LFMModality] = []

    while True:
        t = q.get()
        if t is None:
            break
        elif t.numel() == 1:  # text
            out_text.append(t)
            out_modality.append(LFMModality.TEXT)
            print(proc.text.decode(t), end="")
            cur_string = proc.text.decode(torch.cat(out_text)).removesuffix("<|text_end|>")
            yield AdditionalOutputs(cur_string)
        elif t.numel() == 8:
            out_audio.append(t)
            out_modality.append(LFMModality.AUDIO_OUT)
        elif t.numel() == 1920:
            np_chunk = (t.cpu().numpy() * 32_767).astype(np.int16)
            yield (24_000, np_chunk)
        else:
            raise RuntimeError(f"unexpected shape: {t.shape}")

    chat.append(
        text=torch.stack(out_text, 1),
        audio_out=torch.stack(out_audio, 1),
        modality_flag=torch.tensor(out_modality, device="cuda"),
    )

    chat.end_turn()
    chat.new_turn("user")


def clear():
    gr.Info("Cleared chat history", duration=3)
    return ChatState(proc), None


with gr.Blocks() as demo:
    gr.Markdown("# LFM2-Audio speech-to-speech chat")

    chat_state = gr.State(ChatState(proc))
    webrtc = WebRTC(
        modality="audio",
        mode="send-receive",
        # variant="textbox",
        full_screen=False,
    )
    text_out = gr.Textbox(
        lines=4,
        label="Output",
    )
    clear_btn = gr.Button("Reset chat")

    webrtc.stream(
        ReplyOnPause(
            chat_response,  # type: ignore[arg-type]
            input_sample_rate=24_000,
            output_sample_rate=24_000,
            can_interrupt=False,
        ),
        inputs=[webrtc, chat_state],
        outputs=[webrtc],
    )
    webrtc.on_additional_outputs(
        lambda s: s,
        outputs=[text_out],
    )
    clear_btn.click(clear, outputs=[chat_state, text_out])


def main():
    demo.launch()


if __name__ == "__main__":
    main()

```

## /src/liquid_audio/demo/model.py

```py path="/src/liquid_audio/demo/model.py" 
"""Initialize models"""

import logging

import torch

from liquid_audio import LFM2AudioModel, LFM2AudioProcessor

logger = logging.getLogger(__name__)

__all__ = ["lfm2_audio", "mimi", "proc"]

HF_DIR = "LiquidAI/LFM2-Audio-1.5B"

logging.info("Loading processor")
proc = LFM2AudioProcessor.from_pretrained(HF_DIR).eval()
logging.info("Loading model")
lfm2_audio = LFM2AudioModel.from_pretrained(HF_DIR).eval()
logging.info("Loading tokenizer")
mimi = proc.mimi.eval()

logging.info("Warmup tokenizer")
with mimi.streaming(1), torch.no_grad():
    for _ in range(5):
        x = torch.randint(2048, (1, 8, 1), device="cuda")
        mimi.decode(x)

```

## /src/liquid_audio/model/__init__.py

```py path="/src/liquid_audio/model/__init__.py" 

```

## /src/liquid_audio/model/conformer/__init__.py

```py path="/src/liquid_audio/model/conformer/__init__.py" 

```

## /src/liquid_audio/model/conformer/encoder.py

```py path="/src/liquid_audio/model/conformer/encoder.py" 
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
nemo Conformer

adapted from https://github.com/NVIDIA/NeMo/blob/c83adff36efaa549f7bdd26e97c01a60e9f9026b/nemo/collections/asr/modules/conformer_encoder.py
"""

from dataclasses import dataclass
from torch import nn
import torch

from .mha import MultiHeadAttention, RelPositionalEncoding
from .modules import CausalConv1D, ConformerLayer
from .subsampling import ConvSubsampling
from .utils import CacheAwareStreamingConfig, compute_stochastic_depth_drop_probs


@dataclass(kw_only=True)
class ConformerEncoderConfig:
    feat_in: int
    feat_out: int
    n_layers: int
    d_model: int
    subsampling: str
    subsampling_factor: int
    subsampling_conv_channels: int
    causal_downsampling: bool
    reduction: str | None
    reduction_position: int | None
    reduction_factor: int
    ff_expansion_factor: int
    self_attention_model: str
    n_heads: int
    att_context_size: list[list[int]]
    xscaling: bool
    untie_biases: bool
    pos_emb_max_len: int
    conv_kernel_size: int
    conv_norm_type: str
    conv_context_size: list[int] | None
    dropout: float
    dropout_pre_encoder: float
    dropout_emb: float
    dropout_att: float
    
class ConformerEncoder(nn.Module):
    """
    The encoder for ASR model of Conformer.
    Based on this paper:
    'Conformer: Convolution-augmented Transformer for Speech Recognition' by Anmol Gulati et al.
    https://arxiv.org/abs/2005.08100

    Args:
        feat_in (int): the size of feature channels
        n_layers (int): number of layers of ConformerBlock
        d_model (int): the hidden size of the model
        feat_out (int): the size of the output features
            Defaults to -1 (means feat_out is d_model)
        subsampling (str): the method of subsampling:
            choices = ['vggnet', 'striding', 'dw-striding', 'stacking', 'stacking_norm']
            Defaults to striding.
        subsampling_factor (int): the subsampling factor which should be power of 2
            Defaults to 4.
        subsampling_conv_chunking_factor(int): optionally, force chunk inputs (helpful for large inputs)
            Should be power of 2, 1 (auto-chunking, default), or -1 (no chunking)
        subsampling_conv_channels (int): the size of the convolutions in the subsampling module
            Defaults to -1 which would set it to d_model.
        reduction (str, Optional): the method of reduction, choices=['pooling', 'striding']. If no value
            is passed, then no reduction is performed and the models runs with the original 4x subsampling.
        reduction_position (int, Optional): the index of the layer to apply reduction. If -1, apply reduction
            at the end.
        reduction_factor (int): the reduction factor which should be either 1 or a power of 2
            Defaults to 1.
        ff_expansion_factor (int): the expansion factor in feed forward layers
            Defaults to 4.
        self_attention_model (str): the type of the attention layer and positional encoding.

            'rel_pos':
                relative positional embedding and Transformer-XL
            'rel_pos_local_attn':
                relative positional embedding and Transformer-XL with local attention using
                overlapping chunks. Attention context is determined by att_context_size parameter.
            'abs_pos':
                absolute positional embedding and Transformer

            Default is rel_pos.
        pos_emb_max_len (int): the maximum length of positional embeddings
            Defaults to 5000
        n_heads (int): number of heads in multi-headed attention layers
            Defaults to 4.
        att_context_size (List[Union[List[int],int]]): specifies the context sizes on each side.
            Each context size should be a list of two integers like `[100, 100]`.
            A list of context sizes like `[[100,100]`, `[100,50]]` can also be passed. -1 means unlimited context.
            Defaults to `[-1, -1]`
        att_context_probs (List[float]): a list of probabilities of each one of the att_context_size
            when a list of them is passed. If not specified, uniform distribution is being used.
            Defaults to None
        att_context_style (str): 'regular' or 'chunked_limited'.
            Defaults to 'regular'
        xscaling (bool): enables scaling the inputs to the multi-headed attention layers by `sqrt(d_model)`.
            Defaults to True.
        untie_biases (bool): whether to not share (untie) the bias weights between layers of Transformer-XL
            Defaults to True.
        conv_kernel_size (int): the size of the convolutions in the convolutional modules
            Defaults to 31.
        conv_norm_type (str): the type of the normalization in the convolutional modules
            Defaults to 'batch_norm'.
        conv_context_size (list): it can be"causal" or a list of two integers
            while `conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size`.
            `None` means `[(conv_kernel_size-1)//2`, `(conv_kernel_size-1)//2]`, and 'causal' means
            `[(conv_kernel_size-1), 0]`.
            Defaults to None.
        conv_dual_mode (bool): specifies if convolution should be dual mode when dual_offline mode is being used.
            When enables, the left half of the convolution kernel would get masked in streaming cases.
            Defaults to False.
        use_bias (bool): Use bias in all Linear and Conv1d layers from each ConformerLayer to improve
            activation flow and stabilize training of huge models.
            Defaults to True.
        dropout (float): the dropout rate used in all layers except the attention layers
            Defaults to 0.1.
        dropout_pre_encoder (float): the dropout rate used before the encoder
            Defaults to 0.1.
        dropout_emb (float): the dropout rate used for the positional embeddings
            Defaults to 0.1.
        dropout_att (float): the dropout rate used for the attention layer
            Defaults to 0.0.
        stochastic_depth_drop_prob (float): if non-zero, will randomly drop
            layers during training. The higher this value, the more often layers
            are dropped. Defaults to 0.0.
        stochastic_depth_mode (str): can be either "linear" or "uniform". If
            set to "uniform", all layers have the same probability of drop. If
            set to "linear", the drop probability grows linearly from 0 for the
            first layer to the desired value for the final layer. Defaults to
            "linear".
        stochastic_depth_start_layer (int): starting layer for stochastic depth.
            All layers before this will never be dropped. Note that drop
            probability will be adjusted accordingly if mode is "linear" when
            start layer is > 1. Defaults to 1.
        global_tokens (int): number of tokens to be used for global attention.
            Only relevant if self_attention_model is 'rel_pos_local_attn'.
            Defaults to 0.
        global_tokens_spacing (int): how far apart the global tokens are
            Defaults to 1.
        global_attn_separate (bool): whether the q, k, v layers used for global tokens should be separate.
            Defaults to False.
        use_pytorch_sdpa (bool): use torch sdpa instead of manual attention.
            Defaults to False.
        use_pytorch_sdpa_backends (list[str]): list of backend names to use in sdpa.
            None or empty list means all backends. e.g. ["MATH"]
            Defaults to None.
        bypass_pre_encode: if True, skip the pre-encoder module and the `audio_signal` should be pre-encoded
            embeddings. The `audio_signal` input supports two formats depending on the `bypass_pre_encode`
            boolean flag. This determines the required format of the input variable `audio_signal`.
            Defaults to `bypass_pre_encode=False`. `bypass_pre_encode=True` is used for the cases
            where frame-level, context-independent embeddings are needed to be saved or reused.
            (e.g., speaker cache in streaming speaker diarization)
        sync_max_audio_length (bool): when true, performs NCCL all_reduce to allocate the same amount of memory for
            positional encoding buffers on all GPUs. Disabling this setting may help with deadlocks in certain
            scenarios such as model parallelism, or generally when this module is not being ran on some GPUs
            as a part of the training step.
    """

    def input_example(self, max_batch=1, max_dim=256):
        """
        Generates input examples for tracing etc.
        Returns:
            A tuple of input examples.
        """
        dev = next(self.parameters()).device
        if self.export_cache_support:
            window_size = max_dim
            if self.streaming_cfg is not None:
                if isinstance(self.streaming_cfg.chunk_size, list):
                    chunk_size = self.streaming_cfg.chunk_size[1]
                else:
                    chunk_size = self.streaming_cfg.chunk_size
                if isinstance(self.streaming_cfg.pre_encode_cache_size, list):
                    pre_encode_cache_size = self.streaming_cfg.pre_encode_cache_size[1]
                else:
                    pre_encode_cache_size = self.streaming_cfg.pre_encode_cache_size
                window_size = chunk_size + pre_encode_cache_size
            input_example = torch.randn(max_batch, self._feat_in, window_size, device=dev)
            input_example_length = torch.randint(
                window_size // 4, window_size, (max_batch,), device=dev, dtype=torch.int64
            )
            cache_last_channel, cache_last_time, cache_last_channel_len = self.get_initial_cache_state(
                batch_size=max_batch, device=dev, max_dim=max_dim
            )
            all_input_example = tuple(
                [
                    input_example,
                    input_example_length,
                    cache_last_channel.transpose(0, 1),
                    cache_last_time.transpose(0, 1),
                    cache_last_channel_len,
                ]
            )
        else:
            input_example = torch.randn(max_batch, self._feat_in, max_dim, device=dev)
            input_example_length = torch.randint(max_dim // 4, max_dim, (max_batch,), device=dev, dtype=torch.int64)
            all_input_example = tuple([input_example, input_example_length])

        return all_input_example

    @property
    def input_types(self):
        """Returns definitions of module input ports."""
        return OrderedDict(
            {
                "audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()),
                "length": NeuralType(tuple('B'), LengthsType()),
                "cache_last_channel": NeuralType(('D', 'B', 'T', 'D'), ChannelType(), optional=True),
                "cache_last_time": NeuralType(('D', 'B', 'D', 'T'), ChannelType(), optional=True),
                "cache_last_channel_len": NeuralType(tuple('B'), LengthsType(), optional=True),
                "bypass_pre_encode": NeuralType(tuple(), BoolType(), optional=True),
            }
        )

    @property
    def input_types_for_export(self):
        """Returns definitions of module input ports."""
        return OrderedDict(
            {
                "audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()),
                "length": NeuralType(tuple('B'), LengthsType()),
                "cache_last_channel": NeuralType(('B', 'D', 'T', 'D'), ChannelType(), optional=True),
                "cache_last_time": NeuralType(('B', 'D', 'D', 'T'), ChannelType(), optional=True),
                "cache_last_channel_len": NeuralType(tuple('B'), LengthsType(), optional=True),
                "bypass_pre_encode": NeuralType(tuple(), BoolType(), optional=True),
            }
        )

    @property
    def output_types(self):
        """Returns definitions of module output ports."""
        return OrderedDict(
            {
                "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
                "encoded_lengths": NeuralType(tuple('B'), LengthsType()),
                "cache_last_channel_next": NeuralType(('D', 'B', 'T', 'D'), ChannelType(), optional=True),
                "cache_last_time_next": NeuralType(('D', 'B', 'D', 'T'), ChannelType(), optional=True),
                "cache_last_channel_next_len": NeuralType(tuple('B'), LengthsType(), optional=True),
            }
        )

    @property
    def output_types_for_export(self):
        """Returns definitions of module output ports."""
        return OrderedDict(
            {
                "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
                "encoded_lengths": NeuralType(tuple('B'), LengthsType()),
                "cache_last_channel_next": NeuralType(('B', 'D', 'T', 'D'), ChannelType(), optional=True),
                "cache_last_time_next": NeuralType(('B', 'D', 'D', 'T'), ChannelType(), optional=True),
                "cache_last_channel_next_len": NeuralType(tuple('B'), LengthsType(), optional=True),
            }
        )

    @property
    def disabled_deployment_input_names(self):
        if not self.export_cache_support:
            return set(["cache_last_channel", "cache_last_time", "cache_last_channel_len"])
        else:
            return set()

    @property
    def disabled_deployment_output_names(self):
        if not self.export_cache_support:
            return set(["cache_last_channel_next", "cache_last_time_next", "cache_last_channel_next_len"])
        else:
            return set()

    def __init__(
        self,
        feat_in,
        n_layers,
        d_model,
        feat_out=-1,
        causal_downsampling=False,
        subsampling='striding',
        subsampling_factor=4,
        subsampling_conv_chunking_factor=1,
        subsampling_conv_channels=-1,
        reduction=None,
        reduction_position=None,
        reduction_factor=1,
        ff_expansion_factor=4,
        self_attention_model='rel_pos',
        n_heads=4,
        att_context_size=None,
        att_context_probs=None,
        att_context_style='regular',
        xscaling=True,
        untie_biases=True,
        pos_emb_max_len=5000,
        conv_kernel_size=31,
        conv_norm_type='batch_norm',
        conv_context_size=None,
        use_bias=True,
        dropout=0.1,
        dropout_pre_encoder=0.1,
        dropout_emb=0.1,
        dropout_att=0.0,
        stochastic_depth_drop_prob: float = 0.0,
        stochastic_depth_mode: str = "linear",
        stochastic_depth_start_layer: int = 1,
        global_tokens: int = 0,
        global_tokens_spacing: int = 1,
        global_attn_separate: bool = False,
        use_pytorch_sdpa: bool = False,
        use_pytorch_sdpa_backends=None,
        sync_max_audio_length: bool = True,
    ):
        super().__init__()
        d_ff = d_model * ff_expansion_factor
        self.d_model = d_model
        self.n_layers = n_layers
        self._feat_in = feat_in
        self.att_context_style = att_context_style
        self.subsampling_factor = subsampling_factor
        self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor

        self.self_attention_model = self_attention_model
        self.global_tokens = global_tokens
        self.global_attn_separate = global_attn_separate
        self.global_tokens_spacing = global_tokens_spacing
        self.use_pytorch_sdpa = use_pytorch_sdpa
        if use_pytorch_sdpa_backends is None:
            use_pytorch_sdpa_backends = []
        self.use_pytorch_sdpa_backends = use_pytorch_sdpa_backends
        self.sync_max_audio_length = sync_max_audio_length

        # Setting up the att_context_size
        (
            self.att_context_size_all,
            self.att_context_size,
            self.att_context_probs,
            self.conv_context_size,
        ) = self._calc_context_sizes(
            att_context_style=att_context_style,
            att_context_size=att_context_size,
            att_context_probs=att_context_probs,
            conv_context_size=conv_context_size,
            conv_kernel_size=conv_kernel_size,
        )

        if xscaling:
            self.xscale = math.sqrt(d_model)
        else:
            self.xscale = None

        # Subsampling
        if subsampling_conv_channels == -1:
            subsampling_conv_channels = d_model
        if subsampling and subsampling_factor > 1:
            if subsampling in ['stacking', 'stacking_norm']:
                # stacking_norm has an extra layer norm after stacking comparing to stacking
                self.pre_encode = StackingSubsampling(
                    subsampling_factor=subsampling_factor,
                    feat_in=feat_in,
                    feat_out=d_model,
                    norm=True if subsampling == 'stacking_norm' else False,
                )
            else:
                self.pre_encode = ConvSubsampling(
                    subsampling=subsampling,
                    subsampling_factor=subsampling_factor,
                    feat_in=feat_in,
                    feat_out=d_model,
                    conv_channels=subsampling_conv_channels,
                    subsampling_conv_chunking_factor=subsampling_conv_chunking_factor,
                    activation=nn.ReLU(True),
                    is_causal=causal_downsampling,
                )
        else:
            self.pre_encode = nn.Linear(feat_in, d_model)

        # Reduction
        if reduction and reduction_factor > 1:
            assert reduction_position >= -1 and reduction_position < n_layers
            self.reduction_subsampling = SubsamplingReductionModule(
                reduction=reduction,
                d_model=d_model,
                reduction_factor=reduction_factor,
            )
            self.reduction_position = reduction_position
        else:
            self.reduction_subsampling = None
            self.reduction_position = None

        self._feat_out = d_model

        # Biases for relative positional encoding
        if not untie_biases and self_attention_model == "rel_pos":
            d_head = d_model // n_heads
            pos_bias_u = nn.Parameter(torch.Tensor(n_heads, d_head))
            pos_bias_v = nn.Parameter(torch.Tensor(n_heads, d_head))
            nn.init.zeros_(pos_bias_u)
            nn.init.zeros_(pos_bias_v)
        else:
            pos_bias_u = None
            pos_bias_v = None

        # Positional encodings
        self.pos_emb_max_len = pos_emb_max_len
        if self_attention_model == "rel_pos":
            self.pos_enc = RelPositionalEncoding(
                d_model=d_model,
                dropout_rate=dropout_pre_encoder,
                max_len=pos_emb_max_len,
                xscale=self.xscale,
                dropout_rate_emb=dropout_emb,
            )
        elif self_attention_model == 'rel_pos_local_attn':
            if max(att_context_size) <= 0:
                raise ValueError("When using local attention, context size must be set > 0")
            self.pos_enc = LocalAttRelPositionalEncoding(
                att_context_size=att_context_size,
                d_model=d_model,
                dropout_rate=dropout,
                max_len=pos_emb_max_len,
                xscale=self.xscale,
                dropout_rate_emb=dropout_emb,
            )
        elif self_attention_model == "abs_pos":
            pos_bias_u = None
            pos_bias_v = None
            self.pos_enc = PositionalEncoding(
                d_model=d_model, dropout_rate=dropout_pre_encoder, max_len=pos_emb_max_len, xscale=self.xscale
            )
        else:
            raise ValueError(f"Not valid self_attention_model: '{self_attention_model}'!")

        self.layers = nn.ModuleList()
        for i in range(n_layers):
            layer = ConformerLayer(
                d_model=d_model,
                d_ff=d_ff,
                self_attention_model=self_attention_model,
                global_tokens=global_tokens,
                global_tokens_spacing=global_tokens_spacing,
                global_attn_separate=global_attn_separate,
                n_heads=n_heads,
                conv_kernel_size=conv_kernel_size,
                conv_norm_type=conv_norm_type,
                conv_context_size=self.conv_context_size,
                dropout=dropout,
                dropout_att=dropout_att,
                pos_bias_u=pos_bias_u,
                pos_bias_v=pos_bias_v,
                att_context_size=self.att_context_size,
                use_bias=use_bias,
                use_pytorch_sdpa=self.use_pytorch_sdpa,
                use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends,
            )
            self.layers.append(layer)

        if feat_out > 0 and feat_out != self._feat_out:
            self.out_proj = nn.Linear(self._feat_out, feat_out)
            self._feat_out = feat_out
        else:
            self.out_proj = None
            self._feat_out = d_model
        self.set_max_audio_length(self.pos_emb_max_len)
        self.use_pad_mask = True

        self.setup_streaming_params()
        self.export_cache_support = False

        self.layer_drop_probs = compute_stochastic_depth_drop_probs(
            len(self.layers), stochastic_depth_drop_prob, stochastic_depth_mode, stochastic_depth_start_layer
        )
        # will be set in self.forward() if defined in AccessMixin config
        self.interctc_capture_at_layers = None

    def forward_for_export(
        self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None
    ):
        """
        Forward function for model export. Please see `forward()` for more details.
        """
        if cache_last_channel is not None:
            cache_last_channel = cache_last_channel.transpose(0, 1)
            cache_last_time = cache_last_time.transpose(0, 1)

        rets = self.forward_internal(
            audio_signal,
            length,
            cache_last_channel=cache_last_channel,
            cache_last_time=cache_last_time,
            cache_last_channel_len=cache_last_channel_len,
        )
        rets = self.streaming_post_process(rets, keep_all_outputs=False)
        if len(rets) == 2:
            return rets
        elif rets[2] is None and rets[3] is None and rets[4] is None:
            return (rets[0], rets[1])
        else:
            return (
                rets[0],
                rets[1],
                rets[2].transpose(0, 1),
                rets[3].transpose(0, 1),
                rets[4],
            )

    def streaming_post_process(self, rets, keep_all_outputs=True):
        """
        Post-process the output of the forward function for streaming.

        Args:
            rets: The output of the forward function.
            keep_all_outputs: Whether to keep all outputs.
        """
        if len(rets) == 2:
            return rets[0], rets[1], None, None, None

        (encoded, encoded_len, cache_last_channel_next, cache_last_time_next, cache_last_channel_next_len) = rets

        if cache_last_channel_next is not None and self.streaming_cfg.last_channel_cache_size >= 0:
            if self.streaming_cfg.last_channel_cache_size > 0:
                cache_last_channel_next = cache_last_channel_next[
                    :, :, -self.streaming_cfg.last_channel_cache_size :, :
                ]

        if self.streaming_cfg.valid_out_len > 0 and (not keep_all_outputs or self.att_context_style == "regular"):
            encoded = encoded[:, :, : self.streaming_cfg.valid_out_len]
            encoded_len = torch.clamp(encoded_len, max=self.streaming_cfg.valid_out_len)

        return (encoded, encoded_len, cache_last_channel_next, cache_last_time_next, cache_last_channel_next_len)

    def forward(
        self,
        audio_signal,
        length,
        cache_last_channel=None,
        cache_last_time=None,
        cache_last_channel_len=None,
        bypass_pre_encode=False,
    ):
        """
        Forward function for the ConformerEncoder accepting an audio signal and its corresponding length.
        The `audio_signal` input supports two formats depending on the `bypass_pre_encode` boolean flag.
        This determines the required format of the input variable `audio_signal`:
        (1) bypass_pre_encode = False (default):
            `audio_signal` must be a tensor containing audio features.
            Shape: (batch, self._feat_in, n_frames)
        (2) bypass_pre_encode = True:
            `audio_signal` must be a tensor containing pre-encoded embeddings.
            Shape: (batch, n_frame, self.d_model)
        """
        if not bypass_pre_encode and audio_signal.shape[-2] != self._feat_in:
            raise ValueError(
                f"If bypass_pre_encode is False, audio_signal should have shape "
                f"(batch, {self._feat_in}, n_frame) but got last dimension {audio_signal.shape[-2]}."
            )
        if bypass_pre_encode and audio_signal.shape[-1] != self.d_model:
            raise ValueError(
                f"If bypass_pre_encode is True, audio_signal should have shape "
                f"(batch, n_frame, {self.d_model}) but got last dimension {audio_signal.shape[-1]}."
            )

        if bypass_pre_encode:
            self.update_max_seq_length(
                seq_length=audio_signal.size(2) * self.subsampling_factor, device=audio_signal.device
            )
        else:
            self.update_max_seq_length(seq_length=audio_signal.size(2), device=audio_signal.device)
        return self.forward_internal(
            audio_signal,
            length,
            cache_last_channel=cache_last_channel,
            cache_last_time=cache_last_time,
            cache_last_channel_len=cache_last_channel_len,
            bypass_pre_encode=bypass_pre_encode,
        )

    def forward_internal(
        self,
        audio_signal,
        length,
        cache_last_channel=None,
        cache_last_time=None,
        cache_last_channel_len=None,
        bypass_pre_encode=False,
    ):
        """
        The `audio_signal` input supports two formats depending on the `bypass_pre_encode` boolean flag.
        This determines the required format of the input variable `audio_signal`:
        (1) bypass_pre_encode = False (default):
            `audio_signal` must be a tensor containing audio features.
            Shape: (batch, self._feat_in, n_frames)
        (2) bypass_pre_encode = True:
            `audio_signal` must be a tensor containing pre-encoded embeddings.
            Shape: (batch, n_frame, self.d_model)

        `bypass_pre_encode=True` is used in cases where frame-level, context-independent embeddings are
        needed to be saved or reused (e.g., speaker cache in streaming speaker diarization).
        """
        if length is None:
            length = audio_signal.new_full(
                (audio_signal.size(0),), audio_signal.size(-1), dtype=torch.int64, device=audio_signal.device
            )

        # select a random att_context_size with the distribution specified by att_context_probs during training
        # for non-validation cases like test, validation or inference, it uses the first mode in self.att_context_size
        if self.training and len(self.att_context_size_all) > 1:
            cur_att_context_size = random.choices(self.att_context_size_all, weights=self.att_context_probs)[0]
        else:
            cur_att_context_size = self.att_context_size

        if not bypass_pre_encode:
            audio_signal = torch.transpose(audio_signal, 1, 2)

            if isinstance(self.pre_encode, nn.Linear):
                audio_signal = self.pre_encode(audio_signal)
            else:
                audio_signal, length = self.pre_encode(x=audio_signal, lengths=length)
                length = length.to(torch.int64)
                # `self.streaming_cfg` is set by setup_streaming_cfg(), called in the init
                if self.streaming_cfg.drop_extra_pre_encoded > 0 and cache_last_channel is not None:
                    audio_signal = audio_signal[:, self.streaming_cfg.drop_extra_pre_encoded :, :]
                    length = (length - self.streaming_cfg.drop_extra_pre_encoded).clamp(min=0)

            if self.reduction_position is not None and cache_last_channel is not None:
                raise ValueError("Caching with reduction feature is not supported yet!")

        max_audio_length = audio_signal.size(1)
        if cache_last_channel is not None:
            cache_len = self.streaming_cfg.last_channel_cache_size
            cache_keep_size = max_audio_length - self.streaming_cfg.cache_drop_size
            max_audio_length = max_audio_length + cache_len
            padding_length = length + cache_len
            offset = torch.neg(cache_last_channel_len) + cache_len
        else:
            padding_length = length
            cache_last_channel_next = None
            cache_len = 0
            offset = None

        audio_signal, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len)

        # Create the self-attention and padding masks
        pad_mask, att_mask = self._create_masks(
            att_context_size=cur_att_context_size,
            padding_length=padding_length,
            max_audio_length=max_audio_length,
            offset=offset,
            device=audio_signal.device,
        )

        if cache_last_channel is not None:
            pad_mask = pad_mask[:, cache_len:]
            if att_mask is not None:
                att_mask = att_mask[:, cache_len:]
            # Convert caches from the tensor to list
            cache_last_time_next = []
            cache_last_channel_next = []

        for lth, (drop_prob, layer) in enumerate(zip(self.layer_drop_probs, self.layers)):
            original_signal = audio_signal
            if cache_last_channel is not None:
                cache_last_channel_cur = cache_last_channel[lth]
                cache_last_time_cur = cache_last_time[lth]
            else:
                cache_last_channel_cur = None
                cache_last_time_cur = None
            audio_signal = layer(
                x=audio_signal,
                att_mask=att_mask,
                pos_emb=pos_emb,
                pad_mask=pad_mask,
                cache_last_channel=cache_last_channel_cur,
                cache_last_time=cache_last_time_cur,
            )
            if cache_last_channel_cur is not None:
                (audio_signal, cache_last_channel_cur, cache_last_time_cur) = audio_signal
                cache_last_channel_next.append(cache_last_channel_cur)
                cache_last_time_next.append(cache_last_time_cur)

            # applying stochastic depth logic from https://arxiv.org/abs/2102.03216
            if self.training and drop_prob > 0.0:
                should_drop = torch.rand(1) < drop_prob
                # adjusting to match expectation
                if should_drop:
                    # that's not efficient, but it's hard to implement distributed
                    # version of dropping layers without deadlock or random seed meddling
                    # so multiplying the signal by 0 to ensure all weights get gradients
                    audio_signal = audio_signal * 0.0 + original_signal
                else:
                    # not doing this operation if drop prob is 0 as it's identity in that case
                    audio_signal = (audio_signal - original_signal) / (1.0 - drop_prob) + original_signal

            if self.reduction_position == lth:
                audio_signal, length = self.reduction_subsampling(x=audio_signal, lengths=length)
                max_audio_length = audio_signal.size(1)
                # Don't update the audio_signal here because then it will again scale the audio_signal
                # and cause an increase in the WER
                _, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len)
                pad_mask, att_mask = self._create_masks(
                    att_context_size=cur_att_context_size,
                    padding_length=length,
                    max_audio_length=max_audio_length,
                    offset=offset,
                    device=audio_signal.device,
                )

            # saving tensors if required for interctc loss
            # if self.is_access_enabled(getattr(self, "model_guid", None)):
            #     if self.interctc_capture_at_layers is None:
            #         self.interctc_capture_at_layers = self.access_cfg.get('interctc', {}).get('capture_layers', [])
            #     if lth in self.interctc_capture_at_layers:
            #         lth_audio_signal = audio_signal
            #         if self.out_proj is not None:
            #             lth_audio_signal = self.out_proj(audio_signal)
            #         # shape is the same as the shape of audio_signal output, i.e. [B, D, T]
            #         self.register_accessible_tensor(
            #             name=f'interctc/layer_output_{lth}', tensor=torch.transpose(lth_audio_signal, 1, 2)
            #         )
            #         self.register_accessible_tensor(name=f'interctc/layer_length_{lth}', tensor=length)

        if self.out_proj is not None:
            audio_signal = self.out_proj(audio_signal)

        # Reduction
        if self.reduction_position == -1:
            audio_signal, length = self.reduction_subsampling(x=audio_signal, lengths=length)

        audio_signal = torch.transpose(audio_signal, 1, 2)
        length = length.to(dtype=torch.int64)

        if cache_last_channel is not None:
            cache_last_channel_next = torch.stack(cache_last_channel_next, dim=0)
            cache_last_time_next = torch.stack(cache_last_time_next, dim=0)
            return (
                audio_signal,
                length,
                cache_last_channel_next,
                cache_last_time_next,
                torch.clamp(cache_last_channel_len + cache_keep_size, max=cache_len),
            )
        else:
            return audio_signal, length

    def update_max_seq_length(self, seq_length: int, device):
        """
        Updates the maximum sequence length for the model.

        Args:
            seq_length (int): New maximum sequence length.
            device (torch.device): Device to use for computations.
        """
        # Find global max audio length across all nodes
        if self.sync_max_audio_length and torch.distributed.is_initialized():
            global_max_len = torch.tensor([seq_length], dtype=torch.float32, device=device)

            # Update across all ranks in the distributed system
            torch.distributed.all_reduce(global_max_len, op=torch.distributed.ReduceOp.MAX)

            seq_length = global_max_len.int().item()

        if seq_length > self.max_audio_length:
            self.set_max_audio_length(seq_length)

    def set_max_audio_length(self, max_audio_length):
        """
        Sets maximum input length.
        Pre-calculates internal seq_range mask.

        Args:
            max_audio_length (int): New maximum sequence length.
        """
        self.max_audio_length = max_audio_length
        device = next(self.parameters()).device
        dtype = next(self.parameters()).dtype
        self.pos_enc.extend_pe(max_audio_length, device, dtype)

    def _create_masks(self, att_context_size, padding_length, max_audio_length, offset, device):
        if self.self_attention_model != "rel_pos_local_attn":
            att_mask = torch.ones(1, max_audio_length, max_audio_length, dtype=torch.bool, device=device)

            if self.att_context_style == "regular":
                if att_context_size[0] >= 0:
                    att_mask = att_mask.triu(diagonal=-att_context_size[0])
                if att_context_size[1] >= 0:
                    att_mask = att_mask.tril(diagonal=att_context_size[1])
            elif self.att_context_style == "chunked_limited":
                # When right context is unlimited, just the left side of the masking need to get updated
                if att_context_size[1] == -1:
                    if att_context_size[0] >= 0:
                        att_mask = att_mask.triu(diagonal=-att_context_size[0])
                else:
                    chunk_size = att_context_size[1] + 1
                    # left_chunks_num specifies the number of chunks to be visible by each chunk on the left side
                    if att_context_size[0] >= 0:
                        left_chunks_num = att_context_size[0] // chunk_size
                    else:
                        left_chunks_num = 10000

                    chunk_idx = torch.arange(0, max_audio_length, dtype=torch.int, device=att_mask.device)
                    chunk_idx = torch.div(chunk_idx, chunk_size, rounding_mode="trunc")
                    diff_chunks = chunk_idx.unsqueeze(1) - chunk_idx.unsqueeze(0)
                    chunked_limited_mask = torch.logical_and(
                        torch.le(diff_chunks, left_chunks_num), torch.ge(diff_chunks, 0)
                    )
                    att_mask = torch.logical_and(att_mask, chunked_limited_mask.unsqueeze(0))
        else:
            att_mask = None

        # pad_mask is the masking to be used to ignore paddings
        pad_mask = torch.arange(0, max_audio_length, device=device).expand(
            padding_length.size(0), -1
        ) < padding_length.unsqueeze(-1)

        if offset is not None:
            pad_mask_off = torch.arange(0, max_audio_length, device=device).expand(
                padding_length.size(0), -1
            ) >= offset.unsqueeze(-1)
            pad_mask = pad_mask_off.logical_and(pad_mask)

        if att_mask is not None:
            # pad_mask_for_att_mask is the mask which helps to ignore paddings
            pad_mask_for_att_mask = pad_mask.unsqueeze(1).repeat([1, max_audio_length, 1])
            pad_mask_for_att_mask = torch.logical_and(pad_mask_for_att_mask, pad_mask_for_att_mask.transpose(1, 2))
            # att_mask is the masking to be used by the MHA layers to ignore the tokens not supposed to be visible
            att_mask = att_mask[:, :max_audio_length, :max_audio_length]
            # paddings should also get ignored, so pad_mask_for_att_mask is used to ignore their corresponding scores
            att_mask = torch.logical_and(pad_mask_for_att_mask, att_mask.to(pad_mask_for_att_mask.device))
            att_mask = ~att_mask

        pad_mask = ~pad_mask
        return pad_mask, att_mask

    def enable_pad_mask(self, on=True):
        """
        Enables or disables the pad mask and assign the boolean state `on`.

        Returns:
            mask (bool): The current state of the pad mask.
        """
        # On inference, user may choose to disable pad mask
        mask = self.use_pad_mask
        self.use_pad_mask = on
        return mask

    def _calc_context_sizes(
        self, att_context_size, att_context_probs, att_context_style, conv_context_size, conv_kernel_size
    ):
        # convert att_context_size to a standard list of lists
        if att_context_size:
            att_context_size_all = list(att_context_size)
            if isinstance(att_context_size_all[0], int):
                att_context_size_all = [att_context_size_all]
            for i, att_cs in enumerate(att_context_size_all):
                # if isinstance(att_cs, ListConfig):
                #     att_context_size_all[i] = list(att_cs)
                if att_context_style == "chunked_limited":
                    if att_cs[0] > 0 and att_cs[0] % (att_cs[1] + 1) > 0:
                        raise ValueError(f"att_context_size[{i}][0] % (att_context_size[{i}][1] + 1) should be zero!")
                    if att_cs[1] < 0 and len(att_context_size_all) <= 1:
                        raise ValueError(
                            f"Right context (att_context_size[{i}][1]) can not be unlimited for chunked_limited style!"
                        )
        else:
            att_context_size_all = [[-1, -1]]

        if att_context_probs:
            if len(att_context_probs) != len(att_context_size_all):
                raise ValueError("The size of the att_context_probs should be the same as att_context_size.")
            att_context_probs = list(att_context_probs)
            if sum(att_context_probs) != 1:
                raise ValueError(
                    "The sum of numbers in att_context_probs should be equal to one to be a distribution."
                )
        else:
            att_context_probs = [1.0 / len(att_context_size_all)] * len(att_context_size_all)

        if conv_context_size is not None:
            # if isinstance(conv_context_size, ListConfig):
            #     conv_context_size = list(conv_context_size)
            if not isinstance(conv_context_size, list) and not isinstance(conv_context_size, str):
                raise ValueError(
                    "Invalid conv_context_size! It should be the string 'causal' or a list of two integers."
                )
            if conv_context_size == "causal":
                conv_context_size = [conv_kernel_size - 1, 0]
            else:
                if conv_context_size[0] + conv_context_size[1] + 1 != conv_kernel_size:
                    raise ValueError(f"Invalid conv_context_size: {self.conv_context_size}!")
        else:
            conv_context_size = [(conv_kernel_size - 1) // 2, (conv_kernel_size - 1) // 2]
        return att_context_size_all, att_context_size_all[0], att_context_probs, conv_context_size

    def set_default_att_context_size(self, att_context_size):
        """
        Sets the default attention context size from `att_context_size` argument.

        Args:
            att_context_size (list): The attention context size to be set.
        """
        if att_context_size not in self.att_context_size_all:
            logging.warning(
                f"att_context_size={att_context_size} is not among the list of the supported "
                f"look-aheads: {self.att_context_size_all}"
            )
        if att_context_size is not None:
            self.att_context_size = att_context_size

        self.setup_streaming_params()

    def setup_streaming_params(
        self,
        chunk_size: int = None,
        shift_size: int = None,
        left_chunks: int = None,
        att_context_size: list = None,
        max_context: int = 10000,
    ):
        """
        This function sets the needed values and parameters to perform streaming.
        The configuration would be stored in self.streaming_cfg.
        The streaming configuration is needed to simulate streaming inference.

        Args:
            chunk_size (int): overrides the chunk size
            shift_size (int): overrides the shift size for chunks
            left_chunks (int): overrides the number of left chunks visible to each chunk
            max_context (int): the value used for the cache size of last_channel layers
                               if left context is set to infinity (-1)
                               Defaults to -1 (means feat_out is d_model)
        """
        streaming_cfg = CacheAwareStreamingConfig()

        # When att_context_size is not specified, it uses the default_att_context_size
        if att_context_size is None:
            att_context_size = self.att_context_size

        if chunk_size is not None:
            if chunk_size < 1:
                raise ValueError("chunk_size needs to be a number larger or equal to one.")
            lookahead_steps = chunk_size - 1
            streaming_cfg.cache_drop_size = chunk_size - shift_size
        elif self.att_context_style == "chunked_limited":
            lookahead_steps = att_context_size[1]
            streaming_cfg.cache_drop_size = 0
        elif self.att_context_style == "regular":
            lookahead_steps = att_context_size[1] * self.n_layers + self.conv_context_size[1] * self.n_layers
            streaming_cfg.cache_drop_size = lookahead_steps
        else:
            streaming_cfg.cache_drop_size = 0
            lookahead_steps = None

        if chunk_size is None:
            streaming_cfg.last_channel_cache_size = att_context_size[0] if att_context_size[0] >= 0 else max_context
        else:
            if left_chunks is None:
                streaming_cfg.last_channel_cache_size = (
                    att_context_size[0] if att_context_size[0] >= 0 else max_context
                )
                logging.warning(
                    f"left_chunks is not set. Setting it to default: {streaming_cfg.last_channel_cache_size}."
                )
            else:
                streaming_cfg.last_channel_cache_size = left_chunks * chunk_size

        if hasattr(self.pre_encode, "get_sampling_frames"):
            sampling_frames = self.pre_encode.get_sampling_frames()
        else:
            sampling_frames = 0

        if isinstance(sampling_frames, list):
            streaming_cfg.chunk_size = [
                sampling_frames[0] + self.subsampling_factor * lookahead_steps,
                sampling_frames[1] + self.subsampling_factor * lookahead_steps,
            ]
        else:
            streaming_cfg.chunk_size = sampling_frames * (1 + lookahead_steps)

        if isinstance(sampling_frames, list):
            streaming_cfg.shift_size = [
                sampling_frames[0] + sampling_frames[1] * (lookahead_steps - streaming_cfg.cache_drop_size),
                sampling_frames[1] + sampling_frames[1] * (lookahead_steps - streaming_cfg.cache_drop_size),
            ]
        else:
            streaming_cfg.shift_size = sampling_frames * (1 + lookahead_steps - streaming_cfg.cache_drop_size)

        if isinstance(streaming_cfg.shift_size, list):
            streaming_cfg.valid_out_len = (
                streaming_cfg.shift_size[1] - sampling_frames[1]
            ) // self.subsampling_factor + 1
        else:
            streaming_cfg.valid_out_len = streaming_cfg.shift_size // self.subsampling_factor

        if hasattr(self.pre_encode, "get_streaming_cache_size"):
            streaming_cfg.pre_encode_cache_size = self.pre_encode.get_streaming_cache_size()
        else:
            streaming_cfg.pre_encode_cache_size = 0

        if isinstance(streaming_cfg.pre_encode_cache_size, list):
            if streaming_cfg.pre_encode_cache_size[1] >= 1:
                streaming_cfg.drop_extra_pre_encoded = (
                    1 + (streaming_cfg.pre_encode_cache_size[1] - 1) // self.subsampling_factor
                )
            else:
                streaming_cfg.drop_extra_pre_encoded = 0
        else:
            streaming_cfg.drop_extra_pre_encoded = streaming_cfg.pre_encode_cache_size // self.subsampling_factor

        for m in self.layers.modules():
            if hasattr(m, "_max_cache_len"):
                if isinstance(m, MultiHeadAttention):
                    m.cache_drop_size = streaming_cfg.cache_drop_size
                if isinstance(m, CausalConv1D):
                    m.cache_drop_size = streaming_cfg.cache_drop_size

        self.streaming_cfg = streaming_cfg

    def get_initial_cache_state(self, batch_size=1, dtype=torch.float32, device=None, max_dim=0):
        if device is None:
            device = next(self.parameters()).device
        if max_dim > 0:
            create_tensor = torch.randn
        else:
            create_tensor = torch.zeros
        last_time_cache_size = self.conv_context_size[0]
        cache_last_channel = create_tensor(
            (
                len(self.layers),
                batch_size,
                self.streaming_cfg.last_channel_cache_size,
                self.d_model,
            ),
            device=device,
            dtype=dtype,
        )
        cache_last_time = create_tensor(
            (len(self.layers), batch_size, self.d_model, last_time_cache_size),
            device=device,
            dtype=dtype,
        )
        if max_dim > 0:
            cache_last_channel_len = torch.randint(
                0,
                min(max_dim, self.streaming_cfg.last_channel_cache_size),
                (batch_size,),
                device=device,
                dtype=torch.int64,
            )
            for i in range(batch_size):
                cache_last_channel[:, i, cache_last_channel_len[i] :, :] = 0
                # what is the right rule to zero out cache_last_time?
                if cache_last_channel_len[i] == 0:
                    cache_last_time[:, i, :, :] = 0
        else:
            cache_last_channel_len = torch.zeros(batch_size, device=device, dtype=torch.int64)
        return cache_last_channel, cache_last_time, cache_last_channel_len

    def change_attention_model(
        self,
        self_attention_model: str = None,
        att_context_size: list[int] = None,
        update_config: bool = True,
        device: torch.device = None,
    ):
        """
        Update the self_attention_model which changes the positional encoding and attention layers.

        Args:
            self_attention_model (str): type of the attention layer and positional encoding

                'rel_pos':
                    relative positional embedding and Transformer-XL

                'rel_pos_local_attn':
                    relative positional embedding and Transformer-XL with local attention using
                    overlapping windows. Attention context is determined by att_context_size parameter.

                'abs_pos':
                    absolute positional embedding and Transformer

                If None is provided, the self_attention_model isn't changed. Defaults to None.
            att_context_size (List[int]): List of 2 ints corresponding to left and right attention context sizes,
                or None to keep as it is. Defaults to None.
            update_config (bool): Whether to update the config or not with the new attention model.
                Defaults to True.
            device (torch.device): If provided, new layers will be moved to the device.
                Defaults to None.
        """

        if att_context_size:
            att_context_size = list(att_context_size)
        else:
            att_context_size = self.att_context_size

        if self_attention_model is None:
            self_attention_model = self.self_attention_model

        if self_attention_model == 'rel_pos_local_attn' and max(att_context_size) <= 0:
            raise ValueError("When using local attention, context size must be set > 0")

        if self_attention_model == "rel_pos":
            new_pos_enc = RelPositionalEncoding(
                d_model=self._cfg.d_model,
                dropout_rate=self._cfg.dropout,
                max_len=self._cfg.pos_emb_max_len,
                xscale=self.xscale,
                dropout_rate_emb=self._cfg.dropout_emb,
            )
        elif self_attention_model == 'rel_pos_local_attn':
            new_pos_enc = LocalAttRelPositionalEncoding(
                att_context_size=att_context_size,
                d_model=self._cfg.d_model,
                dropout_rate=self._cfg.dropout,
                max_len=self._cfg.pos_emb_max_len,
                xscale=self.xscale,
                dropout_rate_emb=self._cfg.dropout_emb,
            )
        elif self_attention_model == "abs_pos":
            new_pos_enc = PositionalEncoding(
                d_model=self._cfg.d_model,
                dropout_rate=self._cfg.dropout,
                max_len=self._cfg.pos_emb_max_len,
                xscale=self.xscale,
            )
        else:
            raise ValueError(f"Not valid self_attention_model: '{self_attention_model}'!")

        if device is not None:
            new_pos_enc = new_pos_enc.to(device=device)
        del self.pos_enc
        self.pos_enc = new_pos_enc
        self.self_attention_model = self_attention_model
        self.att_context_size = att_context_size
        self.set_max_audio_length(self.pos_emb_max_len)

        for _, m in self.named_modules():
            if type(m) == ConformerLayer:
                if self_attention_model == 'rel_pos':
                    new_attn = RelPositionMultiHeadAttention(
                        n_head=self._cfg.n_heads,
                        n_feat=self._cfg.d_model,
                        dropout_rate=self._cfg.dropout_att,
                        max_cache_len=att_context_size[0],
                        pos_bias_u=None,
                        pos_bias_v=None,
                        use_pytorch_sdpa=self.use_pytorch_sdpa,
                        use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends,
                    )
                elif self_attention_model == 'rel_pos_local_attn':
                    new_attn = RelPositionMultiHeadAttentionLongformer(
                        n_head=self._cfg.n_heads,
                        n_feat=self._cfg.d_model,
                        dropout_rate=self._cfg.dropout_att,
                        max_cache_len=att_context_size[0],
                        att_context_size=att_context_size,
                        pos_bias_u=None,
                        pos_bias_v=None,
                        use_pytorch_sdpa=self.use_pytorch_sdpa,
                        use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends,
                    )
                elif self_attention_model == 'abs_pos':
                    new_attn = MultiHeadAttention(
                        n_head=self._cfg.n_heads,
                        n_feat=self._cfg.d_model,
                        dropout_rate=self._cfg.dropout_att,
                        max_cache_len=att_context_size[0],
                        use_pytorch_sdpa=self.use_pytorch_sdpa,
                        use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends,
                    )
                else:
                    raise ValueError(
                        f"'{self_attention_model}' is not not a valid value for 'self_attention_model', "
                        f"valid values can be from ['rel_pos', 'rel_pos_local_attn', 'abs_pos']"
                    )
                if device is not None:
                    new_attn = new_attn.to(device=device)
                new_attn.load_state_dict(m.self_attn.state_dict(), strict=False)
                del m.self_attn
                m.self_attn = new_attn
                m.self_attention_model = self_attention_model

        if update_config:
            with open_dict(self._cfg):
                self._cfg.self_attention_model = self_attention_model
                self._cfg.att_context_size = att_context_size

    def change_subsampling_conv_chunking_factor(self, subsampling_conv_chunking_factor: int):
        """
        Update the conv_chunking_factor (int)
        Default is 1 (auto)
        Set it to -1 (disabled) or to a specific value (power of 2) if you OOM in the conv subsampling layers


        Args:
            subsampling_conv_chunking_factor (int)
        """

        if not hasattr(self.pre_encode, "change_subsampling_conv_chunking_factor"):
            logging.info("Model pre_encoder doesn't have a change_subsampling_conv_chunking_factor method ")
            return

        self.pre_encode.change_subsampling_conv_chunking_factor(
            subsampling_conv_chunking_factor=subsampling_conv_chunking_factor
        )

```

## /src/liquid_audio/model/conformer/mha.py

```py path="/src/liquid_audio/model/conformer/mha.py" 
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""NeMo multhead attention

adapted from https://github.com/NVIDIA/NeMo/blob/c83adff36efaa549f7bdd26e97c01a60e9f9026b/nemo/collections/asr/parts/submodules/multi_head_attention.py
"""


import torch
from torch import nn
import math

from .utils import avoid_float16_autocast_context

INF_VAL = 10000.0

class PositionalEncoding(torch.nn.Module):
    """Fixed sinusoidal positional encoding.
    Args:
        d_model (int): embedding dim
        dropout_rate (float): dropout rate
        max_len (int): maximum input length
        xscale (bool): whether to scale the input by sqrt(d_model)
        dropout_rate_emb (float): dropout rate for the positional embeddings
    """

    def __init__(self, d_model, dropout_rate, max_len=5000, xscale=None, dropout_rate_emb=0.0):
        """Construct an PositionalEncoding object."""
        super().__init__()
        self.d_model = d_model
        self.xscale = xscale
        self.dropout = torch.nn.Dropout(p=dropout_rate)
        self.max_len = max_len
        if dropout_rate_emb > 0:
            self.dropout_emb = nn.Dropout(dropout_rate_emb)
        else:
            self.dropout_emb = None

    def create_pe(self, positions, dtype):
        pos_length = positions.size(0)
        pe = torch.zeros(pos_length, self.d_model, device=positions.device)
        div_term = torch.exp(
            torch.arange(0, self.d_model, 2, dtype=torch.float32, device=positions.device)
            * -(math.log(INF_VAL) / self.d_model)
        )
        pe[:, 0::2] = torch.sin(positions * div_term)
        pe[:, 1::2] = torch.cos(positions * div_term)
        pe = pe.unsqueeze(0).to(dtype)
        if hasattr(self, 'pe'):
            self.pe = pe
        else:
            self.register_buffer('pe', pe, persistent=False)

    def extend_pe(self, length, device, dtype):
        """Reset and extend the positional encodings if needed."""
        if hasattr(self, 'pe') and self.pe.size(1) >= length:
            return
        positions = torch.arange(0, length, dtype=torch.float32, device=device).unsqueeze(1)
        self.create_pe(positions=positions, dtype=dtype)

    def forward(self, x: torch.Tensor, cache_len=0):
        """Adds positional encoding.
        Args:
            x (torch.Tensor): Input. Its shape is (batch, time, feature_size)
            cache_len (int): the size of the cache which is used to shift positions
        Returns:
            x+pos_emb (torch.Tensor): Its shape is (batch, time, feature_size)
            pos_emb (torch.Tensor): Its shape is (1, time, feature_size)
        """
        input_len = x.size(1) + cache_len
        if self.xscale:
            x = x * self.xscale
        pos_emb = self.pe[:, :input_len]
        if self.dropout_emb:
            pos_emb = self.dropout_emb(pos_emb)
        x = x + pos_emb
        return self.dropout(x), pos_emb


class RelPositionalEncoding(PositionalEncoding):
    """Relative positional encoding for TransformerXL's layers
    See : Appendix B in https://arxiv.org/abs/1901.02860
    Args:
        d_model (int): embedding dim
        dropout_rate (float): dropout rate
        max_len (int): maximum input length
        xscale (bool): whether to scale the input by sqrt(d_model)
        dropout_rate_emb (float): dropout rate for the positional embeddings
    """

    def extend_pe(self, length, device, dtype):
        """Reset and extend the positional encodings if needed."""
        needed_size = 2 * length - 1
        if hasattr(self, 'pe') and self.pe.size(1) >= needed_size:
            return
        # positions would be from negative numbers to positive
        # positive positions would be used for left positions and negative for right positions
        positions = torch.arange(length - 1, -length, -1, dtype=torch.float32, device=device).unsqueeze(1)
        self.create_pe(positions=positions, dtype=dtype)

    def forward(self, x, cache_len=0):
        """Compute positional encoding.
        Args:
            x (torch.Tensor): Input. Its shape is (batch, time, feature_size)
            cache_len (int): the size of the cache which is used to shift positions
        Returns:
            x (torch.Tensor): Its shape is (batch, time, feature_size)
            pos_emb (torch.Tensor): Its shape is (1, time, feature_size)
        """

        if self.xscale:
            x = x * self.xscale

        # center_pos would be the index of position 0
        # negative positions would be used for right and positive for left tokens
        # for input of length L, 2*L-1 positions are needed, positions from (L-1) to -(L-1)
        input_len = x.size(1) + cache_len
        center_pos = self.pe.size(1) // 2 + 1
        start_pos = center_pos - input_len
        end_pos = center_pos + input_len - 1
        pos_emb = self.pe[:, start_pos:end_pos]
        if self.dropout_emb:
            pos_emb = self.dropout_emb(pos_emb)
        return self.dropout(x), pos_emb


class MultiHeadAttention(nn.Module):
    """Multi-Head Attention layer of Transformer.
    Args:
        n_head (int): number of heads
        n_feat (int): size of the features
        dropout_rate (float): dropout rate
        use_bias (bool): whether to remove bias in linear and conv layers
        use_pytorch_sdpa (bool): use torch sdpa instead of manual attention
        use_pytorch_sdpa_backends list[str]: list of backend names to use in sdpa. None or empty list means all backends. e.g. ["MATH"]
    """

    def __init__(
        self,
        n_head,
        n_feat,
        dropout_rate,
        max_cache_len=0,
        use_bias=True,
        use_pytorch_sdpa=False,
        use_pytorch_sdpa_backends=None,
    ):
        """Construct an MultiHeadedAttention object."""
        super(MultiHeadAttention, self).__init__()
        self.use_pytorch_sdpa = use_pytorch_sdpa
        if self.use_pytorch_sdpa and use_pytorch_sdpa_backends:
            use_pytorch_sdpa_backends = list(
                map(
                    lambda backend_name: getattr(torch.nn.attention.SDPBackend, backend_name),
                    use_pytorch_sdpa_backends,
                )
            )
        self.use_pytorch_sdpa_backends = use_pytorch_sdpa_backends

        self.cache_drop_size = None
        self.use_bias = use_bias
        self.dropout_rate = dropout_rate
        assert n_feat % n_head == 0
        # We assume d_v always equals d_k
        self.d_k = n_feat // n_head
        self.s_d_k = math.sqrt(self.d_k)
        self.h = n_head
        self.linear_q = nn.Linear(n_feat, n_feat, bias=use_bias)
        self.linear_k = nn.Linear(n_feat, n_feat, bias=use_bias)
        self.linear_v = nn.Linear(n_feat, n_feat, bias=use_bias)
        self.linear_out = nn.Linear(n_feat, n_feat, bias=use_bias)
        self.dropout = nn.Dropout(p=dropout_rate)

        self._max_cache_len = max_cache_len

    def forward_qkv(self, query, key, value):
        """Transforms query, key and value.
        Args:
            query (torch.Tensor): (batch, time1, size)
            key (torch.Tensor): (batch, time2, size)
            value (torch.Tensor): (batch, time2, size)
        returns:
            q (torch.Tensor): (batch, head, time1, size)
            k (torch.Tensor): (batch, head, time2, size)
            v (torch.Tensor): (batch, head, time2, size)
        """
        n_batch = query.size(0)
        t1 = query.size(1)
        t2 = key.size(1)
        q = self.linear_q(query).view(n_batch, t1, self.h, self.d_k)
        k = self.linear_k(key).view(n_batch, t2, self.h, self.d_k)
        v = self.linear_v(value).view(n_batch, t2, self.h, self.d_k)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        return q, k, v

    def forward_attention(self, value, scores, mask):
        """Compute attention context vector.
        Args:
            value (torch.Tensor): (batch, head, time2, size)
            scores(torch.Tensor): (batch, head, time1, time2)
            mask(torch.Tensor): (batch, time1, time2)
        returns:
            value (torch.Tensor): transformed `value` (batch, time2, d_model) weighted by the attention scores
        """
        n_batch = value.size(0)
        time = scores.size(2)
        if mask is not None:
            mask = mask.unsqueeze(1)  # (batch, 1, time1, time2)
            scores = scores.masked_fill(mask, -INF_VAL)
            attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)  # (batch, head, time1, time2)
        else:
            attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)

        p_attn = self.dropout(attn)
        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
        x = x.transpose(1, 2).reshape(n_batch, time, self.h * self.d_k)  # (batch, time1, d_model)

        return self.linear_out(x)  # (batch, time1, d_model)

    def forward(self, query, key, value, mask, pos_emb=None, cache=None):
        """Compute 'Scaled Dot Product Attention'.
        Args:
            query (torch.Tensor): (batch, time1, size)
            key (torch.Tensor): (batch, time2, size)
            value(torch.Tensor): (batch, time2, size)
            mask (torch.Tensor): (batch, time1, time2)
            cache (torch.Tensor) : (batch, time_cache, size)

        returns:
            output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention
            cache (torch.Tensor) : (batch, time_cache_next, size)
        """
        key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache)

        if torch.is_autocast_enabled():
            query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32)

        # temporary until we solve this more gracefully
        with avoid_float16_autocast_context():
            q, k, v = self.forward_qkv(query, key, value)

            if self.use_pytorch_sdpa:
                n_batch = value.size(0)

                if mask is not None:
                    mask = ~mask.unsqueeze(1)

                dropout_rate = self.dropout_rate if self.training else 0
                if self.use_pytorch_sdpa_backends:
                    with torch.nn.attention.sdpa_kernel(self.use_pytorch_sdpa_backends):
                        out = torch.nn.functional.scaled_dot_product_attention(
                            q, k, v, attn_mask=mask, dropout_p=dropout_rate
                        )
                else:
                    out = torch.nn.functional.scaled_dot_product_attention(
                        q, k, v, attn_mask=mask, dropout_p=dropout_rate
                    )

                # this IF block can be deleted when https://github.com/pytorch/pytorch/pull/131863 is in the stable version
                if mask is not None:
                    all_masked_rows = torch.all(~mask, dim=-1)
                    all_masked_rows.unsqueeze_(-1)
                    out = out.masked_fill(all_masked_rows, 0.0)

                out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k)  # (batch, time1, d_model)
                out = self.linear_out(out)  # (batch, time1, d_model)
            else:
                scores = torch.matmul(q, k.transpose(-2, -1)) / self.s_d_k
                out = self.forward_attention(v, scores, mask)

        if cache is None:
            return out
        else:
            return out, cache

    def update_cache(self, key, value, query, cache):
        if cache is not None:
            key = value = torch.cat([cache, key], dim=1)
            q_keep_size = query.shape[1] - self.cache_drop_size
            cache = torch.cat([cache[:, q_keep_size:, :], query[:, :q_keep_size, :]], dim=1)
        return key, value, query, cache


class RelPositionMultiHeadAttention(MultiHeadAttention):
    """Multi-Head Attention layer of Transformer-XL with support of relative positional encoding.
    Paper: https://arxiv.org/abs/1901.02860
    Args:
        n_head (int): number of heads
        n_feat (int): size of the features
        dropout_rate (float): dropout rate
        use_bias (bool): whether to apply bias in linear and conv layers of MultiHeadAttention
    """

    def __init__(
        self,
        n_head,
        n_feat,
        dropout_rate,
        pos_bias_u,
        pos_bias_v,
        max_cache_len=0,
        use_bias=True,
        use_pytorch_sdpa=False,
        use_pytorch_sdpa_backends=None,
    ):
        """Construct an RelPositionMultiHeadedAttention object."""
        super().__init__(
            n_head=n_head,
            n_feat=n_feat,
            dropout_rate=dropout_rate,
            max_cache_len=max_cache_len,
            use_bias=use_bias,
            use_pytorch_sdpa=use_pytorch_sdpa,
            use_pytorch_sdpa_backends=use_pytorch_sdpa_backends,
        )
        # linear transformation for positional encoding
        self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
        # these two learnable biases are used in matrix c and matrix d
        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
        if pos_bias_u is None or pos_bias_v is None:
            self.pos_bias_u = nn.Parameter(torch.FloatTensor(self.h, self.d_k))
            self.pos_bias_v = nn.Parameter(torch.FloatTensor(self.h, self.d_k))
            # nn.init.normal_(self.pos_bias_u, 0.0, 0.02)
            # nn.init.normal_(self.pos_bias_v, 0.0, 0.02)
            nn.init.zeros_(self.pos_bias_u)
            nn.init.zeros_(self.pos_bias_v)
        else:
            self.pos_bias_u = pos_bias_u
            self.pos_bias_v = pos_bias_v

    def rel_shift(self, x):
        """Compute relative positional encoding.
        Args:
            x (torch.Tensor): (batch, nheads, time, 2*time-1)
        """
        b, h, qlen, pos_len = x.size()  # (b, h, t1, t2)
        # need to add a column of zeros on the left side of last dimension to perform the relative shifting
        x = torch.nn.functional.pad(x, pad=(1, 0))  # (b, h, t1, t2+1)
        x = x.view(b, h, pos_len+1, qlen)  # (b, h, t2+1, t1)
        # need to drop the first row
        x = x[:, :, 1:].view(b, h, qlen, pos_len)  # (b, h, t1, t2)
        return x

    def forward(self, query, key, value, mask, pos_emb, cache=None):
        """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
        Args:
            query (torch.Tensor): (batch, time1, size)
            key (torch.Tensor): (batch, time2, size)
            value(torch.Tensor): (batch, time2, size)
            mask (torch.Tensor): (batch, time1, time2)
            pos_emb (torch.Tensor) : (batch, time1, size)
            cache (torch.Tensor) : (batch, time_cache, size)

        Returns:
            output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention
            cache (torch.Tensor) : (batch, time_cache_next, size)
        """
        key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache)

        if torch.is_autocast_enabled():
            query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32)

        # temporary until we solve this more gracefully
        with avoid_float16_autocast_context():
            q, k, v = self.forward_qkv(query, key, value)
            q = q.transpose(1, 2)  # (batch, time1, head, d_k)

            n_batch_pos = pos_emb.size(0)
            n_batch = value.size(0)
            p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
            p = p.transpose(1, 2)  # (batch, head, time1, d_k)

            # (batch, head, time1, d_k)
            q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
            # (batch, head, time1, d_k)
            q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)

            # compute attention score
            # first compute matrix a and matrix c
            # as described in https://arxiv.org/abs/1901.02860 Section 3.3
            # (batch, head, time1, time2)

            # compute matrix b and matrix d
            # (batch, head, time1, time2)
            matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
            matrix_bd = self.rel_shift(matrix_bd)

            if self.use_pytorch_sdpa:
                scale_factor = 1 / math.sqrt(q_with_bias_u.size(-1))
                matrix_bd = matrix_bd[:, :, :, : k.size(-2)] * scale_factor

                if mask is not None:
                    mask = mask.unsqueeze(1)
                    matrix_bd.masked_fill_(mask, -INF_VAL)

                dropout_rate = self.dropout_rate if self.training else 0
                if self.use_pytorch_sdpa_backends:
                    with torch.nn.attention.sdpa_kernel(self.use_pytorch_sdpa_backends):
                        out = torch.nn.functional.scaled_dot_product_attention(
                            q_with_bias_u, k, v, attn_mask=matrix_bd, dropout_p=dropout_rate
                        )
                else:
                    out = torch.nn.functional.scaled_dot_product_attention(
                        q_with_bias_u, k, v, attn_mask=matrix_bd, dropout_p=dropout_rate
                    )

                # this IF block can be deleted when https://github.com/pytorch/pytorch/pull/131863 is in the stable version
                if mask is not None:
                    all_masked_rows = torch.all(mask, dim=-1)
                    all_masked_rows.unsqueeze_(-1)
                    all_masked_rows = all_masked_rows.expand(-1, out.size(1), -1, out.size(-1))
                    out = out.masked_fill(all_masked_rows, 0.0)

                out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k)  # (batch, time1, d_model)
                out = self.linear_out(out)  # (batch, time1, d_model)
            else:
                # drops extra elements in the matrix_bd to match the matrix_ac's size
                matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
                matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)]
                scores = (matrix_ac + matrix_bd) / self.s_d_k  # (batch, head, time1, time2)
                out = self.forward_attention(v, scores, mask)

        if cache is None:
            return out
        else:
            return out, cache

```

## /src/liquid_audio/model/conformer/modules.py

```py path="/src/liquid_audio/model/conformer/modules.py" 
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""NeMo conformer modules and layers

adapted from https://github.com/NVIDIA/NeMo/blob/c83adff36efaa549f7bdd26e97c01a60e9f9026b/nemo/collections/asr/parts/submodules/conformer_modules.py
"""

import torch
from torch import nn
from torch.nn import functional as F

from .mha import RelPositionMultiHeadAttention


class ConformerLayer(torch.nn.Module):
    """A single block of the Conformer encoder.

    Args:
        d_model (int): input dimension of MultiheadAttentionMechanism and PositionwiseFeedForward
        d_ff (int): hidden dimension of PositionwiseFeedForward
        self_attention_model (str): type of the attention layer and positional encoding
            'rel_pos': relative positional embedding and Transformer-XL
            'rel_pos_local_attn': relative positional embedding and Transformer-XL with local attention using
                overlapping chunks. Attention context is determined by att_context_size parameter.
            'abs_pos': absolute positional embedding and Transformer
            Default is rel_pos.
        global_tokens (int): number of tokens to be used for global attention.
            Only relevant if self_attention_model is 'rel_pos_local_attn'.
            Defaults to 0.
        global_tokens_spacing (int): how far apart the global tokens are
            Defaults to 1.
        global_attn_separate (bool): whether the q, k, v layers used for global tokens should be separate.
            Defaults to False.
        n_heads (int): number of heads for multi-head attention
        conv_kernel_size (int): kernel size for depthwise convolution in convolution module
        dropout (float): dropout probabilities for linear layers
        dropout_att (float): dropout probabilities for attention distributions
        use_bias (bool): Apply bias to all Linear and Conv1d layers from each ConformerLayer to improve activation flow and stabilize training of huge models.
            Defaults to True.
    """

    def __init__(
        self,
        d_model,
        d_ff,
        self_attention_model='rel_pos',
        global_tokens=0,
        global_tokens_spacing=1,
        global_attn_separate=False,
        n_heads=4,
        conv_kernel_size=31,
        conv_norm_type='batch_norm',
        conv_context_size=None,
        dropout=0.1,
        dropout_att=0.1,
        pos_bias_u=None,
        pos_bias_v=None,
        att_context_size=[-1, -1],
        use_bias=True,
        use_pytorch_sdpa=False,
        use_pytorch_sdpa_backends=None,
    ):
        super().__init__()

        self.use_pytorch_sdpa = use_pytorch_sdpa
        if use_pytorch_sdpa_backends is None:
            use_pytorch_sdpa_backends = []
        self.use_pytorch_sdpa_backends = use_pytorch_sdpa_backends
        self.self_attention_model = self_attention_model
        self.n_heads = n_heads
        self.fc_factor = 0.5

        # first feed forward module
        self.norm_feed_forward1 = nn.LayerNorm(d_model)
        self.feed_forward1 = ConformerFeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout, use_bias=use_bias)

        # convolution module
        self.norm_conv = nn.LayerNorm(d_model)
        self.conv = ConformerConvolution(
            d_model=d_model,
            kernel_size=conv_kernel_size,
            norm_type=conv_norm_type,
            conv_context_size=conv_context_size,
            use_bias=use_bias,
        )

        # multi-headed self-attention module
        self.norm_self_att = nn.LayerNorm(d_model)
        MHA_max_cache_len = att_context_size[0]

        if self_attention_model == 'rel_pos':
            self.self_attn = RelPositionMultiHeadAttention(
                n_head=n_heads,
                n_feat=d_model,
                dropout_rate=dropout_att,
                pos_bias_u=pos_bias_u,
                pos_bias_v=pos_bias_v,
                max_cache_len=MHA_max_cache_len,
                use_bias=use_bias,
                use_pytorch_sdpa=self.use_pytorch_sdpa,
                use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends,
            )
        elif self_attention_model == 'rel_pos_local_attn':
            self.self_attn = RelPositionMultiHeadAttentionLongformer(
                n_head=n_heads,
                n_feat=d_model,
                dropout_rate=dropout_att,
                pos_bias_u=pos_bias_u,
                pos_bias_v=pos_bias_v,
                max_cache_len=MHA_max_cache_len,
                att_context_size=att_context_size,
                global_tokens=global_tokens,
                global_tokens_spacing=global_tokens_spacing,
                global_attn_separate=global_attn_separate,
                use_bias=use_bias,
            )
        elif self_attention_model == 'abs_pos':
            self.self_attn = MultiHeadAttention(
                n_head=n_heads,
                n_feat=d_model,
                dropout_rate=dropout_att,
                max_cache_len=MHA_max_cache_len,
                use_bias=use_bias,
                use_pytorch_sdpa=self.use_pytorch_sdpa,
                use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends,
            )
        else:
            raise ValueError(
                f"'{self_attention_model}' is not not a valid value for 'self_attention_model', "
                f"valid values can be from ['rel_pos', 'rel_pos_local_attn', 'abs_pos']"
            )

        # second feed forward module
        self.norm_feed_forward2 = nn.LayerNorm(d_model)
        self.feed_forward2 = ConformerFeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout, use_bias=use_bias)

        self.dropout = nn.Dropout(dropout)
        self.norm_out = nn.LayerNorm(d_model)

    def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None, cache_last_channel=None, cache_last_time=None):
        """
        Args:
            x (torch.Tensor): input signals (B, T, d_model)
            att_mask (torch.Tensor): attention masks(B, T, T)
            pos_emb (torch.Tensor): (L, 1, d_model)
            pad_mask (torch.tensor): padding mask
            cache_last_channel (torch.tensor) : cache for MHA layers (B, T_cache, d_model)
            cache_last_time (torch.tensor) : cache for convolutional layers (B, d_model, T_cache)
        Returns:
            x (torch.Tensor): (B, T, d_model)
            cache_last_channel (torch.tensor) : next cache for MHA layers (B, T_cache, d_model)
            cache_last_time (torch.tensor) : next cache for convolutional layers (B, d_model, T_cache)
        """
        residual = x
        x = self.norm_feed_forward1(x)
        x = self.feed_forward1(x)
        residual = residual + self.dropout(x) * self.fc_factor

        x = self.norm_self_att(residual)
        if self.self_attention_model == 'rel_pos':
            x = self.self_attn(query=x, key=x, value=x, mask=att_mask, pos_emb=pos_emb, cache=cache_last_channel)
        elif self.self_attention_model == 'rel_pos_local_attn':
            x = self.self_attn(query=x, key=x, value=x, pad_mask=pad_mask, pos_emb=pos_emb, cache=cache_last_channel)
        elif self.self_attention_model == 'abs_pos':
            x = self.self_attn(query=x, key=x, value=x, mask=att_mask, cache=cache_last_channel)
        else:
            x = None

        if x is not None and cache_last_channel is not None:
            (x, cache_last_channel) = x

        residual = residual + self.dropout(x)

        # if self.is_adapter_available():
        #     # Call the MHA adapters
        #     pack_input = {
        #         'x': residual,
        #         'loc': 'mha',
        #         'att_mask': att_mask,
        #         'pos_emb': pos_emb,
        #     }
        #     pack_input = self.forward_enabled_adapters(pack_input)
        #     residual = pack_input['x']

        x = self.norm_conv(residual)
        x = self.conv(x, pad_mask=pad_mask, cache=cache_last_time)
        if cache_last_time is not None:
            (x, cache_last_time) = x
        residual = residual + self.dropout(x)

        x = self.norm_feed_forward2(residual)
        x = self.feed_forward2(x)
        residual = residual + self.dropout(x) * self.fc_factor

        x = self.norm_out(residual)

        # if self.is_adapter_available():
        #     # Call the adapters
        #     pack_input = {
        #         'x': x,
        #         'loc': 'post',
        #     }
        #     pack_input = self.forward_enabled_adapters(pack_input)
        #     x = pack_input['x']

        # if self.is_access_enabled(getattr(self, "model_guid", None)) and self.access_cfg.get(
        #     'save_encoder_tensors', False
        # ):
        #     self.register_accessible_tensor(name='encoder', tensor=x)
        if cache_last_channel is None:
            return x
        else:
            return x, cache_last_channel, cache_last_time


class ConformerConvolution(nn.Module):
    """The convolution module for the Conformer model.
    Args:
        d_model (int): hidden dimension
        kernel_size (int): kernel size for depthwise convolution
        pointwise_activation (str): name of the activation function to be used for the pointwise conv.
            Note that Conformer uses a special key `glu_` which is treated as the original default from
            the paper.
        use_bias (bool): Use bias in all Linear and Conv1d layers improve activation flow and stabilize training of huge models.
            Defaults to True
    """

    def __init__(
        self,
        d_model,
        kernel_size,
        norm_type='batch_norm',
        conv_context_size=None,
        pointwise_activation='glu_',
        use_bias=True,
    ):
        super(ConformerConvolution, self).__init__()
        assert (kernel_size - 1) % 2 == 0
        self.d_model = d_model
        self.kernel_size = kernel_size
        self.norm_type = norm_type
        self.use_bias = use_bias

        if conv_context_size is None:
            conv_context_size = (kernel_size - 1) // 2

        # if pointwise_activation in activation_registry:
        #     self.pointwise_activation = activation_registry[pointwise_activation]()
        #     dw_conv_input_dim = d_model * 2

        #     if hasattr(self.pointwise_activation, 'inplace'):
        #         self.pointwise_activation.inplace = True
        # else:
        assert pointwise_activation == 'glu_'
        self.pointwise_activation = pointwise_activation
        dw_conv_input_dim = d_model

        self.pointwise_conv1 = nn.Conv1d(
            in_channels=d_model,
            out_channels=d_model * 2,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=self.use_bias,
        )

        self.depthwise_conv = CausalConv1D(
            in_channels=dw_conv_input_dim,
            out_channels=dw_conv_input_dim,
            kernel_size=kernel_size,
            stride=1,
            padding=conv_context_size,
            groups=dw_conv_input_dim,
            bias=self.use_bias,
        )

        if norm_type == 'batch_norm':
            self.batch_norm = nn.BatchNorm1d(dw_conv_input_dim)
        elif norm_type == 'instance_norm':
            self.batch_norm = nn.InstanceNorm1d(dw_conv_input_dim)
        elif norm_type == 'layer_norm':
            self.batch_norm = nn.LayerNorm(dw_conv_input_dim)
        elif norm_type == 'fused_batch_norm':
            self.batch_norm = FusedBatchNorm1d(dw_conv_input_dim)
        elif norm_type.startswith('group_norm'):
            num_groups = int(norm_type.replace("group_norm", ""))
            self.batch_norm = nn.GroupNorm(num_groups=num_groups, num_channels=d_model)
        else:
            raise ValueError(f"conv_norm_type={norm_type} is not valid!")

        self.activation = nn.SiLU()
        self.pointwise_conv2 = nn.Conv1d(
            in_channels=dw_conv_input_dim,
            out_channels=d_model,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=self.use_bias,
        )

    def forward(self, x, pad_mask=None, cache=None):
        x = x.transpose(1, 2)
        x = self.pointwise_conv1(x)

        # Compute the activation function or use GLU for original Conformer
        if self.pointwise_activation == 'glu_':
            x = nn.functional.glu(x, dim=1)
        else:
            x = self.pointwise_activation(x)

        if pad_mask is not None:
            x = x.masked_fill(pad_mask.unsqueeze(1), 0.0)

        x = self.depthwise_conv(x, cache=cache)
        if cache is not None:
            x, cache = x

        if self.norm_type == "layer_norm":
            x = x.transpose(1, 2)
            x = self.batch_norm(x)
            x = x.transpose(1, 2)
        else:
            x = self.batch_norm(x)

        x = self.activation(x)
        x = self.pointwise_conv2(x)
        x = x.transpose(1, 2)
        if cache is None:
            return x
        else:
            return x, cache

    def reset_parameters_conv(self):
        pw1_max = pw2_max = self.d_model**-0.5
        dw_max = self.kernel_size**-0.5

        with torch.no_grad():
            nn.init.uniform_(self.pointwise_conv1.weight, -pw1_max, pw1_max)
            nn.init.uniform_(self.pointwise_conv2.weight, -pw2_max, pw2_max)
            nn.init.uniform_(self.depthwise_conv.weight, -dw_max, dw_max)
            if self.use_bias:
                nn.init.uniform_(self.pointwise_conv1.bias, -pw1_max, pw1_max)
                nn.init.uniform_(self.pointwise_conv2.bias, -pw2_max, pw2_max)
                nn.init.uniform_(self.depthwise_conv.bias, -dw_max, dw_max)


class ConformerFeedForward(nn.Module):
    """
    feed-forward module of Conformer model.
    use_bias (bool): Apply bias to all Linear and Conv1d layers improve activation flow and stabilize training of huge models.
    """

    def __init__(self, d_model, d_ff, dropout, activation=nn.SiLU(), use_bias=True):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        self.use_bias = use_bias
        self.linear1 = nn.Linear(d_model, d_ff, bias=self.use_bias)
        self.activation = activation
        self.dropout = nn.Dropout(p=dropout)
        self.linear2 = nn.Linear(d_ff, d_model, bias=self.use_bias)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x

    def reset_parameters_ff(self):
        ffn1_max = self.d_model**-0.5
        ffn2_max = self.d_ff**-0.5
        with torch.no_grad():
            nn.init.uniform_(self.linear1.weight, -ffn1_max, ffn1_max)
            nn.init.uniform_(self.linear2.weight, -ffn2_max, ffn2_max)
            if self.use_bias:
                nn.init.uniform_(self.linear1.bias, -ffn1_max, ffn1_max)
                nn.init.uniform_(self.linear2.bias, -ffn2_max, ffn2_max)

class CausalConv1D(nn.Conv1d):
    """
    A causal version of nn.Conv1d where each step would have limited access to locations on its right or left
    All arguments are the same as nn.Conv1d except padding.

    If padding is set None, then paddings are set automatically to make it a causal convolution where each location would not see any steps on its right.

    If padding is set as a list (size of 2), then padding[0] would be used as left padding and padding[1] as right padding.
    It would make it possible to control the number of steps to be accessible on the right and left.
    This mode is not supported when stride > 1. padding[0]+padding[1] should be equal to (kernel_size - 1).
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        padding: str | int = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = 'zeros',
        device=None,
        dtype=None,
    ) -> None:
        self.cache_drop_size = None
        if padding is None:
            self._left_padding = kernel_size - 1
            self._right_padding = stride - 1
        else:
            if stride != 1 and padding != kernel_size - 1:
                raise ValueError("No striding allowed for non-symmetric convolutions!")
            if isinstance(padding, int):
                self._left_padding = padding
                self._right_padding = padding
            elif isinstance(padding, list) and len(padding) == 2 and padding[0] + padding[1] == kernel_size - 1:
                self._left_padding = padding[0]
                self._right_padding = padding[1]
            else:
                raise ValueError(f"Invalid padding param: {padding}!")

        self._max_cache_len = self._left_padding

        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=0,
            dilation=dilation,
            groups=groups,
            bias=bias,
            padding_mode=padding_mode,
            device=device,
            dtype=dtype,
        )

    def update_cache(self, x, cache=None):
        if cache is None:
            new_x = F.pad(x, pad=(self._left_padding, self._right_padding))
            next_cache = cache
        else:
            new_x = F.pad(x, pad=(0, self._right_padding))
            new_x = torch.cat([cache, new_x], dim=-1)
            if self.cache_drop_size > 0:
                next_cache = new_x[:, :, : -self.cache_drop_size]
            else:
                next_cache = new_x
            next_cache = next_cache[:, :, -cache.size(-1) :]
        return new_x, next_cache

    def forward(self, x, cache=None):
        x, cache = self.update_cache(x, cache=cache)
        x = super().forward(x)
        if cache is None:
            return x
        else:
            return x, cache

```

## /src/liquid_audio/model/conformer/processor.py

```py path="/src/liquid_audio/model/conformer/processor.py" 
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""NeMo audio processing

Adapted from https://github.com/NVIDIA/NeMo/blob/09a962536c7a52f1964224e6e687ffb4a34fef79/nemo/collections/asr/modules/audio_preprocessing.py#L61
"""

import logging
from abc import ABC, abstractmethod
import random

import librosa
import torch
from torch import nn


class AudioPreprocessor(nn.Module, ABC):
    """
    An interface for Neural Modules that performs audio pre-processing,
    transforming the wav files to features.
    """

    def __init__(self, win_length, hop_length):
        super().__init__()

        self.win_length = win_length
        self.hop_length = hop_length

        self.torch_windows = {
            'hann': torch.hann_window,
            'hamming': torch.hamming_window,
            'blackman': torch.blackman_window,
            'bartlett': torch.bartlett_window,
            'ones': torch.ones,
            None: torch.ones,
        }

        # Normally, when you call to(dtype) on a torch.nn.Module, all
        # floating point parameters and buffers will change to that
        # dtype, rather than being float32. The AudioPreprocessor
        # classes, uniquely, don't actually have any parameters or
        # buffers from what I see. In addition, we want the input to
        # the preprocessor to be float32, but need to create the
        # output in appropriate precision. We have this empty tensor
        # here just to detect which dtype tensor this module should
        # output at the end of execution.
        self.register_buffer("dtype_sentinel_tensor", torch.tensor((), dtype=torch.float32), persistent=False)

    @torch.no_grad()
    def forward(self, input_signal, length):
        if input_signal.dtype != torch.float32:
            logging.warning(
                f"AudioPreprocessor received an input signal of dtype {input_signal.dtype}, rather than torch.float32. In sweeps across multiple datasets, we have found that the preprocessor is not robust to low precision  mathematics. As such, it runs in float32. Your input will be cast to float32, but this is not necessarily enough to recovery full accuracy. For example, simply casting input_signal from torch.float32 to torch.bfloat16, then back to torch.float32 before running AudioPreprocessor causes drops in absolute WER of up to 0.1%. torch.bfloat16 simply does not have enough mantissa bits to represent enough values in the range [-1.0,+1.0] correctly.",
            )
        processed_signal, processed_length = self.get_features(input_signal.to(torch.float32), length)
        processed_signal = processed_signal.to(self.dtype_sentinel_tensor.dtype)
        return processed_signal, processed_length

    @abstractmethod
    def get_features(self, input_signal, length):
        # Called by forward(). Subclasses should implement this.
        pass

class AudioToMelSpectrogramPreprocessor(AudioPreprocessor):
    """Featurizer module that converts wavs to mel spectrograms.

    Args:
        sample_rate (int): Sample rate of the input audio data.
            Defaults to 16000
        window_size (float): Size of window for fft in seconds
            Defaults to 0.02
        window_stride (float): Stride of window for fft in seconds
            Defaults to 0.01
        n_window_size (int): Size of window for fft in samples
            Defaults to None. Use one of window_size or n_window_size.
        n_window_stride (int): Stride of window for fft in samples
            Defaults to None. Use one of window_stride or n_window_stride.
        window (str): Windowing function for fft. can be one of ['hann',
            'hamming', 'blackman', 'bartlett']
            Defaults to "hann"
        normalize (str): Can be one of ['per_feature', 'all_features']; all
            other options disable feature normalization. 'all_features'
            normalizes the entire spectrogram to be mean 0 with std 1.
            'pre_features' normalizes per channel / freq instead.
            Defaults to "per_feature"
        n_fft (int): Length of FT window. If None, it uses the smallest power
            of 2 that is larger than n_window_size.
            Defaults to None
        preemph (float): Amount of pre emphasis to add to audio. Can be
            disabled by passing None.
            Defaults to 0.97
        features (int): Number of mel spectrogram freq bins to output.
            Defaults to 64
        lowfreq (int): Lower bound on mel basis in Hz.
            Defaults to 0
        highfreq  (int): Lower bound on mel basis in Hz.
            Defaults to None
        log (bool): Log features.
            Defaults to True
        log_zero_guard_type(str): Need to avoid taking the log of zero. There
            are two options: "add" or "clamp".
            Defaults to "add".
        log_zero_guard_value(float, or str): Add or clamp requires the number
            to add with or clamp to. log_zero_guard_value can either be a float
            or "tiny" or "eps". torch.finfo is used if "tiny" or "eps" is
            passed.
            Defaults to 2**-24.
        dither (float): Amount of white-noise dithering.
            Defaults to 1e-5
        pad_to (int): Ensures that the output size of the time dimension is
            a multiple of pad_to.
            Defaults to 16
        frame_splicing (int): Defaults to 1
        exact_pad (bool): If True, sets stft center to False and adds padding, such that num_frames = audio_length
            // hop_length. Defaults to False.
        pad_value (float): The value that shorter mels are padded with.
            Defaults to 0
        mag_power (float): The power that the linear spectrogram is raised to
            prior to multiplication with mel basis.
            Defaults to 2 for a power spec
        rng : Random number generator
        nb_augmentation_prob (float) : Probability with which narrowband augmentation would be applied to
            samples in the batch.
            Defaults to 0.0
        nb_max_freq (int) : Frequency above which all frequencies will be masked for narrowband augmentation.
            Defaults to 4000
        use_torchaudio: Whether to use the `torchaudio` implementation.
        mel_norm: Normalization used for mel filterbank weights.
            Defaults to 'slaney' (area normalization)
        stft_exact_pad: Deprecated argument, kept for compatibility with older checkpoints.
        stft_conv: Deprecated argument, kept for compatibility with older checkpoints.
    """

    def save_to(self, save_path: str):
        pass

    @classmethod
    def restore_from(cls, restore_path: str):
        pass

    @property
    def input_types(self):
        """Returns definitions of module input ports."""
        return {
            "input_signal": NeuralType(('B', 'T'), AudioSignal(freq=self._sample_rate)),
            "length": NeuralType(
                tuple('B'), LengthsType()
            ),  # Please note that length should be in samples not seconds.
        }

    @property
    def output_types(self):
        """Returns definitions of module output ports.

        processed_signal:
            0: AxisType(BatchTag)
            1: AxisType(MelSpectrogramSignalTag)
            2: AxisType(ProcessedTimeTag)
        processed_length:
            0: AxisType(BatchTag)
        """
        return {
            "processed_signal": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
            "processed_length": NeuralType(tuple('B'), LengthsType()),
        }

    def __init__(
        self,
        sample_rate=16000,
        window_size=0.02,
        window_stride=0.01,
        n_window_size=None,
        n_window_stride=None,
        window="hann",
        normalize="per_feature",
        n_fft=None,
        preemph=0.97,
        features=64,
        lowfreq=0,
        highfreq=None,
        log=True,
        log_zero_guard_type="add",
        log_zero_guard_value=2**-24,
        dither=1e-5,
        pad_to=16,
        frame_splicing=1,
        exact_pad=False,
        pad_value=0,
        mag_power=2.0,
        rng=None,
        nb_augmentation_prob=0.0,
        nb_max_freq=4000,
        use_torchaudio: bool = False,
        mel_norm="slaney",
        stft_exact_pad=False,  # Deprecated arguments; kept for config compatibility
        stft_conv=False,  # Deprecated arguments; kept for config compatibility
    ):
        self._sample_rate = sample_rate
        if window_size and n_window_size:
            raise ValueError(f"{self} received both window_size and " f"n_window_size. Only one should be specified.")
        if window_stride and n_window_stride:
            raise ValueError(
                f"{self} received both window_stride and " f"n_window_stride. Only one should be specified."
            )
        if window_size:
            n_window_size = int(window_size * self._sample_rate)
        if window_stride:
            n_window_stride = int(window_stride * self._sample_rate)
        super().__init__(n_window_size, n_window_stride)

        # Given the long and similar argument list, point to the class and instantiate it by reference
        if not use_torchaudio:
            featurizer_class = FilterbankFeatures
        else:
            featurizer_class = FilterbankFeaturesTA
        self.featurizer = featurizer_class(
            sample_rate=self._sample_rate,
            n_window_size=n_window_size,
            n_window_stride=n_window_stride,
            window=window,
            normalize=normalize,
            n_fft=n_fft,
            preemph=preemph,
            nfilt=features,
            lowfreq=lowfreq,
            highfreq=highfreq,
            log=log,
            log_zero_guard_type=log_zero_guard_type,
            log_zero_guard_value=log_zero_guard_value,
            dither=dither,
            pad_to=pad_to,
            frame_splicing=frame_splicing,
            exact_pad=exact_pad,
            pad_value=pad_value,
            mag_power=mag_power,
            rng=rng,
            nb_augmentation_prob=nb_augmentation_prob,
            nb_max_freq=nb_max_freq,
            mel_norm=mel_norm,
            stft_exact_pad=stft_exact_pad,  # Deprecated arguments; kept for config compatibility
            stft_conv=stft_conv,  # Deprecated arguments; kept for config compatibility
        )

    def input_example(self, max_batch: int = 8, max_dim: int = 32000, min_length: int = 200):
        dev = self.filter_banks.device

        signals = torch.randn(size=[max_batch, max_dim], device=dev)
        lengths = torch.randint(low=min_length, high=max_dim, size=[max_batch], device=dev)
        lengths[0] = max_dim
        return signals, lengths

    def get_features(self, input_signal, length):
        return self.featurizer(input_signal, length)

    @property
    def filter_banks(self):
        return self.featurizer.filter_banks

CONSTANT = 1e-5
class FilterbankFeatures(nn.Module):
    """Featurizer that converts wavs to Mel Spectrograms.
    See AudioToMelSpectrogramPreprocessor for args.
    """

    def __init__(
        self,
        sample_rate=16000,
        n_window_size=320,
        n_window_stride=160,
        window="hann",
        normalize="per_feature",
        n_fft=None,
        preemph=0.97,
        nfilt=64,
        lowfreq=0,
        highfreq=None,
        log=True,
        log_zero_guard_type="add",
        log_zero_guard_value=2**-24,
        dither=CONSTANT,
        pad_to=16,
        max_duration=16.7,
        frame_splicing=1,
        exact_pad=False,
        pad_value=0,
        mag_power=2.0,
        use_grads=False,
        rng=None,
        nb_augmentation_prob=0.0,
        nb_max_freq=4000,
        mel_norm="slaney",
        stft_exact_pad=False,  # Deprecated arguments; kept for config compatibility
        stft_conv=False,  # Deprecated arguments; kept for config compatibility
    ):
        super().__init__()
        if stft_conv or stft_exact_pad:
            logging.warning(
                "Using torch_stft is deprecated and has been removed. The values have been forcibly set to False "
                "for FilterbankFeatures and AudioToMelSpectrogramPreprocessor. Please set exact_pad to True "
                "as needed."
            )
        if exact_pad and n_window_stride % 2 == 1:
            raise NotImplementedError(
                f"{self} received exact_pad == True, but hop_size was odd. If audio_length % hop_size == 0. Then the "
                "returned spectrogram would not be of length audio_length // hop_size. Please use an even hop_size."
            )
        self.log_zero_guard_value = log_zero_guard_value
        if (
            n_window_size is None
            or n_window_stride is None
            or not isinstance(n_window_size, int)
            or not isinstance(n_window_stride, int)
            or n_window_size <= 0
            or n_window_stride <= 0
        ):
            raise ValueError(
                f"{self} got an invalid value for either n_window_size or "
                f"n_window_stride. Both must be positive ints."
            )
        logging.info(f"PADDING: {pad_to}")

        self.sample_rate = sample_rate
        self.win_length = n_window_size
        self.hop_length = n_window_stride
        self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
        self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None
        self.exact_pad = exact_pad
        self.sample_rate = sample_rate

        if exact_pad:
            logging.info("STFT using exact pad")
        torch_windows = {
            'hann': torch.hann_window,
            'hamming': torch.hamming_window,
            'blackman': torch.blackman_window,
            'bartlett': torch.bartlett_window,
            'none': None,
        }
        window_fn = torch_windows.get(window, None)
        window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None
        self.register_buffer("window", window_tensor)

        self.normalize = normalize
        self.log = log
        self.dither = dither
        self.frame_splicing = frame_splicing
        self.nfilt = nfilt
        self.preemph = preemph
        self.pad_to = pad_to
        highfreq = highfreq or sample_rate / 2

        filterbanks = torch.tensor(
            librosa.filters.mel(
                sr=sample_rate, n_fft=self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq, norm=mel_norm
            ),
            dtype=torch.float,
        ).unsqueeze(0)
        self.register_buffer("fb", filterbanks)

        # Calculate maximum sequence length
        max_length = self.get_seq_len(torch.tensor(max_duration * sample_rate, dtype=torch.float))
        max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0
        self.max_length = max_length + max_pad
        self.pad_value = pad_value
        self.mag_power = mag_power

        # We want to avoid taking the log of zero
        # There are two options: either adding or clamping to a small value
        if log_zero_guard_type not in ["add", "clamp"]:
            raise ValueError(
                f"{self} received {log_zero_guard_type} for the "
                f"log_zero_guard_type parameter. It must be either 'add' or "
                f"'clamp'."
            )

        self.use_grads = use_grads
        if not use_grads:
            self.forward = torch.no_grad()(self.forward)
        self._rng = random.Random() if rng is None else rng
        self.nb_augmentation_prob = nb_augmentation_prob
        if self.nb_augmentation_prob > 0.0:
            if nb_max_freq >= sample_rate / 2:
                self.nb_augmentation_prob = 0.0
            else:
                self._nb_max_fft_bin = int((nb_max_freq / sample_rate) * n_fft)

        # log_zero_guard_value is the the small we want to use, we support
        # an actual number, or "tiny", or "eps"
        self.log_zero_guard_type = log_zero_guard_type
        logging.debug(f"sr: {sample_rate}")
        logging.debug(f"n_fft: {self.n_fft}")
        logging.debug(f"win_length: {self.win_length}")
        logging.debug(f"hop_length: {self.hop_length}")
        logging.debug(f"n_mels: {nfilt}")
        logging.debug(f"fmin: {lowfreq}")
        logging.debug(f"fmax: {highfreq}")
        logging.debug(f"using grads: {use_grads}")
        logging.debug(f"nb_augmentation_prob: {nb_augmentation_prob}")

    def stft(self, x):
        return torch.stft(
            x,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            center=False if self.exact_pad else True,
            window=self.window.to(dtype=torch.float, device=x.device),
            return_complex=True,
            pad_mode="constant",
        )

    def log_zero_guard_value_fn(self, x):
        if isinstance(self.log_zero_guard_value, str):
            if self.log_zero_guard_value == "tiny":
                return torch.finfo(x.dtype).tiny
            elif self.log_zero_guard_value == "eps":
                return torch.finfo(x.dtype).eps
            else:
                raise ValueError(
                    f"{self} received {self.log_zero_guard_value} for the "
                    f"log_zero_guard_type parameter. It must be either a "
                    f"number, 'tiny', or 'eps'"
                )
        else:
            return self.log_zero_guard_value

    def get_seq_len(self, seq_len):
        # Assuming that center is True is stft_pad_amount = 0
        pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2
        seq_len = torch.floor_divide((seq_len + pad_amount - self.n_fft), self.hop_length)
        return seq_len.to(dtype=torch.long)

    @property
    def filter_banks(self):
        return self.fb

    def forward(self, x, seq_len, linear_spec=False):
        seq_len_time = seq_len
        seq_len_unfixed = self.get_seq_len(seq_len)
        # fix for seq_len = 0 for streaming; if size was 0, it is always padded to 1, and normalizer fails
        seq_len = torch.where(seq_len == 0, torch.zeros_like(seq_len_unfixed), seq_len_unfixed)

        if self.stft_pad_amount is not None:
            x = torch.nn.functional.pad(
                x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "constant"
            ).squeeze(1)

        # dither (only in training mode for eval determinism)
        if self.training and self.dither > 0:
            x += self.dither * torch.randn_like(x)

        # do preemphasis
        if self.preemph is not None:
            timemask = torch.arange(x.shape[1], device=x.device).unsqueeze(0) < seq_len_time.unsqueeze(1)
            x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1)
            x = x.masked_fill(~timemask, 0.0)

        # disable autocast to get full range of stft values
        with torch.amp.autocast(x.device.type, enabled=False):
            x = self.stft(x)

        # torch stft returns complex tensor (of shape [B,N,T]); so convert to magnitude
        # guard is needed for sqrt if grads are passed through
        guard = 0 if not self.use_grads else CONSTANT
        x = torch.view_as_real(x)
        x = torch.sqrt(x.pow(2).sum(-1) + guard)

        if self.training and self.nb_augmentation_prob > 0.0:
            for idx in range(x.shape[0]):
                if self._rng.random() < self.nb_augmentation_prob:
                    x[idx, self._nb_max_fft_bin :, :] = 0.0

        # get power spectrum
        if self.mag_power != 1.0:
            x = x.pow(self.mag_power)

        # return plain spectrogram if required
        if linear_spec:
            return x, seq_len

        # disable autocast, otherwise it might be automatically casted to fp16
        # on fp16 compatible GPUs and get NaN values for input value of 65520
        with torch.amp.autocast(x.device.type, enabled=False):
            # dot with filterbank energies
            x = torch.matmul(self.fb.to(x.dtype), x)
        # log features if required
        if self.log:
            if self.log_zero_guard_type == "add":
                x = torch.log(x + self.log_zero_guard_value_fn(x))
            elif self.log_zero_guard_type == "clamp":
                x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x)))
            else:
                raise ValueError("log_zero_guard_type was not understood")

        # frame splicing if required
        if self.frame_splicing > 1:
            x = splice_frames(x, self.frame_splicing)

        # normalize if required
        if self.normalize:
            x, _, _ = normalize_batch(x, seq_len, normalize_type=self.normalize)

        # mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency)
        max_len = x.size(-1)
        mask = torch.arange(max_len, device=x.device)
        mask = mask.repeat(x.size(0), 1) >= seq_len.unsqueeze(1)
        x = x.masked_fill(mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value)
        del mask
        pad_to = self.pad_to
        if pad_to == "max":
            x = nn.functional.pad(x, (0, self.max_length - x.size(-1)), value=self.pad_value)
        elif pad_to > 0:
            pad_amt = x.size(-1) % pad_to
            if pad_amt != 0:
                x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value)
        return x, seq_len

def normalize_batch(x, seq_len, normalize_type):
    x_mean = None
    x_std = None
    if normalize_type == "per_feature":
        batch_size = x.shape[0]
        max_time = x.shape[2]

        # When doing stream capture to a graph, item() is not allowed
        # becuase it calls cudaStreamSynchronize(). Therefore, we are
        # sacrificing some error checking when running with cuda graphs.
        if (
            torch.cuda.is_available()
            and not torch.cuda.is_current_stream_capturing()
            and torch.any(seq_len == 1).item()
        ):
            raise ValueError(
                "normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result "
                "in torch.std() returning nan. Make sure your audio length has enough samples for a single "
                "feature (ex. at least `hop_length` for Mel Spectrograms)."
            )
        time_steps = torch.arange(max_time, device=x.device).unsqueeze(0).expand(batch_size, max_time)
        valid_mask = time_steps < seq_len.unsqueeze(1)
        x_mean_numerator = torch.where(valid_mask.unsqueeze(1), x, 0.0).sum(axis=2)
        x_mean_denominator = valid_mask.sum(axis=1)
        x_mean = x_mean_numerator / x_mean_denominator.unsqueeze(1)

        # Subtract 1 in the denominator to correct for the bias.
        x_std = torch.sqrt(
            torch.sum(torch.where(valid_mask.unsqueeze(1), x - x_mean.unsqueeze(2), 0.0) ** 2, axis=2)
            / (x_mean_denominator.unsqueeze(1) - 1.0)
        )
        x_std = x_std.masked_fill(x_std.isnan(), 0.0)  # edge case: only 1 frame in denominator
        # make sure x_std is not zero
        x_std += CONSTANT
        return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2), x_mean, x_std
    elif normalize_type == "all_features":
        x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
        x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
        for i in range(x.shape[0]):
            x_mean[i] = x[i, :, : seq_len[i].item()].mean()
            x_std[i] = x[i, :, : seq_len[i].item()].std()
        # make sure x_std is not zero
        x_std += CONSTANT
        return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1), x_mean, x_std
    elif "fixed_mean" in normalize_type and "fixed_std" in normalize_type:
        x_mean = torch.tensor(normalize_type["fixed_mean"], device=x.device)
        x_std = torch.tensor(normalize_type["fixed_std"], device=x.device)
        return (
            (x - x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2)) / x_std.view(x.shape[0], x.shape[1]).unsqueeze(2),
            x_mean,
            x_std,
        )
    else:
        return x, x_mean, x_std

```

## /src/liquid_audio/model/conformer/subsampling.py

```py path="/src/liquid_audio/model/conformer/subsampling.py" 
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""NeMo subsampling

adapted from https://github.com/NVIDIA/NeMo/blob/c83adff36efaa549f7bdd26e97c01a60e9f9026b/nemo/collections/asr/parts/submodules/subsampling.py
"""

import logging
import math

import torch
from torch import nn

logger = logging.getLogger(__name__)

class ConvSubsampling(torch.nn.Module):
    """Convolutional subsampling which supports VGGNet and striding approach introduced in:
    VGGNet Subsampling: Transformer-transducer: end-to-end speech recognition with self-attention (https://arxiv.org/pdf/1910.12977.pdf)
    Striding Subsampling: "Speech-Transformer: A No-Recurrence Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong et al. (https://ieeexplore.ieee.org/document/8462506)
    Args:
        subsampling (str): The subsampling technique from {"vggnet", "striding", "dw-striding"}
        subsampling_factor (int): The subsampling factor which should be a power of 2
        subsampling_conv_chunking_factor (int): Input chunking factor which can be -1 (no chunking)
        1 (auto) or a power of 2. Default is 1
        feat_in (int): size of the input features
        feat_out (int): size of the output features
        conv_channels (int): Number of channels for the convolution layers.
        activation (Module): activation function, default is nn.ReLU()
    """

    def __init__(
        self,
        subsampling,
        subsampling_factor,
        feat_in,
        feat_out,
        conv_channels,
        subsampling_conv_chunking_factor=1,
        activation=nn.ReLU(),
        is_causal=False,
    ):
        super().__init__()
        self._subsampling = subsampling
        self._conv_channels = conv_channels
        self._feat_in = feat_in
        self._feat_out = feat_out

        if subsampling_factor % 2 != 0:
            raise ValueError("Sampling factor should be a multiply of 2!")
        self._sampling_num = int(math.log2(subsampling_factor))
        self.subsampling_factor = subsampling_factor
        self.is_causal = is_causal

        if (
            subsampling_conv_chunking_factor != -1
            and subsampling_conv_chunking_factor != 1
            and subsampling_conv_chunking_factor % 2 != 0
        ):
            raise ValueError("subsampling_conv_chunking_factor should be -1, 1, or a power of 2")
        self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor

        in_channels = 1
        layers = []

        if subsampling == 'vggnet':
            self._stride = 2
            self._kernel_size = 2
            self._ceil_mode = True

            self._left_padding = 0
            self._right_padding = 0

            for i in range(self._sampling_num):
                layers.append(
                    torch.nn.Conv2d(
                        in_channels=in_channels, out_channels=conv_channels, kernel_size=3, stride=1, padding=1
                    )
                )
                layers.append(activation)
                layers.append(
                    torch.nn.Conv2d(
                        in_channels=conv_channels, out_channels=conv_channels, kernel_size=3, stride=1, padding=1
                    )
                )
                layers.append(activation)
                layers.append(
                    torch.nn.MaxPool2d(
                        kernel_size=self._kernel_size,
                        stride=self._stride,
                        padding=self._left_padding,
                        ceil_mode=self._ceil_mode,
                    )
                )
                in_channels = conv_channels

        elif subsampling == 'dw_striding':
            self._stride = 2
            self._kernel_size = 3
            self._ceil_mode = False

            if self.is_causal:
                self._left_padding = self._kernel_size - 1
                self._right_padding = self._stride - 1
                self._max_cache_len = subsampling_factor + 1
            else:
                self._left_padding = (self._kernel_size - 1) // 2
                self._right_padding = (self._kernel_size - 1) // 2
                self._max_cache_len = 0

            # Layer 1
            if self.is_causal:
                layers.append(
                    CausalConv2D(
                        in_channels=in_channels,
                        out_channels=conv_channels,
                        kernel_size=self._kernel_size,
                        stride=self._stride,
                        padding=None,
                    )
                )
            else:
                layers.append(
                    torch.nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=conv_channels,
                        kernel_size=self._kernel_size,
                        stride=self._stride,
                        padding=self._left_padding,
                    )
                )
            in_channels = conv_channels
            layers.append(activation)

            for i in range(self._sampling_num - 1):
                if self.is_causal:
                    layers.append(
                        CausalConv2D(
                            in_channels=in_channels,
                            out_channels=in_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=None,
                            groups=in_channels,
                        )
                    )
                else:
                    layers.append(
                        torch.nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=in_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=self._left_padding,
                            groups=in_channels,
                        )
                    )

                layers.append(
                    torch.nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=conv_channels,
                        kernel_size=1,
                        stride=1,
                        padding=0,
                        groups=1,
                    )
                )
                layers.append(activation)
                in_channels = conv_channels

        elif subsampling == 'striding':
            self._stride = 2
            self._kernel_size = 3
            self._ceil_mode = False

            if self.is_causal:
                self._left_padding = self._kernel_size - 1
                self._right_padding = self._stride - 1
                self._max_cache_len = subsampling_factor + 1
            else:
                self._left_padding = (self._kernel_size - 1) // 2
                self._right_padding = (self._kernel_size - 1) // 2
                self._max_cache_len = 0

            for i in range(self._sampling_num):
                if self.is_causal:
                    layers.append(
                        CausalConv2D(
                            in_channels=in_channels,
                            out_channels=conv_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=None,
                        )
                    )
                else:
                    layers.append(
                        torch.nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=conv_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=self._left_padding,
                        )
                    )
                layers.append(activation)
                in_channels = conv_channels

        elif subsampling == 'striding_conv1d':

            in_channels = feat_in

            self._stride = 2
            self._kernel_size = 5
            self._ceil_mode = False

            if self.is_causal:
                self._left_padding = self._kernel_size - 1
                self._right_padding = self._stride - 1
                self._max_cache_len = subsampling_factor + 1
            else:
                self._left_padding = (self._kernel_size - 1) // 2
                self._right_padding = (self._kernel_size - 1) // 2
                self._max_cache_len = 0

            for i in range(self._sampling_num):
                if self.is_causal:
                    layers.append(
                        CausalConv1D(
                            in_channels=in_channels,
                            out_channels=feat_out if self._sampling_num == i + 1 else conv_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=None,
                        )
                    )
                else:
                    layers.append(
                        torch.nn.Conv1d(
                            in_channels=in_channels,
                            out_channels=feat_out if self._sampling_num == i + 1 else conv_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=self._left_padding,
                        )
                    )
                layers.append(activation)
                in_channels = conv_channels

        elif subsampling == 'dw_striding_conv1d':

            in_channels = feat_in

            self._stride = 2
            self._kernel_size = 5
            self._ceil_mode = False

            self._left_padding = (self._kernel_size - 1) // 2
            self._right_padding = (self._kernel_size - 1) // 2

            # Layer 1
            layers.extend(
                [
                    torch.nn.Conv1d(
                        in_channels=in_channels,
                        out_channels=in_channels,
                        kernel_size=self._kernel_size,
                        stride=self._stride,
                        padding=self._left_padding,
                        groups=in_channels,
                    ),
                    torch.nn.Conv1d(
                        in_channels=in_channels,
                        out_channels=feat_out if self._sampling_num == 1 else conv_channels,
                        kernel_size=1,
                        stride=1,
                        padding=0,
                        groups=1,
                    ),
                ]
            )
            in_channels = conv_channels
            layers.append(activation)

            for i in range(self._sampling_num - 1):
                layers.extend(
                    [
                        torch.nn.Conv1d(
                            in_channels=in_channels,
                            out_channels=in_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=self._left_padding,
                            groups=in_channels,
                        ),
                        torch.nn.Conv1d(
                            in_channels=in_channels,
                            out_channels=feat_out if self._sampling_num == i + 2 else conv_channels,
                            kernel_size=1,
                            stride=1,
                            padding=0,
                            groups=1,
                        ),
                    ]
                )
                layers.append(activation)
                in_channels = conv_channels

        else:
            raise ValueError(f"Not valid sub-sampling: {subsampling}!")

        if subsampling in ["vggnet", "dw_striding", "striding"]:

            in_length = torch.tensor(feat_in, dtype=torch.float)
            out_length = calc_length(
                lengths=in_length,
                all_paddings=self._left_padding + self._right_padding,
                kernel_size=self._kernel_size,
                stride=self._stride,
                ceil_mode=self._ceil_mode,
                repeat_num=self._sampling_num,
            )
            self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out)
            self.conv2d_subsampling = True
        elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]:
            self.out = None
            self.conv2d_subsampling = False
        else:
            raise ValueError(f"Not valid sub-sampling: {subsampling}!")

        self.conv = MaskedConvSequential(*layers)

    def get_sampling_frames(self):
        return [1, self.subsampling_factor]

    def get_streaming_cache_size(self):
        return [0, self.subsampling_factor + 1]

    def forward(self, x, lengths):
        out_lengths = calc_length(
            lengths,
            all_paddings=self._left_padding + self._right_padding,
            kernel_size=self._kernel_size,
            stride=self._stride,
            ceil_mode=self._ceil_mode,
            repeat_num=self._sampling_num,
        )

        # Transpose to Channel First mode
        if not self.conv2d_subsampling:
            x = x.transpose(1, 2)

        # split inputs if chunking_factor is set
        if self.subsampling_conv_chunking_factor != -1 and self.conv2d_subsampling:
            if self.subsampling_conv_chunking_factor == 1:
                # if subsampling_conv_chunking_factor is 1, we split only if needed
                # avoiding a bug / feature limiting indexing of tensors to 2**31
                # see https://github.com/pytorch/pytorch/issues/80020
                # Fix NeMo bug: compute exactly the largest output size
                x_ceil = 2 ** 31
                out_size = x.shape[0] * self._conv_channels * ((x.shape[1] + 1) // self._stride) * ((x.shape[2] + 1) // self._stride)
                if out_size >= x_ceil:
                    need_to_split = True
                else:
                    need_to_split = False
            else:
                # if subsampling_conv_chunking_factor > 1 we always split
                need_to_split = True

            if need_to_split:
                x, lengths, success = self.conv_split_by_batch(x, lengths)
                if not success:  # if unable to split by batch, try by channel
                    if self._subsampling == 'dw_striding':
                        # TODO: implement lengths inside conv_split_by_channel
                        x = self.conv_split_by_channel(x)
                        lengths = out_lengths
                    else:
                        x, lengths = self.conv(x, lengths)  # try anyway
            else:
                x, lengths = self.conv(x, lengths)
        else:
            x, lengths = self.conv(x)

        # Flatten Channel and Frequency Axes
        if self.conv2d_subsampling:
            b, c, t, f = x.size()
            x = self.out(x.transpose(1, 2).reshape(b, t, c*f))
        # Transpose to Channel Last mode
        else:
            x = x.transpose(1, 2)

        return x, lengths

    def reset_parameters(self):
        # initialize weights
        if self._subsampling == 'dw_striding':
            with torch.no_grad():
                # init conv
                scale = 1.0 / self._kernel_size
                dw_max = (self._kernel_size**2) ** -0.5
                pw_max = self._conv_channels**-0.5

                torch.nn.init.uniform_(self.conv[0].weight, -scale, scale)
                torch.nn.init.uniform_(self.conv[0].bias, -scale, scale)

                for idx in range(2, len(self.conv), 3):
                    torch.nn.init.uniform_(self.conv[idx].weight, -dw_max, dw_max)
                    torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, dw_max)
                    torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max, pw_max)
                    torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max, pw_max)

                # init fc (80 * 64 = 5120 from https://github.com/kssteven418/Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/src/models/conformer_encoder.py#L487
                fc_scale = (self._feat_out * self._feat_in / self._sampling_num) ** -0.5
                torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale)
                torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale)

    def conv_split_by_batch(self, x, lengths):
        """Tries to split input by batch, run conv and concat results"""
        b, *_ = x.size()
        if b == 1:  # can't split if batch size is 1
            return x, lengths, False

        if self.subsampling_conv_chunking_factor > 1:
            cf = self.subsampling_conv_chunking_factor
            logger.debug(f'using manually set chunking factor: {cf}')
        else:
            # avoiding a bug / feature limiting indexing of tensors to 2**31
            # see https://github.com/pytorch/pytorch/issues/80020
            x_ceil = 2 ** 31
            out_size = x.shape[0] * self._conv_channels * ((x.shape[1] + 1) // self._stride) * ((x.shape[2] + 1) // self._stride)
            p = math.ceil(math.log2((out_size+1) / x_ceil))
            cf = 2**p
            logger.debug(f'using auto set chunking factor: {cf}')

        new_batch_size = b // cf
        if new_batch_size == 0:  # input is too big
            return x, lengths, False

        logger.debug(f'conv subsampling: using split batch size {new_batch_size}')

        ans = [
            self.conv(chunk, ln)
            for chunk, ln in zip(
                torch.split(x, new_batch_size, 0),
                torch.split(lengths, new_batch_size, 0),
            )
        ]
        return torch.cat([a[0] for a in ans]), torch.cat([a[1] for a in ans]), True

    def conv_split_by_channel(self, x):
        """For dw convs, tries to split input by time, run conv and concat results"""

        # Note: this method doesn't use the convolution masking implemented in MaskedConvolutionSequential
        x = x.unsqueeze(0)
        x = self.conv[0](x)  # full conv2D
        x = self.conv[1](x)  # activation

        for i in range(self._sampling_num - 1):
            _, c, t, _ = x.size()

            if self.subsampling_conv_chunking_factor > 1:
                cf = self.subsampling_conv_chunking_factor
                logger.debug(f'using manually set chunking factor: {cf}')
            else:
                # avoiding a bug / feature limiting indexing of tensors to 2**31
                # see https://github.com/pytorch/pytorch/issues/80020
                p = math.ceil(math.log(torch.numel(x) / 2**31, 2))
                cf = 2**p
                logger.debug(f'using auto set chunking factor: {cf}')

            new_c = int(c // cf)
            if new_c == 0:
                logger.warning(f'chunking factor {cf} is too high; splitting down to one channel.')
                new_c = 1

            new_t = int(t // cf)
            if new_t == 0:
                logger.warning(f'chunking factor {cf} is too high; splitting down to one timestep.')
                new_t = 1

            logger.debug(f'conv dw subsampling: using split C size {new_c} and split T size {new_t}')
            x = self.channel_chunked_conv(self.conv[i * 3 + 2], new_c, x)  # conv2D, depthwise

            # splitting pointwise convs by time
            x = torch.cat([self.conv[i * 3 + 3](chunk) for chunk in torch.split(x, new_t, 2)], 2)  # conv2D, pointwise
            x = self.conv[i * 3 + 4](x)  # activation
        return x

    def channel_chunked_conv(self, conv, chunk_size, x):
        """Performs channel chunked convolution"""

        ind = 0
        out_chunks = []
        for chunk in torch.split(x, chunk_size, 1):
            step = chunk.size()[1]

            if self.is_causal:
                chunk = nn.functional.pad(
                    chunk, pad=(self._kernel_size - 1, self._stride - 1, self._kernel_size - 1, self._stride - 1)
                )
                ch_out = nn.functional.conv2d(
                    chunk,
                    conv.weight[ind : ind + step, :, :, :],
                    bias=conv.bias[ind : ind + step],
                    stride=self._stride,
                    padding=0,
                    groups=step,
                )
            else:
                ch_out = nn.functional.conv2d(
                    chunk,
                    conv.weight[ind : ind + step, :, :, :],
                    bias=conv.bias[ind : ind + step],
                    stride=self._stride,
                    padding=self._left_padding,
                    groups=step,
                )
            out_chunks.append(ch_out)
            ind += step

        return torch.cat(out_chunks, 1)

    def change_subsampling_conv_chunking_factor(self, subsampling_conv_chunking_factor: int):
        if (
            subsampling_conv_chunking_factor != -1
            and subsampling_conv_chunking_factor != 1
            and subsampling_conv_chunking_factor % 2 != 0
        ):
            raise ValueError("subsampling_conv_chunking_factor should be -1, 1, or a power of 2")
        self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor


def calc_length(lengths, all_paddings, kernel_size, stride, ceil_mode, repeat_num=1):
    """Calculates the output length of a Tensor passed through a convolution or max pooling layer"""
    add_pad: float = all_paddings - kernel_size
    one: float = 1.0
    for i in range(repeat_num):
        lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one
        if ceil_mode:
            lengths = torch.ceil(lengths)
        else:
            lengths = torch.floor(lengths)
    return lengths.to(dtype=torch.int)


class MaskedConvSequential(nn.Sequential):
    def forward(self, x, lengths):
        # Convert input (batch, time, features) to conv format
        x = x.unsqueeze(1)  # (batch, 1, time, features)
        current_lengths = lengths.clone().float()
        mask = self._create_mask(x, current_lengths.long())

        # Process through each layer with mask propagation
        for i, layer in enumerate(self):
            # Apply current mask before layer
            x = apply_channel_mask(x, mask)

            # Apply layer
            x = layer(x)

            # Update lengths for stride operations with proper padding
            if hasattr(layer, 'stride') and layer.stride != (1, 1):
                if hasattr(layer, "_left_padding"):
                    padding = (layer._left_padding, layer._right_padding)  # CausalConv2D
                else:
                    padding = layer.padding
                current_lengths = calculate_conv_output_size(
                    current_lengths, layer.kernel_size[0], layer.stride[0], padding
                )
                mask = self._create_mask(x, current_lengths.long())

        # Final masking
        x = apply_channel_mask(x, mask)
        return x, current_lengths.long()

    def _create_mask(self, tensor, lengths):
        """Create mask matching tensor dimensions."""
        batch_size, channels, time, features = tensor.shape
        time_mask = torch.arange(time, device=tensor.device).expand(batch_size, time) < lengths.unsqueeze(1)
        return time_mask.unsqueeze(-1).expand(batch_size, time, features).to(tensor.dtype)

def apply_channel_mask(tensor, mask):
    """Apply mask to tensor with channel dimension."""
    # tensor: (batch, channels, time, features)
    # mask: (batch, time, features)
    batch_size, channels, time, features = tensor.shape
    expanded_mask = mask.unsqueeze(1).expand(batch_size, channels, time, features)
    return tensor * expanded_mask


def calculate_conv_output_size(input_size: torch.Tensor, kernel_size: int, stride: int, padding: tuple[int, int]):
    """Calculate exact output size after convolution."""
    return (input_size + padding[0] + padding[1] - kernel_size) // stride + 1

```

## /src/liquid_audio/model/conformer/utils.py

```py path="/src/liquid_audio/model/conformer/utils.py" 
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""NeMo casting utils

adapted from https://github.com/NVIDIA/NeMo/blob/c83adff36efaa549f7bdd26e97c01a60e9f9026b/nemo/utils/cast_utils.py
"""

from contextlib import nullcontext
from dataclasses import dataclass

import torch

def avoid_float16_autocast_context():
    """
    If the current autocast context is float16, cast it to bfloat16
    if available (unless we're in jit) or float32
    """

    if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.float16:
        if torch.jit.is_scripting() or torch.jit.is_tracing():
            return torch.amp.autocast('cuda', dtype=torch.float32)

        if torch.cuda.is_bf16_supported():
            return torch.amp.autocast('cuda', dtype=torch.bfloat16)
        else:
            return torch.amp.autocast('cuda', dtype=torch.float32)
    else:
        return nullcontext()

@dataclass
class CacheAwareStreamingConfig:
    chunk_size: int = (
        0  # the size of each chunk at each step, it can be a list of two integers to specify different chunk sizes for the first step and others
    )
    shift_size: int = (
        0  # the size of the shift in each step, it can be a list of two integers to specify different shift sizes for the first step and others
    )

    cache_drop_size: int = 0  # the number of steps to drop from the cache
    last_channel_cache_size: int = 0  # the size of the needed cache for last channel layers

    valid_out_len: int = (
        0  # the number of the steps in the final output which are valid (have the same value as in the offline mode)
    )

    pre_encode_cache_size: int = (
        0  # the size of the needed cache for the pre-encoding part of the model to avoid caching inside the pre-encoding layers
    )
    drop_extra_pre_encoded: int = 0  # the number of steps to get dropped after the pre-encoding layer

    last_channel_num: int = 0  # number of the last channel layers (like MHA layers) which need caching in the model
    last_time_num: int = 0  # number of the last time layers (like convolutions) which need caching in the model

def compute_stochastic_depth_drop_probs(
    num_layers: int,
    stochastic_depth_drop_prob: float = 0.0,
    stochastic_depth_mode: str = "linear",
    stochastic_depth_start_layer: int = 1,
) -> list[float]:
    """Computes drop probabilities for stochastic depth regularization technique.
    The first layer is never dropped and the starting layer needs to be greater
    or equal to 1.

    Args:
        num_layers (int): number of layers in the network.
        stochastic_depth_drop_prob (float): if non-zero, will randomly drop
            layers during training. The higher this value, the more often layers
            are dropped. Defaults to 0.0.
        stochastic_depth_mode (str): can be either "linear" or "uniform". If
            set to "uniform", all layers have the same probability of drop. If
            set to "linear", the drop probability grows linearly from 0 for the
            first layer to the desired value for the final layer. Defaults to
            "linear".
        stochastic_depth_start_layer (int): starting layer for stochastic depth.
            All layers before this will never be dropped. Note that drop
            probability will be adjusted accordingly if mode is "linear" when
            start layer is > 1. Defaults to 1.
    Returns:
        List[float]: list of drop probabilities for all layers
    """
    if not (0 <= stochastic_depth_drop_prob < 1.0):
        raise ValueError("stochastic_depth_drop_prob has to be in [0, 1).")
    if not (1 <= stochastic_depth_start_layer <= num_layers):
        raise ValueError("stochastic_depth_start_layer has to be in [1, num layers].")

    # Layers before `stochastic_depth_start_layer` are never dropped
    layer_drop_probs = [0.0] * stochastic_depth_start_layer

    # Layers starting with `stochastic_depth_start_layer` may be dropped
    if (L := num_layers - stochastic_depth_start_layer) > 0:
        if stochastic_depth_mode == "linear":
            # we start with 1/L * drop_prob and and end with the desired drop probability.
            layer_drop_probs += [l / L * stochastic_depth_drop_prob for l in range(1, L + 1)]
        elif stochastic_depth_mode == "uniform":
            layer_drop_probs += [stochastic_depth_drop_prob] * L
        else:
            raise ValueError(
                f'stochastic_depth_mode has to be one of ["linear", "uniform"]. Current value: {stochastic_depth_mode}'
            )
    return layer_drop_probs
```

## /src/liquid_audio/model/lfm2_audio.py

```py path="/src/liquid_audio/model/lfm2_audio.py" 
from __future__ import annotations

import json
import math
from collections.abc import Generator
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import ClassVar, Literal, Self, TypedDict

import torch
from accelerate import init_on_device, load_checkpoint_in_model
from einops import rearrange
from torch import nn
from transformers import Lfm2Config, Lfm2Model
from transformers.models.lfm2.modeling_lfm2 import Lfm2HybridConvCache

from liquid_audio.model.conformer.encoder import ConformerEncoder, ConformerEncoderConfig
from liquid_audio.model.mlp import MLP
from liquid_audio.model.transformer import MHA, RawLMBackbone, SharedEmbedding, StandardBlock
from liquid_audio.processor import PreprocessorConfig
from liquid_audio.utils import LFMModality, get_model_dir, mel2emb_len, module_exists


class LFM2_HFConfig(TypedDict):
    pretrained_model_name_or_path: str
    revision: str


@dataclass(kw_only=True)
class LFM2AudioConfig:
    architectures: list[str]  # for huggingface compatibility

    codebooks: int
    tie_audio_embeddings: bool

    semantic_codebook_factor: float
    codebook_weight: Literal["log", "linear"]

    interleaved_n_text: int
    interleaved_n_audio: int

    preprocessor: PreprocessorConfig
    encoder: ConformerEncoderConfig
    lfm: Lfm2Config
    depthformer: DepthformerConfig


@dataclass(kw_only=True)
class DepthformerConfig:
    layers: int
    dim: int
    tie: bool


class LFM2AudioModel(nn.Module):
    audio_vocab_size: ClassVar[int] = 2048 + 1  # Includes +1 for EOAudio

    def __init__(
        self,
        conf: LFM2AudioConfig,
    ):
        super().__init__()

        self.conf = conf
        self.codebooks = conf.codebooks

        ## LFM2 ##
        self.lfm = Lfm2Model(conf.lfm)

        ## Audio encoder ##
        self.conformer = ConformerEncoder(**asdict(conf.encoder))
        self.audio_adapter = MLP(self.conformer._feat_out, self.lfm.config.hidden_size, [self.lfm.config.hidden_size])

        ## Depthformer ##
        self.depthformer_layers = conf.depthformer.layers
        self.depthformer_dim = conf.depthformer.dim
        self.depthformer_tie = conf.depthformer.tie
        self.audio_embedding = SharedEmbedding(
            dim=self.lfm.config.hidden_size,
            vocab_size=self.audio_vocab_size * self.conf.codebooks,
            embed_init_scale=1.0,
            norm_eps=0.00001,
            tie_embedding=conf.tie_audio_embeddings,
        )

        self.codebook_offsets: torch.Tensor
        self.register_buffer("codebook_offsets", torch.arange(self.conf.codebooks) * self.audio_vocab_size)

        self.audio_loss_weights: torch.Tensor
        if conf.codebook_weight == "log":
            weights = (torch.linspace(1, 0, self.codebooks) * math.log(conf.semantic_codebook_factor)).exp()
        else:
            weights = torch.ones((self.codebooks,))
            weights[0] *= conf.semantic_codebook_factor
        self.register_buffer(
            "audio_loss_weights",
            weights,
        )

        scale = 1 / math.sqrt(2 * self.depthformer_layers)

        layers = [
            StandardBlock(MHA(self.depthformer_dim, out_init_scale=scale), out_init_scale=scale)
            for _ in range(self.depthformer_layers)
        ]
        self.depthformer = RawLMBackbone(layers, has_embedding=False)

        self.depth_linear = nn.Linear(self.lfm.config.hidden_size, self.depthformer_dim * self.codebooks)
        self.depth_embeddings = nn.ModuleList(
            [
                SharedEmbedding(
                    dim=self.depthformer_dim,
                    vocab_size=self.audio_vocab_size,
                    tie_embedding=self.depthformer_tie,
                )
                for _ in range(self.codebooks)
            ]
        )

    @classmethod
    def from_pretrained(
        cls,
        repo_id: str | Path,
        *,
        revision: str | None = None,
        dtype: torch.dtype = torch.bfloat16,
        device: torch.device | str = "cuda",
    ) -> Self:
        cache_path = get_model_dir(repo_id, revision=revision)

        with (cache_path / "config.json").open() as f:
            config = json.load(f)

        conf = LFM2AudioConfig(
            lfm=Lfm2Config(**config.pop("lfm")),
            encoder=ConformerEncoderConfig(**config.pop("encoder")),
            depthformer=DepthformerConfig(**config.pop("depthformer")),
            **config,
        )

        if isinstance(device, str):
            device = torch.device(device)

        with init_on_device(device, include_buffers=True):
            model = cls(conf).to(device=device, dtype=dtype)

        if module_exists("flash_attn"):
            model.lfm.set_attn_implementation("flash_attention_2")
        else:
            model.lfm.set_attn_implementation("sdpa")

        load_checkpoint_in_model(model, cache_path)

        return model

    @torch.no_grad()
    def generate_sequential(
        self,
        *,
        text: torch.Tensor,
        audio_in: torch.Tensor,
        audio_in_lens: torch.Tensor,
        audio_out: torch.Tensor,
        modality_flag: torch.Tensor,
        max_new_tokens: int = 20,
        text_temperature: float | None = None,
        text_top_k: int | None = None,
        audio_temperature: float | None = None,
        audio_top_k: int | None = None,
    ) -> Generator[torch.Tensor, None, None]:
        in_emb = self._prefill(
            text=text,
            audio_in=audio_in,
            audio_in_lens=audio_in_lens,
            audio_out=audio_out,
            modality_flag=modality_flag,
        )

        current_modality: LFMModality = LFMModality.TEXT
        cache: Lfm2HybridConvCache | None = None

        for _ in range(max_new_tokens):
            # breakpoint()
            lfm_out = self.lfm(
                inputs_embeds=in_emb,
                past_key_values=cache,
                use_cache=True,
            )
            output_embeddings = lfm_out.last_hidden_state
            cache = lfm_out.past_key_values

            if current_modality == LFMModality.TEXT:
                text_logits = nn.functional.linear(output_embeddings[0, -1], self.lfm.embed_tokens.weight)
                next_token = self._sample_text_token(text_logits, temperature=text_temperature, top_k=text_top_k)
                yield next_token

                if next_token == 128:  # <|audio_start|>
                    current_modality = LFMModality.AUDIO_OUT
                if next_token == 7:  # <|im_end|>
                    break

                in_emb = self.lfm.embed_tokens(next_token)[None, :]

            elif current_modality == LFMModality.AUDIO_OUT:
                next_token = self._sample_audio_frame(
                    output_embeddings[0, -1],
                    temperature=audio_temperature,
                    top_k=audio_top_k,
                )

                if next_token[0] == 2048:
                    next_token[:] = 2048
                    current_modality = LFMModality.TEXT

                yield next_token
                in_emb = self.audio_embedding(next_token + self.codebook_offsets).sum(0)[None, None, :]

    @torch.no_grad()
    def generate_interleaved(
        self,
        *,
        text: torch.Tensor,
        audio_in: torch.Tensor,
        audio_in_lens: torch.Tensor,
        audio_out: torch.Tensor,
        modality_flag: torch.Tensor,
        max_new_tokens: int = 20,
        text_temperature: float | None = None,
        text_top_k: int | None = None,
        audio_temperature: float | None = None,
        audio_top_k: int | None = None,
    ) -> Generator[torch.Tensor, None, None]:
        in_emb = self._prefill(
            text=text,
            audio_in=audio_in,
            audio_in_lens=audio_in_lens,
            audio_out=audio_out,
            modality_flag=modality_flag,
        )

        current_modality: LFMModality = LFMModality.TEXT
        modality_left: int = self.conf.interleaved_n_text
        cache: Lfm2HybridConvCache | None = None

        text_done: bool = False

        for _ in range(max_new_tokens):
            modality_left -= 1
            lfm_out = self.lfm(
                inputs_embeds=in_emb,
                past_key_values=cache,
                use_cache=True,
            )
            output_embeddings = lfm_out.last_hidden_state
            cache = lfm_out.past_key_values

            if current_modality == LFMModality.TEXT:
                text_logits = nn.functional.linear(output_embeddings[0, -1], self.lfm.embed_tokens.weight)
                next_token = self._sample_text_token(text_logits, temperature=text_temperature, top_k=text_top_k)

                if next_token == 7:  # <|im_end|>
                    break

                yield next_token

                if next_token == 130:  # <|text_end|>
                    text_done = True
                if not modality_left or text_done:
                    current_modality = LFMModality.AUDIO_OUT
                    modality_left = self.conf.interleaved_n_audio

                in_emb = self.lfm.embed_tokens(next_token)[None, :]

            elif current_modality == LFMModality.AUDIO_OUT:
                next_token = self._sample_audio_frame(
                    output_embeddings[0, -1],
                    temperature=audio_temperature,
                    top_k=audio_top_k,
                )

                if not modality_left and not text_done:
                    current_modality = LFMModality.TEXT
                    modality_left = self.conf.interleaved_n_text

                if next_token[0] == 2048:
                    next_token[:] = 2048
                    current_modality = LFMModality.TEXT

                yield next_token
                in_emb = self.audio_embedding(next_token + self.codebook_offsets).sum(0)[None, None, :]

    def _prefill(
        self,
        *,
        text: torch.Tensor,
        audio_in: torch.Tensor,
        audio_in_lens: torch.Tensor,
        audio_out: torch.Tensor,
        modality_flag: torch.Tensor,
    ) -> torch.Tensor:
        ## Sanity check
        assert len(text.shape) == 2
        assert len(audio_in.shape) == 2
        assert len(audio_in_lens.shape) == 1
        assert len(audio_out.shape) == 2
        assert len(modality_flag.shape) == 2

        assert text.shape[0] == 1

        assert audio_in.shape[0] == 128
        assert audio_out.shape[0] >= self.codebooks
        assert modality_flag.shape[0] == 1

        assert (modality_flag == LFMModality.TEXT).sum() == text.shape[1]
        assert (modality_flag == LFMModality.AUDIO_OUT).sum() == audio_out.shape[1]
        assert (modality_flag == LFMModality.AUDIO_IN).sum() == mel2emb_len(audio_in_lens).sum()
        assert audio_in.shape[1] == audio_in_lens.sum()

        # Text embeddings
        text_emb = self.lfm.embed_tokens(text[0])
        text_mask = modality_flag == LFMModality.TEXT

        # Audio-in embeddings
        ## Batch and pad
        audio_in_list = audio_in.mT.split(audio_in_lens.tolist())
        if audio_in_list:
            padded_audio_in = nn.utils.rnn.pad_sequence(audio_in_list, batch_first=True)
        else:
            padded_audio_in = text_emb.new_empty((0, 8 + 1, 128))

        ## Encode
        audio_enc, audio_in_len = self.conformer(padded_audio_in.mT, audio_in_lens)

        ## Unbatch, unpad
        len_mask = torch.arange(audio_enc.shape[-1], device=audio_enc.device).unsqueeze(0) < audio_in_len.unsqueeze(1)
        audio_enc_concatenated = audio_enc.mT[len_mask]

        ## Adapt
        audio_in_emb = self.audio_adapter(audio_enc_concatenated)
        audio_in_mask = modality_flag == LFMModality.AUDIO_IN
        assert audio_in_emb.shape[0] == audio_in_mask.sum()

        # Audio-out embeddings
        offset_audio_tokens = audio_out[: self.codebooks] + self.codebook_offsets.unsqueeze(1)
        audio_out_emb = self.audio_embedding(offset_audio_tokens).sum(0)
        audio_out_mask = modality_flag == LFMModality.AUDIO_OUT
        assert audio_out_emb.shape[0] == audio_out_mask.sum()

        # Assemble LFM input
        B, L, D = *modality_flag.shape, self.lfm.config.hidden_size

        in_emb = text_emb.new_empty((B, L, D))

        in_emb[text_mask] = text_emb
        in_emb[audio_in_mask] = audio_in_emb
        in_emb[audio_out_mask] = audio_out_emb

        return in_emb

    def _sample_text_token(
        self, logits: torch.Tensor, *, temperature: float | None = None, top_k: int | None = None
    ) -> torch.Tensor:
        greedy = temperature is None or temperature <= 0 or top_k == 1
        if greedy:
            next_token = logits.argmax(keepdim=True)
        else:
            assert isinstance(temperature, float) and temperature > 0
            logits /= temperature
            if top_k is not None:
                min_score = torch.topk(logits, top_k).values[-1]
                to_remove = logits < min_score
                logits = torch.masked_fill(logits, to_remove, -float("inf"))
            probs = logits.softmax(0)
            next_token = torch.multinomial(probs, 1)

        return next_token

    def _sample_audio_frame(
        self,
        embedding: torch.Tensor,  # lfm_dim sized vecto
        *,
        temperature: float | None = None,
        top_k: int | None = None,
    ) -> torch.Tensor:
        greedy = temperature is None or temperature <= 0 or top_k == 1
        depthformer_in = rearrange(self.depth_linear(embedding), "(C D) -> C D", C=self.codebooks, D=self.depthformer_dim)
        depthformer_token = torch.zeros_like(depthformer_in[0])
        cache = None

        out_tokens: list[torch.Tensor] = []
        for i in range(self.codebooks):
            cur_depthformer_input = depthformer_in[i] + depthformer_token
            depthformer_out, cache = self.depthformer.forward_cached(cur_depthformer_input[None, None, :], cache)
            depthformer_logits = self.depth_embeddings[i].get_logits(depthformer_out.squeeze())  # type: ignore[operator]

            if greedy:
                next_token = depthformer_logits.argmax(keepdim=True)
            else:
                assert isinstance(temperature, float) and temperature > 0
                depthformer_logits /= temperature
                if top_k is not None:
                    min_score = torch.topk(depthformer_logits, top_k).values[-1]
                    to_remove = depthformer_logits < min_score
                    depthformer_logits = torch.masked_fill(depthformer_logits, to_remove, -float("inf"))
                probs = depthformer_logits.softmax(0)
                next_token = torch.multinomial(probs, 1)

            out_tokens.append(next_token)
            depthformer_token = self.depth_embeddings[i](next_token).squeeze()

        return torch.cat(out_tokens)

```

## /src/liquid_audio/model/mlp.py

```py path="/src/liquid_audio/model/mlp.py" 
import torch
from torch import nn


class MLP(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        hidden_dim: list[int],
        bias: bool = True,
        use_layer_norm: bool = True,
        dropout: float = 0.0,
    ):
        super().__init__()

        channels = [in_channels, *hidden_dim, out_channels]

        layers: list[nn.Module] = []
        if use_layer_norm:
            layers.append(nn.LayerNorm(channels[0]))

        for i in range(len(channels) - 1):
            layers.append(
                nn.Linear(
                    in_features=channels[i],
                    out_features=channels[i + 1],
                    bias=bias,
                )
            )

            if i != (len(channels) - 2):
                layers.append(nn.GELU())
                if dropout > 0:
                    layers.append(nn.Dropout(p=dropout))

        self.model = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

```

## /src/liquid_audio/model/transformer.py

```py path="/src/liquid_audio/model/transformer.py" 
import math
from abc import ABC, abstractmethod
from collections.abc import Iterable, Sequence
from functools import partial
from typing import Any, Literal, TypeGuard, cast

import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from torch.nn.attention.bias import causal_lower_right

type CacheType = torch.Tensor | None | Sequence["CacheType"]


class SequenceModel(nn.Module, ABC):
    """Models operating on sequences

    Assumptions:
        input shape [N, T, dim]
        output shape [N, T', dim_out]
    """

    dim: int
    dim_out: int

    @abstractmethod
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__()

    @abstractmethod
    def forward(self, x: torch.Tensor, cache: CacheType = None) -> torch.Tensor: ...

    @abstractmethod
    def forward_cached(self, x: torch.Tensor, cache: CacheType = None) -> tuple[torch.Tensor, CacheType]: ...


class LayerKVCache:
    """
    Assumes input cache is a tuple of two tensors (key, value)
    """

    def __init__(self, cache: tuple[torch.Tensor, torch.Tensor] | None = None) -> None:
        assert cache is None or len(cache) == 2
        self.key_cache = cache[0] if cache is not None else None
        self.value_cache = cache[1] if cache is not None else None

    def update(self, k: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        if self.key_cache is None or self.value_cache is None:
            # initialize
            self.key_cache = k
            self.value_cache = v
        else:
            self.key_cache = torch.cat([self.key_cache, k], dim=1)
            self.value_cache = torch.cat([self.value_cache, v], dim=1)
        return self.key_cache, self.value_cache

    def get_cache_size(self) -> int:
        if self.key_cache is None:
            return 0
        else:
            return self.key_cache.shape[1]


class RMSNorm(SequenceModel):
    def __init__(self, dim: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor, cache: CacheType = None) -> torch.Tensor:
        assert cache is None, "RMSNorm expects None cache"

        output = self._norm(x.float())
        return (output * self.weight).type_as(x)

    def forward_cached(self, x: torch.Tensor, cache: CacheType = None) -> tuple[torch.Tensor, CacheType]:
        return self(x, cache), None


class GLU(SequenceModel):
    def __init__(
        self,
        dim: int,
        ff_dim: int | None = None,
        mlp_init_scale: float = 1.0,
        out_init_scale: float = 0.14434,
        use_swiglu: bool = True,
        multiple_of: int = 256,
        ffn_dim_multiplier: float = 1.0,
    ):
        super().__init__()
        # linear_cls = FlexLinear
        self.dim = dim
        self.dim_out = dim
        self.use_swiglu = use_swiglu
        self.num_params = 0
        if ff_dim is None:  # from LFMv1Block
            ff_dim = 4 * dim
        if use_swiglu:
            ff_dim = int(2 * ff_dim / 3)
            # custom dim factor multiplier
            if ffn_dim_multiplier is not None:
                ff_dim = int(ffn_dim_multiplier * ff_dim)
            ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(dim, ff_dim, bias=False)

        self.num_params += dim * ff_dim
        std = mlp_init_scale / math.sqrt(dim)
        torch.nn.init.normal_(self.w1.weight, mean=0.0, std=std)

        if use_swiglu:
            self.w3 = nn.Linear(dim, ff_dim, bias=False)

            self.num_params += dim * ff_dim
            std = mlp_init_scale / math.sqrt(dim)
            torch.nn.init.normal_(self.w3.weight, mean=0.0, std=std)

        self.w2 = nn.Linear(ff_dim, dim, bias=False)

        self.num_params += ff_dim * dim
        std = out_init_scale * mlp_init_scale / math.sqrt(ff_dim)
        torch.nn.init.normal_(self.w2.weight, mean=0.0, std=std)

    def forward(self, x: torch.Tensor, cache: CacheType = None) -> torch.Tensor:
        assert cache is None, "expected None cache for GLU"
        if self.use_swiglu:
            return cast(torch.Tensor, self.w2(F.silu(self.w1(x)) * self.w3(x)))
        else:
            return cast(torch.Tensor, self.w2(F.gelu(self.w1(x))))

    def forward_cached(self, x: torch.Tensor, cache: CacheType = None) -> tuple[torch.Tensor, CacheType]:
        return self(x, cache), None


class BoundedAttention(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int = 32,
        head_style: Literal["mha", "gqa", "mqa"] = "mha",
        gqa_dim: int | None = None,
        qk_layernorm: bool = False,
        norm_eps: float = 1e-5,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0

        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.head_style = head_style

        if self.head_style == "gqa":
            # only access attribute if using gqa head style
            assert gqa_dim is not None
            self.gqa_dim = gqa_dim
            assert self.num_heads % self.gqa_dim == 0, f"{self.gqa_dim} % {self.head_dim} != 0"

        self.qk_layernorm = qk_layernorm

        if self.qk_layernorm:
            self.q_layernorm = RMSNorm(self.head_dim, eps=norm_eps)

            self.k_layernorm = RMSNorm(self.head_dim, eps=norm_eps)

    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        freqs_cis: torch.Tensor | None = None,
        cache: LayerKVCache | None = None,
    ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        bsz, seqlen = q.shape[0], q.shape[1]

        if self.head_style == "mqa":
            q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim)
            k = k.reshape(bsz, seqlen, 1, self.head_dim)
            v = v.reshape(bsz, seqlen, 1, self.head_dim)
        elif self.head_style == "mha":
            q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim)
            k = k.reshape(bsz, seqlen, self.num_heads, self.head_dim)
            v = v.reshape(bsz, seqlen, self.num_heads, self.head_dim)
        elif self.head_style == "gqa":
            q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim)
            k = k.reshape(bsz, seqlen, self.gqa_dim, self.head_dim)
            v = v.reshape(bsz, seqlen, self.gqa_dim, self.head_dim)

        if self.qk_layernorm:
            q = self.q_layernorm(q)
            k = self.k_layernorm(k)

        if freqs_cis is not None:
            q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)

        if cache is not None:
            k, v = cache.update(k, v)

        q_len = q.shape[1]
        kv_len = k.shape[1]

        query = q.transpose(1, 2)
        key = k.transpose(1, 2)
        value = v.transpose(1, 2)

        enable_gqa = self.head_style in ("mqa", "gqa")
        if q_len == kv_len:
            output = nn.functional.scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=enable_gqa)
        else:
            if q_len == 1:
                attn_mask = None
            else:
                attn_mask = causal_lower_right(q_len, kv_len)
            output = nn.functional.scaled_dot_product_attention(
                query, key, value, is_causal=False, enable_gqa=enable_gqa, attn_mask=attn_mask
            )

        output = output.transpose(1, 2)

        output = output.reshape(bsz, seqlen, self.num_heads * self.head_dim)
        return output, (k, v)


class MHA(SequenceModel):
    def __init__(
        self,
        dim: int,
        num_heads: int = 32,
        head_style: Literal["mha", "gqa", "mqa"] = "gqa",
        out_init_scale: float = 0.125,  # 1/sqrt(2*n_layers) (from  gpt3)
        proj_init_scale: float = 1.0,
        qk_layernorm: bool = True,
        norm_eps: float = 0.00001,
        gqa_dim: int = 8,  # Optional if head_style is not "gqa"
        freqs_cis: torch.Tensor | None = None,
        max_seq_len: int = 128_000,  # Stored positional encodings. Optional if freqs_cis is given
        theta: float = 1_000_000.0,  # Positional encoding theta. Optional if freqs_cis is given
    ):
        super().__init__()

        self.dim = self.dim_out = dim
        self.num_heads = num_heads
        assert self.dim % self.num_heads == 0, "expected dim to be divisible by num_heads"
        self.head_dim = self.dim // self.num_heads
        self.head_style = head_style

        if self.head_style == "gqa":
            # only access attribute if using gqa head style
            assert gqa_dim is not None
            self.gqa_dim = gqa_dim

        # q, k, v + optional w and z projections
        if self.head_style == "mha":
            self.total_width = 3 * self.dim
        elif self.head_style == "mqa":
            self.total_width = self.dim + 2 * self.head_dim
        elif self.head_style == "gqa":
            assert self.gqa_dim is not None
            self.total_width = self.dim + 2 * self.head_dim * self.gqa_dim
        else:
            raise NotImplementedError(f"head style {self.head_style} not implemented")

        self.qkv_proj = nn.Linear(
            self.dim,
            self.total_width,
            bias=False,
        )
        std = proj_init_scale / math.sqrt(self.dim)
        torch.nn.init.normal_(self.qkv_proj.weight, mean=0.0, std=std)

        self.out_proj = nn.Linear(self.dim, self.dim, bias=False)

        std = out_init_scale * proj_init_scale / math.sqrt(self.dim)
        torch.nn.init.normal_(self.out_proj.weight, mean=0.0, std=std)

        self.bounded_attention = BoundedAttention(
            dim=dim,
            num_heads=num_heads,
            head_style=head_style,
            gqa_dim=gqa_dim,
            qk_layernorm=qk_layernorm,
            norm_eps=norm_eps,
        )

        if freqs_cis is not None:
            self.freqs_cis = freqs_cis
        else:
            self.freqs_cis = precompute_freqs_cis(self.head_dim, max_seq_len, theta)

    def _validate_cache(self, cache: CacheType) -> TypeGuard[tuple[torch.Tensor, torch.Tensor]]:
        return (
            isinstance(cache, tuple)
            and len(cache) == 2
            and isinstance(cache[0], torch.Tensor)
            and isinstance(cache[1], torch.Tensor)
        )

    def forward(self, x: torch.Tensor, cache: CacheType = None) -> torch.Tensor:
        return self.forward_cached(x, cache)[0]

    def forward_cached(self, x: torch.Tensor, cache: CacheType = None) -> tuple[torch.Tensor, CacheType]:
        if cache is not None:
            assert self._validate_cache(cache)
            kv_cache = LayerKVCache(cache)
        else:
            kv_cache = None

        # x is (bsz, seqlen, d_model)
        seq_len = x.shape[1]

        x = self.qkv_proj(x)
        if self.head_style == "mha":
            xq, xk, xv = x.split(self.dim, dim=-1)
        elif self.head_style == "mqa":
            xq, xk, xv = x.split([self.dim, self.head_dim, self.head_dim], dim=-1)
        elif self.head_style == "gqa":
            xq, xk, xv = x.split(
                [self.dim, self.head_dim * self.gqa_dim, self.head_dim * self.gqa_dim],
                dim=-1,
            )

        # TODO: Need to clean up, hack for now to allow rpes in grafted model if using e.g. mqa for grafting
        self.freqs_cis = self.freqs_cis.to(xq.device)
        if kv_cache is not None:
            # If using cache, get freqs for all new tokens starting from cache size
            cache_size = kv_cache.get_cache_size()
            freqs_cis = self.freqs_cis[cache_size : cache_size + seq_len]
        else:
            # Otherwise get freqs for full sequence
            freqs_cis = self.freqs_cis[:seq_len]

        ys, new_cache = self.bounded_attention(xq, xk, xv, freqs_cis=freqs_cis, cache=kv_cache)

        ys = self.out_proj(ys)

        return cast(torch.Tensor, ys), new_cache


class StandardBlock(SequenceModel):
    """Block with an operator + norm + skip connection, followed by a GLU + norm + skip connection"""

    def __init__(
        self,
        operator: SequenceModel,
        ff_dim: int | None = None,
        mlp_init_scale: float = 1.0,
        out_init_scale: float = 0.125,  # 1/sqrt(2*n_layers) (from gpt3)
        use_swiglu: bool = True,
        multiple_of: int = 256,
        ffn_dim_multiplier: float = 1.0,
        norm_eps: float = 0.00001,
    ):
        super().__init__()
        self.operator = operator
        self.dim = self.dim_out = self.operator.dim

        if ff_dim is None:
            ff_dim = 4 * self.dim

        self.feed_forward = GLU(
            dim=self.dim,
            ff_dim=ff_dim,
            mlp_init_scale=mlp_init_scale,
            out_init_scale=out_init_scale,
            use_swiglu=use_swiglu,
            multiple_of=multiple_of,
            ffn_dim_multiplier=ffn_dim_multiplier,
        )

        self.operator_norm = RMSNorm(self.dim, eps=norm_eps)
        self.ffn_norm = RMSNorm(self.dim, eps=norm_eps)

    def forward(self, x: torch.Tensor, cache: CacheType = None) -> torch.Tensor:
        h = self.operator(self.operator_norm(x), cache)
        h += x
        h_glu = self.feed_forward(self.ffn_norm(h))
        out = h + h_glu
        return cast(torch.Tensor, out)

    def forward_cached(self, x: torch.Tensor, cache: CacheType | None = None) -> tuple[torch.Tensor, CacheType]:
        h, new_cache = self.operator.forward_cached(self.operator_norm(x), cache)
        h += x
        h_glu = self.feed_forward.forward(self.ffn_norm(h))
        out = h + h_glu
        return out, new_cache


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Apply rotary embeddings to input tensors using the given frequency tensor.

    This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
    frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
    is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
    returned as real tensors.

    Args:
        xq (torch.Tensor): Query tensor to apply rotary embeddings.
        xk (torch.Tensor): Key tensor to apply rotary embeddings.
        freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.



    """
    xq_ = torch.view_as_complex(rearrange(xq.float(), "... (D two) -> ... D two", two=2))
    xk_ = torch.view_as_complex(rearrange(xk.float(), "... (D two) -> ... D two", two=2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    """
    Reshape frequency tensor for broadcasting it with another tensor.

    This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
    for the purpose of broadcasting the frequency tensor during element-wise operations.

    Args:
        freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
        x (torch.Tensor): Target tensor for broadcasting compatibility.

    Returns:
        torch.Tensor: Reshaped frequency tensor.

    Raises:
        AssertionError: If the frequency tensor doesn't match the expected shape.
        AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
    """
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
    """
    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.

    This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
    and the end index 'end'. The 'theta' parameter scales the frequencies.
    The returned tensor contains complex values in complex64 data type.

    Args:
        dim (int): Dimension of the frequency tensor.
        end (int): End index for precomputing frequencies.
        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.

    Returns:
        torch.Tensor: Precomputed frequency tensor with complex exponentials.
    """
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis


class SharedEmbedding(nn.Module):
    def __init__(
        self,
        dim: int,
        vocab_size: int = 65_536,
        embed_init_scale: float = 1.0,
        norm_eps: float = 0.00001,
        *,
        tie_embedding: bool = True,
    ) -> None:
        super().__init__()

        self.embedding = torch.nn.Embedding(vocab_size, dim)

        std = embed_init_scale / math.sqrt(dim)
        torch.nn.init.normal_(self.embedding.weight, mean=0.0, std=std)

        self.embedding_norm = RMSNorm(dim, eps=norm_eps)  # Note: this is really the norm before output projection
        self.to_logits = nn.Linear(dim, vocab_size, bias=False)

        if tie_embedding:
            self.to_logits.weight = self.embedding.weight
        else:
            # If not tying embedding, scale the output weights
            std = embed_init_scale / math.sqrt(dim)
            torch.nn.init.normal_(self.to_logits.weight, mean=0.0, std=std)

    def forward(self, tokens: torch.Tensor) -> torch.Tensor:
        return self.embed(tokens)

    def embed(self, tokens: torch.Tensor) -> torch.Tensor:
        return cast(torch.Tensor, self.embedding(tokens))

    def get_logits(self, embeddings: torch.Tensor) -> torch.Tensor:
        return cast(torch.Tensor, self.to_logits(self.embedding_norm(embeddings)))


class RawLMBackbone(SequenceModel):
    """
    "Raw" Backbone for LM models:
    - input: continuous embeddings
    - output: continuous embeddings
    """

    def __init__(
        self,
        layers: Iterable[SequenceModel],
        vocab_size: int = 65_536,
        norm_eps: float = 0.00001,
        embed_init_scale: float = 1.0,
        *,
        has_embedding: bool = True,
        tie_embedding: bool = True,
    ) -> None:
        super().__init__()

        self.layers = cast(Sequence[SequenceModel], nn.ModuleList(layers))
        self.dim = self.layers[0].dim
        self.dim_out = self.layers[-1].dim_out
        assert self.dim == self.dim_out, "expected first layer input dim to be equal to last layer's output dim"

        if has_embedding:
            # TODO: possibly wrap in wrap_sharded
            self.embedding = SharedEmbedding(
                self.dim, vocab_size, embed_init_scale=embed_init_scale, norm_eps=norm_eps, tie_embedding=tie_embedding
            )
        self.has_embedding = has_embedding
        self.vocab_size = vocab_size

    def forward(self, x: torch.Tensor, cache: CacheType | None = None) -> torch.Tensor:
        if cache is not None:
            assert isinstance(cache, list)
            assert len(cache) == len(self.layers)
        else:
            cache = [None] * len(self.layers)

        for layer, layer_cache in zip(self.layers, cache, strict=True):
            x = layer(x, layer_cache)

        return x

    def forward_cached(self, x: torch.Tensor, cache: CacheType | None = None) -> tuple[torch.Tensor, CacheType]:
        if cache is not None:
            assert isinstance(cache, list)
            assert len(cache) == len(self.layers)
        else:
            cache = [None] * len(self.layers)

        cache_out: list[CacheType] = []
        for layer, layer_cache in zip(self.layers, cache, strict=True):
            x, new_cache = layer.forward_cached(x, layer_cache)
            cache_out.append(new_cache)

        return x, cache_out


def wrap_activation_checkpoint[T: nn.Module](mod: T) -> T:
    # NOTE: we're using torch.utils.checkpoint.checkpoint here to avoid the error in backward pass when using zero-3 + activation checkpointing (https://github.com/microsoft/DeepSpeed/issues/4595)
    checkpoint_fn = partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
    mod.forward_ = mod.forward  # type: ignore[assignment]

    def forward(*args, **kwargs):
        return checkpoint_fn(mod.forward_, *args, **kwargs)

    mod.forward = forward
    return mod

```

## /src/liquid_audio/moshi/__init__.py

```py path="/src/liquid_audio/moshi/__init__.py" 
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
moshi is the inference codebase for Kyutai audio generation models.

The code has been adapted from Audiocraft, see LICENSE.audiocraft
  Copyright (c) Meta Platforms, Inc. and affiliates.
"""

# flake8: noqa
from . import conditioners
from . import models
from . import modules
from . import quantization
from . import utils

__version__ = "0.2.12a3"

```

## /src/liquid_audio/moshi/client.py

```py path="/src/liquid_audio/moshi/client.py" 
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Client for the Moshi server."""

import argparse
import asyncio
import queue
import sys

import aiohttp
import numpy as np
import sphn
import sounddevice as sd

from .client_utils import AnyPrinter, Printer, RawPrinter


class Connection:
    def __init__(
        self,
        printer: AnyPrinter,
        websocket: aiohttp.ClientWebSocketResponse,
        sample_rate: float = 24000,
        channels: int = 1,
        frame_size: int = 1920,
    ) -> None:
        self.printer = printer
        self.websocket = websocket
        self.sample_rate = sample_rate
        self.frame_size = frame_size
        self.channels = channels

        self._done = False
        self._in_stream = sd.InputStream(
            samplerate=sample_rate,
            channels=channels,
            blocksize=self.frame_size,
            callback=self._on_audio_input,
        )

        self._out_stream = sd.OutputStream(
            samplerate=sample_rate,
            channels=channels,
            blocksize=frame_size,
            callback=self._on_audio_output,
        )
        self._opus_writer = sphn.OpusStreamWriter(sample_rate)
        self._opus_reader = sphn.OpusStreamReader(sample_rate)
        self._output_queue = queue.Queue()

    async def _queue_loop(self) -> None:
        while True:
            if self._done:
                return
            await asyncio.sleep(0.001)
            msg = self._opus_writer.read_bytes()
            if len(msg) > 0:
                try:
                    await self.websocket.send_bytes(b"\x01" + msg)
                except Exception as e:
                    print(e)
                    self._lost_connection()
                    return

    async def _decoder_loop(self) -> None:
        all_pcm_data = None
        while True:
            if self._done:
                return
            await asyncio.sleep(0.001)
            pcm = self._opus_reader.read_pcm()
            if all_pcm_data is None:
                all_pcm_data = pcm
            else:
                all_pcm_data = np.concatenate((all_pcm_data, pcm))
            while all_pcm_data.shape[-1] >= self.frame_size:
                self._output_queue.put(all_pcm_data[: self.frame_size])
                all_pcm_data = np.array(all_pcm_data[self.frame_size :])

    async def _recv_loop(self) -> None:
        try:
            async for message in self.websocket:
                if message.type == aiohttp.WSMsgType.CLOSED:
                    self.printer.log("info", "Connection closed")
                    break
                elif message.type == aiohttp.WSMsgType.ERROR:
                    self.printer.log("error", f"{self.websocket.exception()}")
                    break
                elif message.type != aiohttp.WSMsgType.BINARY:
                    self.printer.log("error", f"received from server: {message.type}")
                    continue
                message = message.data
                if not isinstance(message, bytes):
                    self.printer.log(
                        "warning", f"unsupported message type {type(message)}"
                    )
                    continue
                if len(message) == 0:
                    self.printer.log("warning", "empty message")
                    continue
                kind = message[0]
                if kind == 1:  # audio
                    payload = message[1:]
                    self._opus_reader.append_bytes(payload)
                    self.printer.print_pending()
                elif kind == 2:  # text
                    payload = message[1:]
                    self.printer.print_token(payload.decode())
                else:
                    self.printer.log("warning", f"unknown message kind {kind}")
        except Exception as e:
            print(e)
            self._lost_connection()
            return

    def _lost_connection(self) -> None:
        if not self._done:
            self.printer.log("error", "Lost connection with the server!")
            self._done = True

    def _on_audio_input(self, in_data, frames, time_, status) -> None:
        assert in_data.shape == (self.frame_size, self.channels), in_data.shape
        self._opus_writer.append_pcm(in_data[:, 0])

    def _on_audio_output(self, out_data, frames, time_, status) -> None:
        assert out_data.shape == (self.frame_size, self.channels), out_data.shape
        try:
            pcm_data = self._output_queue.get(block=False)
            # TODO: handle other shapes by using some form of fifo/ring buffer.
            assert pcm_data.shape == (self.frame_size,), pcm_data.shape
            out_data[:, 0] = pcm_data
        except queue.Empty:
            out_data.fill(0)
            self.printer.print_lag()

    async def run(self) -> None:
        with self._in_stream, self._out_stream:
            await asyncio.gather(
                self._recv_loop(), self._decoder_loop(), self._queue_loop()
            )


async def run(printer: AnyPrinter, args):
    if args.url is None:
        proto = "ws"
        if args.https:
            proto += "s"
        uri = f"{proto}://{args.host}:{args.port}/api/chat"
    else:
        proto = "wss"
        if '://' in args.url:
            proto, without_proto = args.url.split('://', 1)
            if proto in ['ws', 'http']:
                proto = "ws"
            elif proto in ['wss', 'https']:
                proto = "wss"
            else:
                printer.log("error", "The provided URL {args.url} seems to contain a protocol but it is unknown.")
                sys.exit(1)
        else:
            without_proto = args.url
        uri = f"{proto}://{without_proto}/api/chat"

    printer.log("info", f"Connecting to {uri}.")
    async with aiohttp.ClientSession() as session:
        async with session.ws_connect(uri) as ws:
            printer.log("info", "connected!")
            printer.print_header()
            connection = Connection(printer, ws)
            await connection.run()


def main():
    parser = argparse.ArgumentParser("client_opus")
    parser.add_argument("--host", default="localhost", type=str, help="Hostname to connect to.")
    parser.add_argument("--port", default=8998, type=int, help="Port to connect to.")
    parser.add_argument("--https", action='store_true',
                        help="Set this flag for using a https connection.")
    parser.add_argument("--url", type=str, help='Provides directly a URL, e.g. to a gradio tunnel.')
    args = parser.parse_args()
    printer: AnyPrinter

    if sys.stdout.isatty():
        printer = Printer()
    else:
        printer = RawPrinter()
    try:
        asyncio.run(run(printer, args))
    except KeyboardInterrupt:
        printer.log("warning", "Interrupting, exiting connection.")
    printer.log("info", "All done!")


if __name__ == "__main__":
    main()

```

## /src/liquid_audio/moshi/client_gradio.py

```py path="/src/liquid_audio/moshi/client_gradio.py" 
import argparse
from typing import Generator, Literal, cast

import numpy as np
import sphn
from numpy.typing import NDArray

try:
    import gradio as gr  # type: ignore
    import websockets.sync.client
    from gradio_webrtc import AdditionalOutputs, StreamHandler, WebRTC  # type: ignore
except ImportError:
    raise ImportError("Please install gradio-webrtc>=0.0.18 to run this script.")

# See https://freddyaboulton.github.io/gradio-webrtc/deployment/ for
# instructions on how to set the rtc_configuration variable for deployment
# on cloud platforms like Heroku, Spaces, etc.
rtc_configuration = None


class MoshiHandler(StreamHandler):
    def __init__(
        self,
        url: str,
        expected_layout: Literal["mono", "stereo"] = "mono",
        output_sample_rate: int = 24000,
        output_frame_size: int = 480,
    ) -> None:
        self.url = url
        proto, without_proto = self.url.split("://", 1)
        if proto in ["ws", "http"]:
            proto = "ws"
        elif proto in ["wss", "https"]:
            proto = "wss"

        self._generator = None
        self.output_chunk_size = 1920
        self.ws = None
        self.ws_url = f"{proto}://{without_proto}/api/chat"
        self.stream_reader = sphn.OpusStreamReader(output_sample_rate)
        self.stream_writer = sphn.OpusStreamWriter(output_sample_rate)
        self.all_output_data = None
        super().__init__(
            expected_layout,
            output_sample_rate,
            output_frame_size,
            input_sample_rate=24000,
        )

    def receive(self, frame: tuple[int, NDArray]) -> None:
        if not self.ws:
            self.ws = websockets.sync.client.connect(self.ws_url)
        _, array = frame
        array = array.squeeze().astype(np.float32) / 32768.0
        self.stream_writer.append_pcm(array)
        bytes = b"\x01" + self.stream_writer.read_bytes()
        self.ws.send(bytes)

    def generator(
        self,
    ) -> Generator[tuple[int, NDArray] | None | AdditionalOutputs, None, None]:
        for message in cast(websockets.sync.client.ClientConnection, self.ws):
            if len(message) == 0:
                yield None
            kind = message[0]
            if kind == 1:
                payload = message[1:]
                self.stream_reader.append_bytes(payload)
                pcm = self.stream_reader.read_pcm()
                if self.all_output_data is None:
                    self.all_output_data = pcm
                else:
                    self.all_output_data = np.concatenate((self.all_output_data, pcm))
                while self.all_output_data.shape[-1] >= self.output_chunk_size:
                    yield (
                        self.output_sample_rate,
                        self.all_output_data[: self.output_chunk_size].reshape(1, -1),
                    )
                    self.all_output_data = np.array(
                        self.all_output_data[self.output_chunk_size :]
                    )
            elif kind == 2:
                payload = cast(bytes, message[1:])
                yield AdditionalOutputs(payload.decode())

    def emit(self) -> tuple[int, NDArray] | AdditionalOutputs | None:
        if not self.ws:
            return
        if not self._generator:
            self._generator = self.generator()
        try:
            return next(self._generator)
        except StopIteration:
            self.reset()

    def reset(self) -> None:
        self._generator = None
        self.all_output_data = None

    def copy(self) -> StreamHandler:
        return MoshiHandler(
            self.url,
            self.expected_layout,  # type: ignore
            self.output_sample_rate,
            self.output_frame_size,
        )

    def shutdown(self) -> None:
        if self.ws:
            self.ws.close()


def main():
    parser = argparse.ArgumentParser("client_gradio")
    parser.add_argument("--url", type=str, help="URL to moshi server.")
    args = parser.parse_args()

    with gr.Blocks() as demo:
        gr.HTML(
            """
        <div style='text-align: center'>
            <h1>
                Talk To Moshi (Powered by WebRTC ⚡️)
            </h1>
            <p>
                Each conversation is limited to 90 seconds. Once the time limit is up you can rejoin the conversation.
            </p>
        </div>
        """
        )
        chatbot = gr.Chatbot(type="messages", value=[])
        webrtc = WebRTC(
            label="Conversation",
            modality="audio",
            mode="send-receive",
            rtc_configuration=rtc_configuration,
        )
        webrtc.stream(
            MoshiHandler(args.url),
            inputs=[webrtc, chatbot],
            outputs=[webrtc],
            time_limit=90,
        )

        def add_text(chat_history, response):
            if len(chat_history) == 0:
                chat_history.append({"role": "assistant", "content": ""})
            chat_history[-1]["content"] += response
            return chat_history

        webrtc.on_additional_outputs(
            add_text,
            inputs=[chatbot],
            outputs=chatbot,
            queue=False,
            show_progress="hidden",
        )

        demo.launch()


if __name__ == "__main__":
    main()

```

## /src/liquid_audio/moshi/client_utils.py

```py path="/src/liquid_audio/moshi/client_utils.py" 
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Utilities for the command line client, in particular for handling interactions with the terminal.
"""

from dataclasses import dataclass
import sys


def colorize(text, color):
    code = f"\033[{color}m"
    restore = "\033[0m"
    return "".join([code, text, restore])


def make_log(level: str, msg: str) -> str:
    if level == "warning":
        prefix = colorize("[Warn]", "1;31")
    elif level == "info":
        prefix = colorize("[Info]", "1;34")
    elif level == "error":
        prefix = colorize("[Err ]", "1;31")
    else:
        raise ValueError(f"Unknown level {level}")
    return prefix + " " + msg


def log(level: str, msg: str) -> None:
    """Log something with a given level."""
    print(make_log(level, msg))


class RawPrinter:
    def __init__(self, stream=sys.stdout, err_stream=sys.stderr):
        self.stream = stream
        self.err_stream = err_stream

    def print_header(self):
        pass

    def print_token(self, token: str):
        self.stream.write(token)
        self.stream.flush()

    def log(self, level: str, msg: str):
        print(f"{level.capitalize()}: {msg}", file=self.err_stream)

    def print_lag(self):
        self.err_stream.write(colorize(" [LAG]", "31"))
        self.err_stream.flush()

    def print_pending(self):
        pass


@dataclass
class LineEntry:
    msg: str
    color: str | None = None

    def render(self):
        if self.color is None:
            return self.msg
        else:
            return colorize(self.msg, self.color)

    def __len__(self):
        return len(self.msg)


class Line:
    def __init__(self, stream):
        self.stream = stream
        self._line: list[LineEntry] = []
        self._has_padding: bool = False
        self._max_line_length = 0

    def __bool__(self):
        return bool(self._line)

    def __len__(self):
        return sum(len(entry) for entry in self._line)

    def add(self, msg: str, color: str | None = None) -> int:
        entry = LineEntry(msg, color)
        return self._add(entry)

    def _add(self, entry: LineEntry) -> int:
        if self._has_padding:
            self.erase(count=0)
        self._line.append(entry)
        self.stream.write(entry.render())
        self._max_line_length = max(self._max_line_length, len(self))
        return len(entry)

    def erase(self, count: int = 1):
        if count:
            entries = list(self._line[:-count])
        else:
            entries = list(self._line)
        self._line.clear()
        self.stream.write("\r")
        for entry in entries:
            self._line.append(entry)
            self.stream.write(entry.render())

        self._has_padding = False

    def newline(self):
        missing = self._max_line_length - len(self)
        if missing > 0:
            self.stream.write(" " * missing)
        self.stream.write("\n")
        self._line.clear()
        self._max_line_length = 0
        self._has_padding = False

    def flush(self):
        missing = self._max_line_length - len(self)
        if missing > 0:
            self.stream.write(" " * missing)
            self._has_padding = True
        self.stream.flush()


class Printer:
    def __init__(self, max_cols: int = 80, stream=sys.stdout, err_stream=sys.stderr):
        self.max_cols = max_cols
        self.line = Line(stream)
        self.stream = stream
        self.err_stream = err_stream
        self._pending_count = 0
        self._pending_printed = False

    def print_header(self):
        self.line.add(" " + "-" * (self.max_cols) + " ")
        self.line.newline()
        self.line.flush()
        self.line.add("| ")

    def _remove_pending(self) -> bool:
        if self._pending_printed:
            self._pending_printed = False
            self.line.erase(1)
            return True
        return False

    def print_token(self, token: str, color: str | None = None):
        self._remove_pending()
        remaining = self.max_cols - len(self.line)
        if len(token) <= remaining:
            self.line.add(token, color)
        else:
            end = " " * remaining + " |"
            if token.startswith(" "):
                token = token.lstrip()
                self.line.add(end)
                self.line.newline()
                self.line.add("| ")
                self.line.add(token, color)
            else:
                assert color is None
                erase_count = None
                cumulated = ""
                for idx, entry in enumerate(self.line._line[::-1]):
                    if entry.color:
                        # probably a LAG message
                        erase_count = idx
                        break
                    if entry.msg.startswith(" "):
                        erase_count = idx + 1
                        cumulated = entry.msg + cumulated
                        break
                if erase_count is not None:
                    if erase_count > 0:
                        self.line.erase(erase_count)
                    remaining = self.max_cols - len(self.line)
                    end = " " * remaining + " |"
                    self.line.add(end)
                    self.line.newline()
                    self.line.add("| ")
                    token = cumulated.lstrip() + token
                    self.line.add(token)
                else:
                    self.line.add(token[:remaining])
                    self.line.add(" |")
                    self.line.newline()
                    self.line.add("| ")
                    self.line.add(token[remaining:])
        self.line.flush()

    def log(self, level: str, msg: str):
        msg = make_log(level, msg)
        self._remove_pending()
        if self.line:
            self.line.newline()
        self.line.flush()
        print(msg, file=self.err_stream)
        self.err_stream.flush()

    def print_lag(self):
        self.print_token(" [LAG]", "31")

    def print_pending(self):
        chars = ["|", "/", "-", "\\"]
        count = int(self._pending_count / 5)
        char = chars[count % len(chars)]
        colors = ["32", "33", "31"]
        self._remove_pending()
        self.line.add(char, colors[count % len(colors)])
        self._pending_printed = True
        self._pending_count += 1


AnyPrinter = Printer | RawPrinter

```

## /src/liquid_audio/moshi/conditioners/__init__.py

```py path="/src/liquid_audio/moshi/conditioners/__init__.py" 
# flake8: noqa
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Modules to help doing generations under some fixed conditions.
"""

from .base import (ConditionType, ConditionAttributes, ConditionFuser, ConditionProvider,
                   BaseConditioner, TensorCondition, ConditionTensors, dropout_all_conditions)

```

## /src/liquid_audio/moshi/conditioners/base.py

```py path="/src/liquid_audio/moshi/conditioners/base.py" 
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# Adapted from
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

import logging
import typing as tp
from collections import defaultdict
from dataclasses import dataclass, field
from itertools import chain

import torch
from torch import nn

from ..modules.transformer import create_sin_embedding

logger = logging.getLogger(__name__)
TextCondition = tp.Optional[str]  # a text condition can be a string or None (if doesn't exist)
ConditionTensors = dict[str, 'ConditionType']


class ConditionType(tp.NamedTuple):
    """Return type for a conditioner: both a condition tensor, and a mask indicating valid positions.
    """
    condition: torch.Tensor
    mask: torch.Tensor


@dataclass(frozen=True)
class TensorCondition:
    """Looks quite similar to ConditionType, but represents the input to TensorConditioners.
    `tensor` should be [B | 1, T, D], and `mask` should be `[B | 1, T]`.
    """
    tensor: torch.Tensor
    mask: torch.Tensor

    @staticmethod
    def from_tensor(tensor: torch.Tensor):
        B, T, _ = tensor.shape
        mask = torch.ones(B, T, dtype=torch.bool, device=tensor.device)
        return TensorCondition(tensor, mask)

    @staticmethod
    def cat(conditions: tp.Sequence['TensorCondition']) -> 'TensorCondition':
        assert conditions, "Cannot cat empty list."
        ref_tensor = conditions[0].tensor
        B, _, D = ref_tensor.shape
        assert B == 1
        B = len(conditions)
        T = max(condition.tensor.shape[1] for condition in conditions)
        mask = torch.zeros(B, T, dtype=torch.bool, device=ref_tensor.device)
        tensor = torch.zeros(B, T, D, dtype=ref_tensor.dtype, device=ref_tensor.device)
        for b, condition in enumerate(conditions):
            tensor[b, :condition.tensor.shape[1], :] = condition.tensor[0]
            mask[b, :condition.mask.shape[1]] = condition.mask[0]
        return TensorCondition(tensor, mask)


@dataclass
class ConditionAttributes:
    """Standard class for representing the set of potential inputs to the conditioners.
    Typically, `audiocraft.data.audio_dataset.SegmentInfo` will convert
    to this class to make conditioning agnostic to the type of dataset.

    There are two kinds of conditionings: text (or None), or raw torch tensors (with a mask).

    """
    text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
    tensor: tp.Dict[str, TensorCondition] = field(default_factory=dict)

    @property
    def text_attributes(self) -> tp.Iterable[str]:
        return self.text.keys()

    @property
    def tensor_attributes(self) -> tp.Iterable[str]:
        return self.text.keys()

    @staticmethod
    def condition_types() -> tp.FrozenSet[str]:
        return frozenset(["text", "tensor"])

    def copy(self) -> 'ConditionAttributes':
        return ConditionAttributes(dict(self.text), dict(self.tensor))


Prepared = tp.TypeVar('Prepared')  # represents the prepared condition input type.


class BaseConditioner(nn.Module, tp.Generic[Prepared]):
    """Base model for all conditioner modules.

    Args:
        dim (int): internal dim of the model.
        output_dim (int): Output dim of the conditioner.
        force_linear (bool, optional): Force linear projection even when `dim == output_dim`.
        pad_empty (bool): if True, conditionings of 0 length will be padded to have length 1.
        output_bias (bool): if True, the output projection will have a bias.
        learn_padding (bool): if True, the padding value will be learnt, zero otherwise.
    """

    def __init__(self, dim: int,
                 output_dim: int,
                 device: tp.Union[torch.device, str],
                 force_linear: bool = True,
                 pad_empty: bool = True,
                 output_bias: bool = False,
                 learn_padding: bool = True):
        super().__init__()
        self.dim = dim
        self.output_dim = output_dim
        self.pad_empty = pad_empty
        self.device = device
        self.output_proj: nn.Module
        if force_linear or dim != output_dim:
            self.output_proj = nn.Linear(dim, output_dim, bias=output_bias, device=device)
            assert not output_bias
        else:
            self.output_proj = nn.Identity()
        self.learnt_padding: tp.Optional[torch.Tensor]
        if learn_padding:
            self.learnt_padding = nn.Parameter(
                torch.randn(1, 1, output_dim, device=device), requires_grad=True)
            self.learnt_padding.data *= 0.2
        else:
            self.learnt_padding = None

    def prepare(self, *args, **kwargs) -> Prepared:
        """Should be any part of the processing that will lead to a synchronization
        point, e.g. BPE tokenization with transfer to the GPU.

        The returned value will be saved and return later when calling forward().
        """
        raise NotImplementedError()

    def _get_condition(self, inputs: Prepared) -> ConditionType:
        """Gets input that should be used as conditioning (e.g, genre, description or a waveform).
        Outputs a ConditionType, after the input data was embedded as a dense vector.

        Returns:
            ConditionType:
                - A tensor of size [B, T, dim] where B is the batch size, T is the length of the
                  output embedding and `dim` is the internal dimension of the embedding.
                - And a mask indicating where the padding tokens, of shape `[B, T]`.
        """
        raise NotImplementedError()

    def forward(self, inputs: Prepared) -> ConditionType:
        cond, mask = self._get_condition(inputs)
        B, T, C = cond.shape
        if T == 0 and self.pad_empty:
            cond = torch.zeros(B, T, C, device=cond.device, dtype=cond.dtype)
            mask = torch.zeros_like(cond[..., 0], dtype=torch.bool)

        cond = self.output_proj(cond)

        maskf = mask.float()[..., None]
        if self.learnt_padding is not None:
            cond = cond * maskf + self.learnt_padding * (1 - maskf)
        else:
            cond = cond * maskf
        return ConditionType(cond, mask)


class _BaseTextConditioner(BaseConditioner[Prepared]):
    pass


class _BaseTensorConditioner(BaseConditioner[Prepared]):
    pass


def dropout_tensor(condition: TensorCondition) -> TensorCondition:
    """Utility function for nullifying a WavCondition object.
    """
    return TensorCondition(
        tensor=torch.zeros_like(condition.tensor),
        mask=torch.zeros_like(condition.mask))


def dropout_condition_(sample: ConditionAttributes, condition_type: str, condition: str) -> None:
    """Utility function for nullifying an attribute inside a ConditionAttributes object.
    Works in-place.
    """
    valid_conditions = ConditionAttributes.condition_types()
    if condition_type not in valid_conditions:
        raise ValueError(
            "dropout_condition got an unexpected condition type!"
            f" expected one of {valid_conditions} but got '{condition_type}'")

    if condition not in getattr(sample, condition_type):
        raise ValueError(
            "dropout_condition received an unexpected condition!"
            f" expected tensor={sample.tensor.keys()} and text={sample.text.keys()}"
            f" but got '{condition}' of type '{condition_type}'!"
        )

    if condition_type == 'tensor':
        tensor_condition = sample.tensor[condition]
        sample.tensor[condition] = dropout_tensor(tensor_condition)
    elif condition_type == 'text':
        sample.text[condition] = None
    else:
        assert False


def dropout_all_conditions(attributes: tp.Sequence[ConditionAttributes]) -> list[ConditionAttributes]:
    """
    Args:
        attributes (list[ConditionAttributes]): All conditions attributes.
    Returns:
        list[ConditionAttributes]: Same with all conditions dropped.
    """
    attributes = [attribute.copy() for attribute in attributes]
    for condition_type in ConditionAttributes.condition_types():
        for attribute in attributes:
            for condition in getattr(attribute, condition_type):
                dropout_condition_(attribute, condition_type, condition)
    return attributes


class ConditionProvider(nn.Module):
    """Prepare and provide conditions given all the supported conditioners.

    Args:
        conditioners (dict): Dictionary of conditioners.
        device (torch.device or str, optional): Device for conditioners and output condition types.
    """

    def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"):
        super().__init__()
        self.device = device
        self.conditioners = nn.ModuleDict(conditioners).to(device)

    @property
    def text_conditions(self):
        return [k for k, v in self.conditioners.items() if isinstance(v, _BaseTextConditioner)]

    @property
    def tensor_conditions(self):
        return [k for k, v in self.conditioners.items() if isinstance(v, _BaseTensorConditioner)]

    def _collate_text(self, samples: tp.Sequence[ConditionAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
        """Given a list of ConditionAttributes objects, compile a dictionary where the keys
        are the attributes and the values are the aggregated input per attribute.
        For example:
        Input:
        [
            ConditionAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
            ConditionAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
        ]
        Output:
        {
            "genre": ["Rock", "Hip-hop"],
            "description": ["A rock song with a guitar solo", "A hip-hop verse"]
        }

        Args:
            samples (list of ConditionAttributes): List of ConditionAttributes samples.
        Returns:
            dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch.
        """
        out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
        texts = [x.text for x in samples]
        for text in texts:
            for condition in self.text_conditions:
                out[condition].append(text[condition])
        return out

    def _collate_tensors(self, samples: tp.Sequence[ConditionAttributes]) -> tp.Dict[str, TensorCondition]:
        """For each tensor attribute, collate the tensor from individual batch items.

        Args:
            samples (list of ConditionAttributes): List of ConditionAttributes samples.
        Returns:
            dict[str, TensorCondition]: A dictionary mapping an attribute name to tensor.
        """
        per_attribute = defaultdict(list)
        out: tp.Dict[str, TensorCondition] = {}
        for sample in samples:
            for attribute in self.tensor_conditions:
                per_attribute[attribute].append(sample.tensor[attribute])

        # stack all tensors to a single tensor
        for attribute in self.tensor_conditions:
            out[attribute] = TensorCondition.cat(per_attribute[attribute])

        return out

    def prepare(self, inputs: tp.Sequence[ConditionAttributes]) -> tp.Dict[str, tp.Any]:
        """Match attributes/tensors with existing conditioners in self, and call `prepare` for each one.
        This should be called before starting any real GPU work to avoid synchronization points.
        This will return a dict matching conditioner names to their arbitrary prepared representations.

        Args:
            inputs (list[ConditionAttributes]): List of ConditionAttributes objects containing
                text and tensors conditions.
        """
        assert all([isinstance(x, ConditionAttributes) for x in inputs]), (
            "Got unexpected types input for conditioner! should be tp.List[ConditionAttributes]",
            f" but types were {set([type(x) for x in inputs])}"
        )

        output = {}
        text = self._collate_text(inputs)
        tensors = self._collate_tensors(inputs)

        assert set(text.keys() | tensors.keys()).issubset(set(self.conditioners.keys())), (
            f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
            f"got {text.keys(), tensors.keys()}"
        )

        missing_inputs = set(self.conditioners.keys()) - (set(text.keys()) | set(tensors.keys()))
        if missing_inputs:
            raise RuntimeError(f"Some conditioners did not receive an input: {missing_inputs}")
        for attribute, batch in chain(text.items(), tensors.items()):
            conditioner = self.conditioners[attribute]
            assert isinstance(conditioner, BaseConditioner)
            output[attribute] = conditioner.prepare(batch)
        return output

    def forward(self, prepared: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
        """Compute pairs of `(embedding, mask)` using the configured conditioners and the prepared representations.
        The output is for example:
        {
            "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
            "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
            ...
        }

        Args:
            prepared (dict): Dict of prepared representations as returned by `prepare()`.
        """
        output = {}
        for name, inputs in prepared.items():
            condition, mask = self.conditioners[name](inputs)
            output[name] = ConditionType(condition, mask)
        return output

    def prepare_and_provide(self, inputs: tp.Sequence[ConditionAttributes]):
        """See .prepare() and .forward()."""
        prepared = self.prepare(inputs)
        return self(prepared)


class ConditionFuser(nn.Module):
    """Condition fuser handles the logic to combine the different conditions
    to the actual model input.

    Args:
        fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
            each condition. For example:
            {
                "prepend": ["description"],
                "sum": ["genre", "bpm"],
                "cross": ["description"],
            }
        cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
        cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
    """
    FUSING_METHODS = ["sum", "prepend", "cross"]

    def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
                 cross_attention_pos_emb_scale: float = 1.0):
        super().__init__()
        assert all(
            [k in self.FUSING_METHODS for k in fuse2cond.keys()]
        ), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}"
        self.cross_attention_pos_emb = cross_attention_pos_emb
        self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
        self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
        self.cond2fuse: tp.Dict[str, str] = {}
        for fuse_method, conditions in fuse2cond.items():
            for condition in conditions:
                self.cond2fuse[condition] = fuse_method
                if fuse_method not in ['cross', 'sum']:
                    raise RuntimeError("only `sum` and `cross` conditionings are supported "
                                       f"for now, got {fuse_method}.")

    @property
    def has_conditions(self) -> bool:
        return bool(self.cond2fuse)

    @property
    def has_prepend(self) -> bool:
        """Is there a conditioning that needs to be prepending to the Transformer sequence."""
        return bool(self.fuse2cond['prepend'])

    def get_cross(self, conditions: ConditionTensors) -> torch.Tensor | None:
        """Return the tensor to be provided for the cross attention."""
        cross = None
        for name in self.fuse2cond['cross']:
            cond, _ = conditions[name]
            if cross is None:
                cross = cond
            else:
                cross = torch.cat([cross, cond], dim=1)

        if self.cross_attention_pos_emb and cross is not None:
            positions = torch.arange(
                cross.shape[1],
                device=cross.device
            ).view(1, -1, 1)
            pos_emb = create_sin_embedding(positions, cross.shape[-1]).to(cross.dtype)
            cross = cross + self.cross_attention_pos_emb_scale * pos_emb
        return cross

    def get_sum(self, conditions: ConditionTensors) -> torch.Tensor | None:
        """Return the tensor to be provided as an extra sum offset shared for each step."""
        sum = None
        for name in self.fuse2cond['sum']:
            cond, _ = conditions[name]
            assert cond.shape[1] == 1, cond.shape
            if sum is None:
                sum = cond
            else:
                sum = sum + cond
        return sum

    def get_prepend(self, conditions: ConditionTensors) -> torch.Tensor | None:
        """Return the tensor to be prepended to the transformer."""
        prepend = None
        for name in self.fuse2cond['prepend']:
            cond, _ = conditions[name]
            if prepend is None:
                prepend = cond
            else:
                prepend = torch.cat([cond, prepend], dim=1)
        if prepend is not None:
            sum = self.get_sum(conditions)
            if sum is not None:
                prepend = prepend + sum
        return prepend

```

## /src/liquid_audio/moshi/conditioners/tensors.py

```py path="/src/liquid_audio/moshi/conditioners/tensors.py" 
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from .base import _BaseTensorConditioner, TensorCondition, ConditionType


class TensorConditioner(_BaseTensorConditioner[TensorCondition]):
    """Does basically nothing.
    """

    def prepare(self, tensor: TensorCondition) -> TensorCondition:
        device = next(iter(self.parameters())).device
        return TensorCondition(tensor.tensor.to(device=device), tensor.mask.to(device=device))

    def _get_condition(self, inputs: TensorCondition) -> ConditionType:
        return ConditionType(inputs.tensor, inputs.mask)

```

## /src/liquid_audio/moshi/models/__init__.py

```py path="/src/liquid_audio/moshi/models/__init__.py" 
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Models for the compression model Moshi,
"""

# flake8: noqa
from .compression import (
    CompressionModel,
    MimiModel,
)
from .lm import LMModel, LMGen
from .loaders import get_mimi, get_moshi_lm

```

## /src/liquid_audio/moshi/modules/__init__.py

```py path="/src/liquid_audio/moshi/modules/__init__.py" 
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Modules used for building the models."""

# flake8: noqa
from .conv import (
    NormConv1d,
    NormConvTranspose1d,
    StreamingConv1d,
    StreamingConvTranspose1d,
    pad_for_conv1d,
    pad1d,
    unpad1d,
)
from .seanet import SEANetEncoder, SEANetDecoder
from .transformer import StreamingTransformer

```

## /src/liquid_audio/moshi/modules/resample.py

```py path="/src/liquid_audio/moshi/modules/resample.py" 
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import typing as tp

from einops import rearrange
import torch
from torch import nn

from .conv import StreamingConv1d, StreamingConvTranspose1d


class ConvDownsample1d(nn.Module):
    """
    Downsampling by some integer amount `stride` using convolutions
    with a kernel size of twice the stride.
    If `causal` is True, the output uses a causal convolution.
    """

    def __init__(
        self,
        stride: int,
        dimension: tp.Optional[int] = None,
        causal: bool = False,
        learnt: bool = False,
        channel_wise: bool = False,
    ):
        super().__init__()
        self.learnt = learnt
        self.channel_wise = channel_wise
        groups = 1
        if learnt:
            assert dimension is not None, "Dimension required for learnt convolutions."
            in_channels = dimension
            out_channels = dimension
            if channel_wise:
                groups = dimension
        else:
            in_channels = 1
            out_channels = 1

        self.conv = StreamingConv1d(
            in_channels,
            out_channels,
            kernel_size=2 * stride,
            stride=stride,
            causal=causal,
            groups=groups,
            bias=False,
            pad_mode="replicate",
        )
        if not learnt:
            actual_conv = self.conv.conv.conv
            actual_conv.weight.requires_grad_(False)
            actual_conv.weight.data.fill_(1.0 / (2 * stride))

    def forward(self, x: torch.Tensor):
        batch_size = len(x)
        if not self.learnt:
            x = rearrange(x, "b c t -> (b c) () t")
        y = self.conv(x)
        if not self.learnt:
            y = rearrange(y, "(b c) () t -> b c t", b=batch_size)
        return y


class ConvTrUpsample1d(nn.Module):
    """
    Upsample by some integer amount `stride` using transposed convolutions.
    """

    def __init__(
        self,
        stride: int,
        dimension: tp.Optional[int] = None,
        causal: bool = False,
        learnt: bool = False,
        channel_wise: bool = False,
    ):
        super().__init__()
        self.learnt = learnt
        self.channel_wise = channel_wise
        groups = 1
        if learnt:
            assert dimension is not None, "Dimension required for learnt convolutions."
            in_channels = dimension
            out_channels = dimension
            if channel_wise:
                groups = dimension
        else:
            in_channels = 1
            out_channels = 1

        self.convtr = StreamingConvTranspose1d(
            in_channels,
            out_channels,
            kernel_size=2 * stride,
            stride=stride,
            causal=causal,
            groups=groups,
            bias=False,
        )
        if not learnt:
            actual_convtr = self.convtr.convtr.convtr
            actual_convtr.weight.requires_grad_(False)
            actual_convtr.weight.data.fill_(1.0)

    def forward(self, x: torch.Tensor):
        batch_size = len(x)
        if not self.learnt:
            x = rearrange(x, "b c t -> (b c) () t")
        y = self.convtr(x)
        if not self.learnt:
            x_for_normalization = torch.ones_like(x[:1])
            normalization = self.convtr(x_for_normalization)
            y = y / normalization
            y = rearrange(y, "(b c) () t -> b c t", b=batch_size)
        return y

```


The content has been capped at 50000 tokens. The user could consider applying other filters to refine the result. The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.
Copied!