Skip to content

Commit 287552b

Browse files
committed
Fix
1 parent d913afd commit 287552b

1 file changed

Lines changed: 1 addition & 51 deletions

File tree

src/maxdiffusion/models/lora_nnx.py

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import Union, Tuple, Optional
1818
import re
1919
import torch
20-
import torch.utils.dlpack
2120
from jax import dlpack
2221
import jax.numpy as jnp
2322
from flax import nnx
@@ -154,58 +153,9 @@ def __call__(self, x):
154153

155154
return base_out + lora_out
156155

157-
158-
# -----------------------------------------------------------------------------
159-
# Helper: The "Discovery" Logic (Graph Transformation)
160-
# -----------------------------------------------------------------------------
161-
162-
def inject_lora(
163-
model: nnx.Module,
164-
rank: int,
165-
rngs: nnx.Rngs,
166-
scale: float = 1.0,
167-
target_linear: bool = True,
168-
target_conv: bool = False
169-
):
170-
"""
171-
Traverses the NNX model graph and replaces target layers with LoRA wrappers.
172-
This modifies the 'model' object in-place.
173-
"""
174-
for path, module in nnx.iter_graph(model):
175-
# If path is like ('block', 'linear'), attr_name is 'linear'
176-
# and parent is model.block
177-
attr_name = path[-1]
178-
if len(path) == 1:
179-
parent = model
180-
else:
181-
parent = model
182-
for key in path[:-1]:
183-
parent = getattr(parent, key)
184-
185-
# Handle Linear Layers
186-
if target_linear and isinstance(module, nnx.Linear):
187-
# Do not wrap if it's already wrapped (sanity check)
188-
if isinstance(parent, BaseLoRALayer):
189-
continue
190-
191-
print(f"Injecting LoRA (Linear) at {'.'.join([str(p) for p in path])}")
192-
wrapper = LoRALinear(base_layer=module, rank=rank, scale=scale, rngs=rngs)
193-
setattr(parent, attr_name, wrapper)
194-
195-
# Handle Conv Layers
196-
elif target_conv and isinstance(module, nnx.Conv):
197-
if isinstance(parent, BaseLoRALayer):
198-
continue
199-
200-
print(f"Injecting LoRA (Conv) at {'.'.join([str(p) for p in path])}")
201-
wrapper = LoRAConv(base_layer=module, rank=rank, scale=scale, rngs=rngs)
202-
setattr(parent, attr_name, wrapper)
203-
204-
return model
205-
206156
def _to_jax_array(v):
207157
if isinstance(v, torch.Tensor):
208-
return dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(v))
158+
return dlpack.from_dlpack(v)
209159
return jnp.array(v)
210160

211161
def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=None):

0 commit comments

Comments
 (0)