@@ -898,50 +898,114 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array:
898898 return output
899899
900900
901- def positional_embedding_as_linen (* , embedding_dims : int , max_wavelength : int = _MAX_WAVELENGTH ):
901+ def positional_embedding_as_linen (
902+ * ,
903+ embedding_dims : int ,
904+ max_wavelength : int = _MAX_WAVELENGTH ,
905+ cast_as_fprop_dtype : bool = False ,
906+ fprop_dtype : DType = jnp .bfloat16 ,
907+ ):
902908 """Initializes the PositionalEmbedding module and returns it as a Linen module.
903909
904910 Args:
905911 embedding_dims: The dimension of the embeddings.
906912 max_wavelength: The maximum wavelength for the sinusoidal positional embeddings.
913+ cast_as_fprop_dtype: Whether to cast output to fprop_dtype.
914+ fprop_dtype: The dtype of the output when cast_as_fprop_dtype is True.
907915 """
908916 return nnx_wrappers .to_linen (
909917 PositionalEmbedding ,
910918 embedding_dims = embedding_dims ,
911919 max_wavelength = max_wavelength ,
920+ cast_as_fprop_dtype = cast_as_fprop_dtype ,
921+ fprop_dtype = fprop_dtype ,
912922 metadata_fn = variable_to_logically_partitioned ,
913923 )
914924
915925
916926@dataclasses .dataclass (repr = False )
917927class PositionalEmbedding (nnx .Module ):
918- """A layer that adds sinusoidal positional embeddings to the input."""
928+ """Sinusoidal positional embeddings supporting both uniform and per-batch positions.
929+
930+ This module computes sinusoidal positional embeddings and supports two use cases:
931+
932+ 1. Uniform positions across batch: All batch elements share the same position sequence.
933+ Pass position as 1D array (seq_len,) or None for sequential [0,1,2,...].
934+ Returns (seq_len, embedding_dims), caller broadcasts to batch.
935+ Example: pos_emb = layer(seq_len) # Sequential positions
936+ pos_emb = layer(seq_len, position_1d) # Custom 1D positions
937+
938+ 2. Per-batch positions (packed sequences): Each batch element has different positions.
939+ Pass position as 2D array (batch, seq_len).
940+ Returns (batch, seq_len, embedding_dims).
941+ Example: pos_emb = layer(seq_len, position_2d)
942+
943+ As a side effect, the uniform case is more efficient since sin/cos are computed once
944+ and broadcasted, rather than per batch element.
945+ """
919946
920947 #: The dimension of the embeddings.
921948 embedding_dims : int
922949 #: The maximum wavelength for the sinusoidal positional embeddings.
923950 max_wavelength : int = _MAX_WAVELENGTH
924-
951+ #: Whether to cast output to fprop_dtype.
952+ cast_as_fprop_dtype : bool = False
953+ #: The dtype of the output when cast_as_fprop_dtype is True.
954+ fprop_dtype : DType = jnp .bfloat16
925955 #: RNG state passed in by nnx.bridge.to_linen, not used in this module.
926956 rngs : nnx .Rngs = None # Not used in PositionalEmbedding but passed in by nnx.bridge.to_linen
927957
928- def __call__ (
929- self , # pytype: disable=signature-mismatch # overriding-parameter-count-checks
930- input_embedding : jax .Array ,
931- position : jax .Array ,
932- ) -> jax .Array :
958+ def _compute_embeddings (self , position : Array ) -> Array :
959+ """Compute sinusoidal embeddings for given positions.
960+
961+ Args:
962+ position: Either (seq_len,) for efficient path or (batch, seq_len) for full path.
963+
964+ Returns:
965+ Embeddings of shape (seq_len, embedding_dims) or (batch, seq_len, embedding_dims).
966+ """
933967 num_timescales = self .embedding_dims // 2
934968 log_timescale_increment = jnp .log (float (self .max_wavelength )) / jnp .maximum (
935969 jnp .asarray (num_timescales , dtype = jnp .float32 ) - 1 , 1
936970 )
937971 inv_timescales = jnp .exp (jnp .arange (num_timescales , dtype = jnp .float32 ) * - log_timescale_increment )
938- position = position [:, :, jnp .newaxis ]
939- inv_timescales = inv_timescales [jnp .newaxis , jnp .newaxis , :]
940- scaled_time = position * inv_timescales
972+
973+ if position .ndim == 1 :
974+ # use the same position for the whole batch when position is (seq_len,)
975+ scaled_time = position [:, jnp .newaxis ] * inv_timescales [jnp .newaxis , :]
976+ else :
977+ # when position is (batch, seq_len)
978+ position = position [:, :, jnp .newaxis ]
979+ inv_timescales = inv_timescales [jnp .newaxis , jnp .newaxis , :]
980+ scaled_time = position * inv_timescales
981+
941982 signal = jnp .concatenate ([jnp .sin (scaled_time ), jnp .cos (scaled_time )], axis = - 1 )
942- # signal = jnp.pad(signal, [[0, jnp.mod(self.embedding_dims, 2)]])
943- position_embedding = signal .astype (jnp .float32 )
944- return input_embedding + position_embedding
983+
984+ if self .cast_as_fprop_dtype :
985+ return signal .astype (self .fprop_dtype )
986+ else :
987+ return signal .astype (jnp .float32 )
988+
989+ def __call__ (
990+ self ,
991+ seq_len : int ,
992+ position : Array | None = None ,
993+ ) -> Array :
994+ """Compute positional embeddings.
995+
996+ Args:
997+ seq_len: Sequence length for computing embeddings.
998+ position: Optional position array. If None, uses sequential [0,1,2,...].
999+ Shape can be (seq_len,) or (batch, seq_len) for packed sequences.
1000+
1001+ Returns:
1002+ Positional embeddings of shape (seq_len, embedding_dims) or
1003+ (batch, seq_len, embedding_dims) if position has batch dimension.
1004+ """
1005+ if position is None :
1006+ position = jnp .arange (seq_len , dtype = jnp .float32 )
1007+
1008+ return self ._compute_embeddings (position )
9451009
9461010
9471011def llama_vision_rotary_embedding_as_linen (
0 commit comments