Skip to content

Commit 7059427

Browse files
committed
fix
1 parent 9331602 commit 7059427

1 file changed

Lines changed: 22 additions & 44 deletions

File tree

src/maxdiffusion/tests/ltx2_pipeline_test.py

Lines changed: 22 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -91,52 +91,30 @@ def setUp(self):
9191

9292
try:
9393
from maxdiffusion.pipelines.ltx2.ltx2_pipeline import MaxTextGemma3FeatureExtractor
94-
from MaxText import common_types
95-
# Partial mock for config if needed
96-
class MockConfig:
97-
vocab_size = 32000
98-
emb_dim = 16
99-
num_layers = 5
100-
num_heads = 2
101-
head_dim = 8
102-
mlp_dim = 32
103-
dtype = jnp.float32
104-
weights_dtype = jnp.float32
105-
weight_dtype = jnp.float32
106-
use_iota_embed = False
107-
num_decoder_layers = 5 # Match num_layers
108-
normalization_layer_epsilon = 1e-6
109-
scan_layers = False
110-
param_scan_axis = 1
111-
max_prefill_predict_length = 512
112-
per_device_batch_size = 1
113-
max_target_length = 1024
114-
rope_min_timescale = 1
115-
rope_max_timescale = 10000
116-
rope_type = "interleaved"
117-
rope_embedding_dims = 16
118-
rope_use_scale = False
119-
model_name = "gemma3-4b" # Use a valid model name
120-
base_emb_dim = 16
121-
base_num_query_heads = 2
122-
head_dim = 8
123-
num_query_heads = 2
124-
num_kv_heads = 1
125-
dropout_rate = 0.0
126-
float32_qk_product = False
127-
float32_logits = False
128-
sliding_window_size = 128
129-
attn_logits_soft_cap = 50.0
130-
use_post_attn_norm = True
131-
attention = "dot_product" # attention_kernel
132-
quantization = "" # for configure_kv_quant
133-
quantize_kvcache = False
134-
decoder_block = common_types.DecoderBlockType.GEMMA3
135-
use_chunked_prefill = False
136-
attention_type = "dot_product"
94+
from MaxText.configs import types
95+
from MaxText import pyconfig
96+
97+
# Use real MaxText Config to avoid missing attributes
98+
raw_config = types.MaxTextConfig(
99+
vocab_size=32000,
100+
model_name="gemma3-4b",
101+
base_emb_dim=16,
102+
base_num_query_heads=2,
103+
base_num_kv_heads=1,
104+
head_dim=8,
105+
num_decoder_layers=5,
106+
max_prefill_predict_length=512,
107+
max_target_length=1024,
108+
per_device_batch_size=1,
109+
dtype=jnp.float32,
110+
weight_dtype=jnp.float32,
111+
rope_embedding_dims=16,
112+
# Add other overrides as needed for the test to be lightweight
113+
)
114+
config = pyconfig.HyperParameters(raw_config)
137115

138116
self.text_encoder = MaxTextGemma3FeatureExtractor(
139-
config=MockConfig(),
117+
config=config,
140118
mesh=self.mesh,
141119
quant=None,
142120
rngs=self.rng

0 commit comments

Comments
 (0)