@@ -119,27 +119,6 @@ def __init__(
119119 )
120120
121121 def __call__ (self , x : jax .Array , cache_x : Optional [jax .Array ] = None , idx = - 1 ) -> jax .Array :
122- # OPTIMIZATION: Spatial Partitioning during execution
123- if self .mesh is not None and "context" in self .mesh .axis_names :
124- height = x .shape [2 ]
125- width = x .shape [3 ]
126- num_context_devices = self .mesh .shape ["context" ]
127-
128- shard_axis = "context" if (height % num_context_devices == 0 ) else None
129- shard_width_axis = None
130- if shard_axis is None and width % num_context_devices == 0 :
131- shard_width_axis = "context"
132-
133- x = jax .lax .with_sharding_constraint (
134- x , jax .sharding .PartitionSpec ("data" , None , shard_axis , shard_width_axis , None )
135- )
136-
137- # Debug logging
138- if shard_axis or shard_width_axis :
139- jax .debug .print (
140- "Spatial sharding applied: height_axis={}, width_axis={} for shape {}" ,
141- shard_axis , shard_width_axis , x .shape
142- )
143122
144123 current_padding = list (self ._causal_padding ) # Mutable copy
145124 padding_needed = self ._depth_padding_before
@@ -165,6 +144,20 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) ->
165144 x_padded = jnp .pad (x , padding_to_apply , mode = "constant" , constant_values = 0.0 )
166145 else :
167146 x_padded = x
147+
148+ if self .mesh is not None and "context" in self .mesh .axis_names :
149+ height = x_padded .shape [2 ]
150+ width = x_padded .shape [3 ]
151+ num_context_devices = self .mesh .shape ["context" ]
152+
153+ shard_axis = "context" if (height % num_context_devices == 0 ) else None
154+ shard_width_axis = None
155+ if shard_axis is None and width % num_context_devices == 0 :
156+ shard_width_axis = "context"
157+
158+ x_padded = jax .lax .with_sharding_constraint (
159+ x_padded , jax .sharding .PartitionSpec ("data" , None , shard_axis , shard_width_axis , None )
160+ )
168161
169162 out = self .conv (x_padded )
170163 return out
0 commit comments