Skip to content

Commit b4fd8ac

Browse files
Merge pull request #3096 from AI-Hypercomputer:qwen-checkpoint
PiperOrigin-RevId: 868817033
2 parents ef90f2d + 77b061e commit b4fd8ac

4 files changed

Lines changed: 364 additions & 10 deletions

File tree

src/MaxText/layers/encoders.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,14 @@ def _setup_vision_encoder_layers(self):
6464
def __call__(self, input_images, deterministic=False):
6565
# vision encoder output, frozen params in many cases
6666
encoder = getattr(self, self.encoder_name)
67-
embeddings = encoder(input_images, deterministic=deterministic)
67+
encoder_output = encoder(input_images, deterministic=deterministic)
68+
69+
deep_feats = None
70+
if isinstance(encoder_output, tuple):
71+
embeddings = encoder_output[0]
72+
deep_feats = encoder_output[1]
73+
else:
74+
embeddings = encoder_output
6875

6976
if self.config.freeze_vision_encoder_params:
7077
embeddings = jax.lax.stop_gradient(embeddings)
@@ -73,7 +80,7 @@ def __call__(self, input_images, deterministic=False):
7380
projector = getattr(self, self.projector_name)
7481
embeddings = projector(embeddings)
7582

76-
return embeddings
83+
return embeddings, deep_feats
7784

7885

7986
class AudioEncoder(nnx.Module):

src/MaxText/layers/models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ def __call__(
154154
audio_embeddings = None
155155

156156
if self.config.use_multimodal and encoder_images is not None:
157-
image_embeddings = self.vision_encoder(input_images=encoder_images, deterministic=not enable_dropout)
157+
# qwen3-omni-30b-a3b returns deep features from the vision encoder.
158+
image_embeddings, _ = self.vision_encoder(input_images=encoder_images, deterministic=not enable_dropout)
158159
bidirectional_mask = mm_processor.get_bidirectional_mask_vision(self.config, decoder_input_tokens)
159160

160161
if self.config.use_multimodal and encoder_audios is not None and self.audio_encoder is not None:
@@ -459,7 +460,7 @@ def __call__(
459460
bidirectional_mask = None
460461
image_embeddings = None
461462
if self.config.use_multimodal and encoder_images is not None:
462-
image_embeddings = self.vision_encoder(input_images=encoder_images, deterministic=not enable_dropout)
463+
image_embeddings, _ = self.vision_encoder(input_images=encoder_images, deterministic=not enable_dropout)
463464
bidirectional_mask = mm_processor.get_bidirectional_mask_vision(self.config, decoder_input_tokens)
464465

465466
audio_embeddings = None

0 commit comments

Comments
 (0)