|
17 | 17 | from typing import Union, Tuple, Optional |
18 | 18 | import re |
19 | 19 | import torch |
20 | | -import torch.utils.dlpack |
21 | 20 | from jax import dlpack |
22 | 21 | import jax.numpy as jnp |
23 | 22 | from flax import nnx |
@@ -154,58 +153,9 @@ def __call__(self, x): |
154 | 153 |
|
155 | 154 | return base_out + lora_out |
156 | 155 |
|
157 | | - |
158 | | -# ----------------------------------------------------------------------------- |
159 | | -# Helper: The "Discovery" Logic (Graph Transformation) |
160 | | -# ----------------------------------------------------------------------------- |
161 | | - |
162 | | -def inject_lora( |
163 | | - model: nnx.Module, |
164 | | - rank: int, |
165 | | - rngs: nnx.Rngs, |
166 | | - scale: float = 1.0, |
167 | | - target_linear: bool = True, |
168 | | - target_conv: bool = False |
169 | | -): |
170 | | - """ |
171 | | - Traverses the NNX model graph and replaces target layers with LoRA wrappers. |
172 | | - This modifies the 'model' object in-place. |
173 | | - """ |
174 | | - for path, module in nnx.iter_graph(model): |
175 | | - # If path is like ('block', 'linear'), attr_name is 'linear' |
176 | | - # and parent is model.block |
177 | | - attr_name = path[-1] |
178 | | - if len(path) == 1: |
179 | | - parent = model |
180 | | - else: |
181 | | - parent = model |
182 | | - for key in path[:-1]: |
183 | | - parent = getattr(parent, key) |
184 | | - |
185 | | - # Handle Linear Layers |
186 | | - if target_linear and isinstance(module, nnx.Linear): |
187 | | - # Do not wrap if it's already wrapped (sanity check) |
188 | | - if isinstance(parent, BaseLoRALayer): |
189 | | - continue |
190 | | - |
191 | | - print(f"Injecting LoRA (Linear) at {'.'.join([str(p) for p in path])}") |
192 | | - wrapper = LoRALinear(base_layer=module, rank=rank, scale=scale, rngs=rngs) |
193 | | - setattr(parent, attr_name, wrapper) |
194 | | - |
195 | | - # Handle Conv Layers |
196 | | - elif target_conv and isinstance(module, nnx.Conv): |
197 | | - if isinstance(parent, BaseLoRALayer): |
198 | | - continue |
199 | | - |
200 | | - print(f"Injecting LoRA (Conv) at {'.'.join([str(p) for p in path])}") |
201 | | - wrapper = LoRAConv(base_layer=module, rank=rank, scale=scale, rngs=rngs) |
202 | | - setattr(parent, attr_name, wrapper) |
203 | | - |
204 | | - return model |
205 | | - |
206 | 156 | def _to_jax_array(v): |
207 | 157 | if isinstance(v, torch.Tensor): |
208 | | - return dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(v)) |
| 158 | + return dlpack.from_dlpack(v) |
209 | 159 | return jnp.array(v) |
210 | 160 |
|
211 | 161 | def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=None): |
|
0 commit comments