Skip to content

Commit 0c3aa31

Browse files
authored
[PERF] Make NXGraphCompiler avoid slowpath tensor creation - resolves #48 (#53)
1 parent 3d6344b commit 0c3aa31

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

qek/data/graphs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ class NXGraphCompiler(BaseGraphCompiler[NXWithPos]):
519519
def ingest(self, graph: NXWithPos, device: pl.devices.Device, id: int) -> BaseGraph:
520520
pyg = pyg_utils.from_networkx(graph.graph)
521521
pyg.y = graph.target
522-
positions = [graph.positions[node] for node in graph.graph.nodes()]
522+
positions = np.array([graph.positions[node] for node in graph.graph.nodes()])
523523
pyg.pos = torch.tensor(positions, dtype=torch.float)
524524

525525
return BaseGraph(id=id, device=device, data=pyg)

0 commit comments

Comments
 (0)