Skip to content

Commit 56bcd76

Browse files
Merge pull request #2952 from AI-Hypercomputer:aireen/qwen-audio
PiperOrigin-RevId: 860210039
2 parents 4bcee99 + 0084322 commit 56bcd76

18 files changed

Lines changed: 1646 additions & 133 deletions

pytest.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ addopts =
1818
--ignore=tests/unit/moba_vs_reference_test.py
1919
--ignore=tests/unit/offline_engine_test.py
2020
--ignore=tests/unit/profiler_test.py
21-
--ignore=tests/unit/qwen3_embedding_vs_reference_test.py
21+
--ignore=tests/unit/qwen3_omni_layers_test.py
2222
--ignore=tests/unit/qwen3_next_vs_reference_test.py
2323
markers =
2424
tpu_only: marks tests to be run on TPUs only

src/MaxText/configs/base.yml

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -942,7 +942,9 @@ temperature_tuning: False
942942

943943
# Multimodal flags
944944
use_multimodal: False
945+
use_audio: False
945946
freeze_vision_encoder_params: True
947+
freeze_audio_encoder_params: True
946948
dtype_mm: "float32" # Data type for multimodal model's vision encoder
947949
remat_policy_for_vit: "minimal" # Remat policy for multimodal model's vision encoder. Check `remat_policy` for options.
948950
image_size_for_vit: 896 # Default for Gemma3, and should be overwritten by model's config
@@ -980,7 +982,27 @@ temporal_patch_size_for_vit: 2
980982
num_position_embeddings_for_vit: 1024
981983
deepstack_visual_indexes_for_vit: []
982984

983-
# Subslice shape in the form of "x,y,z" when using pathways (single controller).
985+
### Audio encoder configs (Qwen3-OmniMoe)
986+
d_model_for_audio: 256
987+
encoder_attention_heads_for_audio: 4
988+
encoder_ffn_dim_for_audio: 512
989+
encoder_layers_for_audio: 2
990+
attention_dropout_for_audio: 0.0
991+
activation_dropout_for_audio: 0.0
992+
activation_function_for_audio: "gelu"
993+
num_mel_bins_for_audio: 128
994+
max_source_positions_for_audio: 1500
995+
scale_embedding_for_audio: True
996+
n_window_for_audio: 50
997+
n_window_infer_for_audio: 800
998+
conv_chunksize_for_audio: 500
999+
downsample_hidden_size_for_audio: 256
1000+
output_dim_for_audio: 512
1001+
num_conv_layers_for_audio: 3
1002+
max_timescale_for_audio: 10000.0
1003+
max_sample_len_for_audio: 10000
1004+
1005+
# Subslice shape in the form of "x,y,z" when using pathways (single controller).
9841006
# Example: "8,8" to use a 8x8 subgrid (64 chips) of a full pod (16x16) of trillium.
9851007
subslice_shape: ""
9861008

src/MaxText/configs/models/qwen3-omni-30b-a3b.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,21 @@ num_position_embeddings_for_vit: 2304
5656
deepstack_visual_indexes_for_vit: [8, 16, 24]
5757

5858
use_multimodal: true
59+
use_audio: true
60+
# Audio Encoder Configuration (need to set use_audio=true to enable)
61+
# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py
62+
d_model_for_audio: 1280
63+
encoder_layers_for_audio: 32
64+
encoder_attention_heads_for_audio: 20
65+
encoder_ffn_dim_for_audio: 5120
66+
max_source_positions_for_audio: 1500
67+
num_mel_bins_for_audio: 128
68+
downsample_hidden_size_for_audio: 480
69+
output_dim_for_audio: 2048
70+
attention_dropout_for_audio: 0.0
71+
n_window_for_audio: 50
72+
n_window_infer_for_audio: 400
73+
conv_chunksize_for_audio: 500
74+
num_conv_layers_for_audio: 3
75+
max_timescale_for_audio: 10000.0
76+
max_sample_len_for_audio: 10000

src/MaxText/configs/types.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,6 +1360,8 @@ class MultimodalGeneral(BaseModel):
13601360

13611361
use_multimodal: bool = Field(False, description="Enable multimodal capabilities.")
13621362
freeze_vision_encoder_params: bool = Field(True, description="Freeze the parameters of the vision encoder.")
1363+
freeze_audio_encoder_params: bool = Field(True, description="Freeze the parameters of the audio encoder.")
1364+
use_audio: bool = Field(False, description="Enable audio encoder for multimodal models.")
13631365
image_size_for_vit: int = Field(896, description="Input image size for the Vision Transformer.")
13641366
image_path: PathStr = Field("", description="Path to an image for decoding.")
13651367
image_placeholder: str = Field("<|image|>", description="Placeholder string for images in text prompts.")
@@ -1408,6 +1410,29 @@ class VisionProjector(BaseModel):
14081410
projector_dropout_for_vit: float = Field(0.0, description="Dropout rate for the vision projector.")
14091411

14101412

1413+
class AudioEncoder(BaseModel):
1414+
"""Configuration for the Audio Encoder in a multimodal model."""
1415+
1416+
d_model_for_audio: int = Field(256, description="Model dimension for the audio encoder.")
1417+
encoder_attention_heads_for_audio: int = Field(4, description="Number of attention heads in the audio encoder.")
1418+
encoder_ffn_dim_for_audio: int = Field(512, description="Feed-forward network dimension for the audio encoder.")
1419+
encoder_layers_for_audio: int = Field(2, description="Number of encoder layers for audio.")
1420+
attention_dropout_for_audio: float = Field(0.0, description="Attention dropout rate for audio encoder.")
1421+
activation_dropout_for_audio: float = Field(0.0, description="Activation dropout rate for audio encoder.")
1422+
activation_function_for_audio: str = Field("gelu", description="Activation function for audio encoder.")
1423+
num_mel_bins_for_audio: int = Field(128, description="Number of mel-frequency bins for audio input.")
1424+
max_source_positions_for_audio: int = Field(1500, description="Maximum source positions for audio encoder.")
1425+
scale_embedding_for_audio: bool = Field(True, description="Whether to scale embeddings in audio encoder.")
1426+
n_window_for_audio: int = Field(50, description="Window size for audio processing.")
1427+
n_window_infer_for_audio: int = Field(800, description="Window size for audio inference.")
1428+
conv_chunksize_for_audio: int = Field(500, description="Chunk size for convolutional layers in audio encoder.")
1429+
downsample_hidden_size_for_audio: int = Field(256, description="Hidden size for downsampling in audio encoder.")
1430+
output_dim_for_audio: int = Field(512, description="Output dimension for audio encoder.")
1431+
num_conv_layers_for_audio: int = Field(3, description="Number of convolutional layers in audio encoder.")
1432+
max_timescale_for_audio: float = Field(10000.0, description="Maximum timescale for audio positional encoding.")
1433+
max_sample_len_for_audio: int = Field(10000, description="Maximum sample length for audio input.")
1434+
1435+
14111436
class Debug(BaseModel):
14121437
"""Configuration for debugging options."""
14131438

@@ -1722,6 +1747,7 @@ class MaxTextConfig(
17221747
MultimodalGeneral,
17231748
VisionTower,
17241749
VisionProjector,
1750+
AudioEncoder,
17251751
# Derived
17261752
DerivedValues,
17271753
):

src/MaxText/decode.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def main(argv: Sequence[str]) -> None:
152152
padded_tokens=tokens,
153153
images=processor_outputs.pixel_values if config.use_multimodal else None,
154154
image_masks=processor_outputs.pixel_mask if config.use_multimodal and "llama4" in config.model_name else None,
155+
audio_values=processor_outputs.audio_values if config.use_audio else None,
156+
audio_masks=processor_outputs.audio_mask if config.use_audio else None,
155157
true_length=true_length,
156158
rng=rng_prefill,
157159
slot=i,

src/MaxText/layers/decoders.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,8 @@ def _apply_embedding(
563563
image_embeddings=None,
564564
bidirectional_mask=None,
565565
image_masks=None,
566+
audio_embeddings=None,
567+
audio_masks=None,
566568
):
567569
"""Applies token and positional embeddings to the input tokens."""
568570
cfg = self.config
@@ -581,19 +583,30 @@ def _apply_embedding(
581583
]:
582584
y = multimodal_utils.merge_mm_embeddings(
583585
text_embeddings=y,
584-
vision_embeddings=image_embeddings,
586+
multimodal_embeddings=image_embeddings,
585587
mask=bidirectional_mask,
586-
image_masks=image_masks,
588+
token_masks=image_masks,
587589
)
588590
# TODO(hengtaoguo): Add support for other multimodal models such as Llama4, refactor if needed
589591
else:
590592
raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}")
591593

594+
if audio_embeddings is not None and cfg.use_audio:
595+
if cfg.model_name in ["qwen3-omni-30b-a3b"]:
596+
y = multimodal_utils.merge_mm_embeddings(
597+
text_embeddings=y,
598+
multimodal_embeddings=audio_embeddings,
599+
mask=audio_masks,
600+
token_masks=None,
601+
)
602+
else:
603+
raise ValueError(f"Unsupported model_name for audio: {cfg.model_name}")
604+
592605
y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic)
593606
y = y.astype(cfg.dtype)
594607

595608
if cfg.use_untrainable_positional_embedding:
596-
y = positional_embedding_as_linen(embedding_dims=cfg.base_emb_dim)(y, decoder_positions)
609+
y += positional_embedding_as_linen(embedding_dims=cfg.base_emb_dim)(y.shape[1], decoder_positions)
597610

598611
if cfg.trainable_position_size > 0:
599612
y += embed_as_linen(
@@ -673,6 +686,7 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi
673686

674687
return logits
675688

689+
# TODO(aireenmei, Hengtaoguo): consolidate all multimodal inputs into a class as input to the encoder
676690
@nn.compact
677691
def __call__(
678692
self,
@@ -690,6 +704,8 @@ def __call__(
690704
image_masks: None | jnp.ndarray = None,
691705
kv_caches: list[jax.Array] | None = None,
692706
attention_metadata=None,
707+
audio_embeddings: None | jnp.ndarray = None,
708+
audio_masks: None | jnp.ndarray = None,
693709
):
694710
cfg = self.config
695711
mesh = self.mesh
@@ -705,6 +721,8 @@ def __call__(
705721
image_embeddings,
706722
bidirectional_mask,
707723
image_masks,
724+
audio_embeddings,
725+
audio_masks,
708726
)
709727

710728
policy = self.get_remat_policy()

src/MaxText/layers/embeddings.py

Lines changed: 78 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -898,50 +898,114 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array:
898898
return output
899899

900900

901-
def positional_embedding_as_linen(*, embedding_dims: int, max_wavelength: int = _MAX_WAVELENGTH):
901+
def positional_embedding_as_linen(
902+
*,
903+
embedding_dims: int,
904+
max_wavelength: int = _MAX_WAVELENGTH,
905+
cast_as_fprop_dtype: bool = False,
906+
fprop_dtype: DType = jnp.bfloat16,
907+
):
902908
"""Initializes the PositionalEmbedding module and returns it as a Linen module.
903909
904910
Args:
905911
embedding_dims: The dimension of the embeddings.
906912
max_wavelength: The maximum wavelength for the sinusoidal positional embeddings.
913+
cast_as_fprop_dtype: Whether to cast output to fprop_dtype.
914+
fprop_dtype: The dtype of the output when cast_as_fprop_dtype is True.
907915
"""
908916
return nnx_wrappers.to_linen(
909917
PositionalEmbedding,
910918
embedding_dims=embedding_dims,
911919
max_wavelength=max_wavelength,
920+
cast_as_fprop_dtype=cast_as_fprop_dtype,
921+
fprop_dtype=fprop_dtype,
912922
metadata_fn=variable_to_logically_partitioned,
913923
)
914924

915925

916926
@dataclasses.dataclass(repr=False)
917927
class PositionalEmbedding(nnx.Module):
918-
"""A layer that adds sinusoidal positional embeddings to the input."""
928+
"""Sinusoidal positional embeddings supporting both uniform and per-batch positions.
929+
930+
This module computes sinusoidal positional embeddings and supports two use cases:
931+
932+
1. Uniform positions across batch: All batch elements share the same position sequence.
933+
Pass position as 1D array (seq_len,) or None for sequential [0,1,2,...].
934+
Returns (seq_len, embedding_dims), caller broadcasts to batch.
935+
Example: pos_emb = layer(seq_len) # Sequential positions
936+
pos_emb = layer(seq_len, position_1d) # Custom 1D positions
937+
938+
2. Per-batch positions (packed sequences): Each batch element has different positions.
939+
Pass position as 2D array (batch, seq_len).
940+
Returns (batch, seq_len, embedding_dims).
941+
Example: pos_emb = layer(seq_len, position_2d)
942+
943+
As a side effect, the uniform case is more efficient since sin/cos are computed once
944+
and broadcasted, rather than per batch element.
945+
"""
919946

920947
#: The dimension of the embeddings.
921948
embedding_dims: int
922949
#: The maximum wavelength for the sinusoidal positional embeddings.
923950
max_wavelength: int = _MAX_WAVELENGTH
924-
951+
#: Whether to cast output to fprop_dtype.
952+
cast_as_fprop_dtype: bool = False
953+
#: The dtype of the output when cast_as_fprop_dtype is True.
954+
fprop_dtype: DType = jnp.bfloat16
925955
#: RNG state passed in by nnx.bridge.to_linen, not used in this module.
926956
rngs: nnx.Rngs = None # Not used in PositionalEmbedding but passed in by nnx.bridge.to_linen
927957

928-
def __call__(
929-
self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
930-
input_embedding: jax.Array,
931-
position: jax.Array,
932-
) -> jax.Array:
958+
def _compute_embeddings(self, position: Array) -> Array:
959+
"""Compute sinusoidal embeddings for given positions.
960+
961+
Args:
962+
position: Either (seq_len,) for efficient path or (batch, seq_len) for full path.
963+
964+
Returns:
965+
Embeddings of shape (seq_len, embedding_dims) or (batch, seq_len, embedding_dims).
966+
"""
933967
num_timescales = self.embedding_dims // 2
934968
log_timescale_increment = jnp.log(float(self.max_wavelength)) / jnp.maximum(
935969
jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1
936970
)
937971
inv_timescales = jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
938-
position = position[:, :, jnp.newaxis]
939-
inv_timescales = inv_timescales[jnp.newaxis, jnp.newaxis, :]
940-
scaled_time = position * inv_timescales
972+
973+
if position.ndim == 1:
974+
# use the same position for the whole batch when position is (seq_len,)
975+
scaled_time = position[:, jnp.newaxis] * inv_timescales[jnp.newaxis, :]
976+
else:
977+
# when position is (batch, seq_len)
978+
position = position[:, :, jnp.newaxis]
979+
inv_timescales = inv_timescales[jnp.newaxis, jnp.newaxis, :]
980+
scaled_time = position * inv_timescales
981+
941982
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=-1)
942-
# signal = jnp.pad(signal, [[0, jnp.mod(self.embedding_dims, 2)]])
943-
position_embedding = signal.astype(jnp.float32)
944-
return input_embedding + position_embedding
983+
984+
if self.cast_as_fprop_dtype:
985+
return signal.astype(self.fprop_dtype)
986+
else:
987+
return signal.astype(jnp.float32)
988+
989+
def __call__(
990+
self,
991+
seq_len: int,
992+
position: Array | None = None,
993+
) -> Array:
994+
"""Compute positional embeddings.
995+
996+
Args:
997+
seq_len: Sequence length for computing embeddings.
998+
position: Optional position array. If None, uses sequential [0,1,2,...].
999+
Shape can be (seq_len,) or (batch, seq_len) for packed sequences.
1000+
1001+
Returns:
1002+
Positional embeddings of shape (seq_len, embedding_dims) or
1003+
(batch, seq_len, embedding_dims) if position has batch dimension.
1004+
"""
1005+
if position is None:
1006+
position = jnp.arange(seq_len, dtype=jnp.float32)
1007+
1008+
return self._compute_embeddings(position)
9451009

9461010

9471011
def llama_vision_rotary_embedding_as_linen(

src/MaxText/layers/encoders.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,43 @@ def __call__(self, input_images, deterministic=False):
7676
return embeddings
7777

7878

79+
class AudioEncoder(nnx.Module):
80+
"""Audio encoder to encode audio features into soft tokens."""
81+
82+
def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs):
83+
self.config = config
84+
self.mesh = mesh
85+
self.rngs = rngs
86+
self.encoder_name, self.projector_name = self._setup_audio_encoder_layers()
87+
88+
def _setup_audio_encoder_layers(self):
89+
"""Setup audio encoder layers specific to the model, instantiate NNX modules."""
90+
if self.config.model_name in ["qwen3-omni-30b-a3b"]:
91+
from MaxText.layers import qwen3 # pylint: disable=import-outside-toplevel
92+
93+
encoder_name = "Qwen3OmniAudioEncoder_0"
94+
projector_name = "Qwen3OmniAudioProjector_0"
95+
setattr(self, encoder_name, qwen3.Qwen3OmniAudioEncoder(config=self.config, mesh=self.mesh, rngs=self.rngs))
96+
setattr(self, projector_name, qwen3.Qwen3OmniAudioProjector(config=self.config, rngs=self.rngs))
97+
return encoder_name, projector_name
98+
else:
99+
raise ValueError(f"No AudioEncoder implemented for {self.config.model_name} yet")
100+
101+
def __call__(self, input_audio, deterministic=False):
102+
# audio encoder output (includes convs + encoder, outputs before projector)
103+
encoder = getattr(self, self.encoder_name)
104+
embeddings = encoder(input_audio, deterministic=deterministic)
105+
106+
if self.config.freeze_audio_encoder_params:
107+
embeddings = jax.lax.stop_gradient(embeddings)
108+
109+
# audio projector layer
110+
projector = getattr(self, self.projector_name)
111+
embeddings = projector(embeddings)
112+
113+
return embeddings
114+
115+
79116
def vision_encoder_as_linen(
80117
config: Config,
81118
mesh: Mesh,
@@ -90,3 +127,19 @@ def vision_encoder_as_linen(
90127
metadata_fn=initializers.variable_to_logically_partitioned,
91128
)
92129
return module
130+
131+
132+
def audio_encoder_as_linen(
133+
config: Config,
134+
mesh: Mesh,
135+
):
136+
"""Creates an AudioEncoder module."""
137+
module = nnx_wrappers.to_linen(
138+
AudioEncoder,
139+
config=config,
140+
mesh=mesh,
141+
name="audio_encoder",
142+
abstract_init=False,
143+
metadata_fn=initializers.variable_to_logically_partitioned,
144+
)
145+
return module

0 commit comments

Comments
 (0)