Skip to content

Clean up xcodec addition. #40271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -693,8 +693,6 @@
title: UL2
- local: model_doc/umt5
title: UMT5
- local: model_doc/xcodec
title: X-CODEC
- local: model_doc/xmod
title: X-MOD
- local: model_doc/xglm
Expand Down Expand Up @@ -945,6 +943,8 @@
title: WavLM
- local: model_doc/whisper
title: Whisper
- local: model_doc/xcodec
title: X-Codec
- local: model_doc/xls_r
title: XLS-R
- local: model_doc/xlsr_wav2vec2
Expand Down
29 changes: 19 additions & 10 deletions docs/source/en/model_doc/xcodec.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
Expand All @@ -13,6 +13,7 @@ specific language governing permissions and limitations under the License.
rendered properly in your Markdown viewer.

-->
*This model was released on 2024-08-30 and added to Hugging Face Transformers on 2025-08-15.*

# X-Codec

Expand All @@ -22,23 +23,28 @@ rendered properly in your Markdown viewer.

## Overview

The X-Codec model was proposed in [Codec Does Matter: Exploring the Semantic Shortcoming of Codec for Audio Language Model](https://arxiv.org/abs/2408.17175) by Zhen Ye, Peiwen Sun, Jiahe Lei, Hongzhan Lin, Xu Tan, Zheqi Dai, Qiuqiang Kong, Jianyi Chen, Jiahao Pan, Qifeng Liu, Yike Guo, Wei Xue
The X-Codec model was proposed in [Codec Does Matter: Exploring the Semantic Shortcoming of Codec for Audio Language Model](https://huggingface.co/papers/2408.17175) by Zhen Ye, Peiwen Sun, Jiahe Lei, Hongzhan Lin, Xu Tan, Zheqi Dai, Qiuqiang Kong, Jianyi Chen, Jiahao Pan, Qifeng Liu, Yike Guo, Wei Xue.

The X-Codec model is a neural audio codec that integrates semantic information from self-supervised models (e.g., HuBERT) alongside traditional acoustic information. This enables :
The X-Codec model is a neural audio codec that integrates semantic information from self-supervised models (e.g., HuBERT) alongside traditional acoustic information. This enables:

- **Music continuation** : Better modeling of musical semantics yields more coherent continuations.
- **Text-to-Sound Synthesis** : X-Codec captures semantic alignment between text prompts and generated audio.
- **Music continuation**: Better modeling of musical semantics yields more coherent continuations.
- **Text-to-Sound Synthesis**: X-Codec captures semantic alignment between text prompts and generated audio.
- **Semantic aware audio tokenization**: X-Codec is used as an audio tokenizer in the YuE lyrics to song generation model.

The abstract of the paper states the following:

*Recent advancements in audio generation have been significantly propelled by the capabilities of Large Language Models (LLMs). The existing research on audio LLM has primarily focused on enhancing the architecture and scale of audio language models, as well as leveraging larger datasets, and generally, acoustic codecs, such as EnCodec, are used for audio tokenization. However, these codecs were originally designed for audio compression, which may lead to suboptimal performance in the context of audio LLM. Our research aims to address the shortcomings of current audio LLM codecs, particularly their challenges in maintaining semantic integrity in generated audio. For instance, existing methods like VALL-E, which condition acoustic token generation on text transcriptions, often suffer from content inaccuracies and elevated word error rates (WER) due to semantic misinterpretations of acoustic tokens, resulting in word skipping and errors. To overcome these issues, we propose a straightforward yet effective approach called X-Codec. X-Codec incorporates semantic features from a pre-trained semantic encoder before the Residual Vector Quantization (RVQ) stage and introduces a semantic reconstruction loss after RVQ. By enhancing the semantic ability of the codec, X-Codec significantly reduces WER in speech synthesis tasks and extends these benefits to non-speech applications, including music and sound generation. Our experiments in text-to-speech, music continuation, and text-to-sound tasks demonstrate that integrating semantic information substantially improves the overall performance of language models in audio generation.*

Demos can be found in this [post](https://x-codec-audio.github.io/).
Model cards:
- [xcodec-hubert-librispeech](https://huggingface.co/hf-audio/xcodec-hubert-librispeech) (for speech)
- [xcodec-wavlm-mls](https://huggingface.co/hf-audio/xcodec-wavlm-mls) (for speech)
- [xcodec-wavlm-more-data](https://huggingface.co/hf-audio/xcodec-wavlm-more-data) (for speech)
- [xcodec-hubert-general](https://huggingface.co/hf-audio/xcodec-hubert-general) (for general audio)
- [xcodec-hubert-general-balanced](https://huggingface.co/hf-audio/xcodec-hubert-general-balanced) (for general audio)

This model was contributed by [Manal El Aidouni](https://huggingface.co/Manel). The original code can be found [here](https://github.com/zhenye234/xcodec) and original checkpoints for the five different models [here](https://github.com/zhenye234/xcodec?tab=readme-ov-file#available-models).

This model was contributed by [Manal El Aidouni](https://huggingface.co/Manel). The original code can be found [here](https://github.com/zhenye234/xcodec) and original checkpoint [here](https://huggingface.co/ZhenYe234/xcodec/blob/main/xcodec_speech_hubert_librispeech.pth).

Demos can be found on this [page](https://x-codec-audio.github.io/).


## Usage example
Expand All @@ -51,13 +57,16 @@ from transformers import XcodecModel, AutoFeatureExtractor
dummy_dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

# load model and feature extractor
model = XcodecModel.from_pretrained("Manel/X-Codec")
feature_extractor = AutoFeatureExtractor.from_pretrained("Manel/X-Codec")
model_id = "hf-audio/xcodec-hubert-librispeech"
model = XcodecModel.from_pretrained(model_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)

# load audio sample
dummy_dataset = dummy_dataset.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
audio_sample = dummy_dataset[-1]["audio"]["array"]
inputs = feature_extractor(raw_audio=audio_sample, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")

# encode and decode
encoder_outputs = model.encode(inputs["input_values"])
decoder_outputs = model.decode(encoder_outputs.audio_codes)
audio_values = decoder_outputs.audio_values
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/feature_extraction_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@
("wavlm", "Wav2Vec2FeatureExtractor"),
("whisper", "WhisperFeatureExtractor"),
("xclip", "CLIPFeatureExtractor"),
("xcodec", "EncodecFeatureExtractor"),
("xcodec", "DacFeatureExtractor"),
("yolos", "YolosFeatureExtractor"),
]
)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/xcodec/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
100 changes: 50 additions & 50 deletions src/transformers/models/xcodec/configuration_xcodec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -19,7 +19,7 @@

import numpy as np

from transformers import DacConfig, HubertConfig
from transformers import AutoConfig, DacConfig, HubertConfig, WavLMConfig

from ...configuration_utils import PretrainedConfig
from ...utils import logging
Expand All @@ -41,14 +41,8 @@ class XcodecConfig(PretrainedConfig):
Args:
target_bandwidths (`List[float]`, *optional*, defaults to `[0.5, 1, 1.5, 2, 4]`):
The range of different bandwidths (in kbps) the model can encode audio with.
audio_channels (`int`, *optional*, defaults to 1):
Number of channels in the audio data. Either 1 for mono or 2 for stereo.
sample_rate (`int`, *optional*, defaults to 16000):
The sampling rate at which the audio waveform should be digitalized, in hertz (Hz).
input_channels (`int`, *optional*, defaults to 768):
Number of channels of the input to the first convolution in the semantic encoder.
encoder_channels (`int`, *optional*, defaults to 768):
Number of hidden channels in each semantic encoder block.
kernel_size (`int`, *optional*, defaults to 3):
Kernel size for the initial semantic convolution.
channel_ratios (`List[float]`, *optional*, defaults to `[1, 1]`):
Expand All @@ -59,38 +53,26 @@ class XcodecConfig(PretrainedConfig):
Dilation factors for the residual units in semantic blocks.
unit_kernel_size (`int`, *optional*, defaults to 3):
Kernel size inside each ResidualUnit in semantic blocks.
decoder_channels (`int`, *optional*, defaults to 768):
Number of hidden channels in each semantic decoder block.
output_channels (`int`, *optional*, defaults to 768):
Number of output channels in the semantic decoder.
codebook_size (`int`, *optional*, defaults to 1024):
Number of entries in each residual quantizer’s codebook.
num_quantizers (`int`, *optional*, defaults to 8):
Number of sequential quantizers (codebooks) in the RVQ stack.
codebook_dim (`int`, *optional*, defaults to 1024):
Dimensionality of each codebook vector.
Number of entries in each residual quantizer's codebook.
codebook_dim (`int`, *optional*):
Dimensionality of each codebook vector. Defaults to sum of hidden size of acoustic and semantic models.
initializer_range (`float`, *optional*, defaults to 0.02):
Standard deviation of the truncated normal initializer for all weight matrices.
hidden_dim (`int`, *optional*, defaults to 1024):
Dimensionality of the joint acoustic+semantic FC layer.
intermediate_dim (`int`, *optional*, defaults to 768):
Dimensionality of the next FC layer in the decoder path.
output_dim (`int`, *optional*, defaults to 256):
Dimensionality of the final FC layer before feeding into the acoustic decoder.
acoustic_model_config (`Union[Dict, DacConfig]`, *optional*):
An instance of the configuration for the acoustic (DAC) model.
semantic_model_config (`Union[Dict, HubertConfig]`, *optional*):
semantic_model_config (`Union[Dict, HubertConfig, WavLMConfig]`, *optional*):
An instance of the configuration object for the semantic (HuBERT) model.

Example:

```python
>>> from transformers import XcodecModel, XcodecConfig

>>> # Initializing a " " style configuration
>>> # Initializing configuration
>>> configuration = XcodecConfig()

>>> # Initializing a model (with random weights) from the " " style configuration
>>> # Initializing a model (with random weights) from the configuration
>>> model = XcodecModel(configuration)

>>> # Accessing the model configuration
Expand All @@ -101,30 +83,21 @@ class XcodecConfig(PretrainedConfig):

sub_configs = {
"acoustic_model_config": DacConfig,
"semantic_model_config": HubertConfig,
"semantic_model_config": AutoConfig,
}

def __init__(
self,
target_bandwidths: Optional[list[float]] = None,
audio_channels: int = 1,
sample_rate: int = 16000,
input_channels: int = 768,
encoder_channels: int = 768,
kernel_size: int = 3,
channel_ratios: list[float] = [1, 1],
strides: list[int] = [1, 1],
block_dilations: list[int] = [1, 1],
unit_kernel_size: int = 3,
decoder_channels: int = 768,
output_channels: int = 768,
codebook_size: int = 1024,
num_quantizers: int = 8,
codebook_dim: int = 1024,
codebook_dim: Optional[int] = None,
initializer_range: float = 0.02,
hidden_dim: int = 1024,
intermediate_dim: int = 768,
output_dim: int = 256,
acoustic_model_config: Union[dict, DacConfig] = None,
semantic_model_config: Union[dict, HubertConfig] = None,
**kwargs,
Expand All @@ -134,6 +107,8 @@ def __init__(
if acoustic_model_config is None:
self.acoustic_model_config = DacConfig(
encoder_hidden_size=64,
# NOTE: original DAC uses [2, 4, 8, 8] `downsampling ratios`, namely reverse of `upsampling_ratios`
# (not sure if intentional by Xcodec but we keep it)
downsampling_ratios=[8, 5, 4, 2],
decoder_hidden_size=1024,
upsampling_ratios=[8, 5, 4, 2],
Expand All @@ -143,44 +118,69 @@ def __init__(
self.acoustic_model_config = DacConfig(**acoustic_model_config)
elif isinstance(acoustic_model_config, DacConfig):
self.acoustic_model_config = acoustic_model_config
else:
raise ValueError(
f"acoustic_model_config must be a dict or DacConfig instance, but got {type(acoustic_model_config)}"
)

if semantic_model_config is None:
self.semantic_model_config = HubertConfig()
elif isinstance(semantic_model_config, dict):
self.semantic_model_config = HubertConfig(**semantic_model_config)
elif isinstance(semantic_model_config, HubertConfig):
if "_name_or_path" in semantic_model_config:
# If the config is a path, load it using AutoConfig
self.semantic_model_config = AutoConfig.from_pretrained(semantic_model_config["_name_or_path"])
else:
# assume HubertConfig as probably created from scratch
logger.warning(
"Could not determine semantic model type from config architecture. Defaulting to `HubertConfig`."
)
self.semantic_model_config = HubertConfig(**semantic_model_config)
elif isinstance(semantic_model_config, WavLMConfig) or isinstance(semantic_model_config, HubertConfig):
self.semantic_model_config = semantic_model_config
else:
raise ValueError(
f"semantic_model_config must be a dict, HubertConfig, or WavLMConfig instance, but got {type(semantic_model_config)}"
)

if target_bandwidths is None:
target_bandwidths = [0.5, 1, 1.5, 2, 4]

self.target_bandwidths = target_bandwidths
self.audio_channels = audio_channels
self.sample_rate = sample_rate
self.input_channels = input_channels
self.encoder_channels = encoder_channels
self.kernel_size = kernel_size
self.channel_ratios = channel_ratios
self.strides = strides
self.block_dilations = block_dilations
self.unit_kernel_size = unit_kernel_size
self.decoder_channels = decoder_channels
self.output_channels = output_channels
self.codebook_size = codebook_size
self.num_quantizers = num_quantizers
self.codebook_dim = codebook_dim
self.initializer_range = initializer_range
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.output_dim = output_dim
if codebook_dim is None:
codebook_dim = self.acoustic_model_config.hidden_size + self.semantic_model_config.hidden_size
self.codebook_dim = codebook_dim

@property
def frame_rate(self) -> int:
return math.ceil(self.sample_rate / np.prod(self.acoustic_model_config.upsampling_ratios))
return math.ceil(self.sample_rate / self.hop_length)

@property
def semantic_hidden_size(self) -> int:
return self.semantic_model_config.hidden_size

@property
def hop_length(self) -> int:
return int(np.prod(self.acoustic_model_config.downsampling_ratios))

@property
def codebook_nbits(self) -> int:
return math.ceil(math.log2(self.codebook_size))

@property
def hidden_size(self) -> int:
return self.acoustic_model_config.hidden_size + self.semantic_model_config.hidden_size

@property
def num_quantizers(self) -> int:
return int(1000 * self.target_bandwidths[-1] // (self.frame_rate * self.codebook_nbits))


__all__ = ["XcodecConfig"]
Loading