Skip to content

Commit 7d40985

Browse files
[ENG-641] Allow reverse relation creation (#623)
* allow reverse label creation * cur progress * allow working for multiple same-label relation * address PR comments * address PR comment * Update apps/roam/src/components/canvas/DiscourseRelationShape/DiscourseRelationUtil.tsx Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com> * format and address PR comment --------- Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com>
1 parent fc82a41 commit 7d40985

2 files changed

Lines changed: 196 additions & 20 deletions

File tree

apps/roam/src/components/canvas/DiscourseRelationShape/DiscourseRelationTool.tsx

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,14 +361,17 @@ export const createAllRelationShapeTools = (
361361
);
362362

363363
const relation = discourseContext.relations[name].find(
364-
(r) => r.source === target?.type,
364+
(r) => r.source === target?.type || r.destination === target?.type,
365365
);
366366
if (relation) {
367367
this.shapeType = relation.id;
368368
} else {
369-
const acceptableTypes = discourseContext.relations[name].map(
370-
(r) => discourseContext.nodes[r.source].text,
371-
);
369+
const acceptableTypes = discourseContext.relations[name]
370+
.flatMap((r) => [
371+
discourseContext.nodes[r.source]?.text,
372+
discourseContext.nodes[r.destination]?.text,
373+
])
374+
.filter(Boolean);
372375
const uniqueTypes = [...new Set(acceptableTypes)];
373376
this.cancelAndWarn(
374377
`Starting node must be one of ${uniqueTypes.join(", ")}`,

apps/roam/src/components/canvas/DiscourseRelationShape/DiscourseRelationUtil.tsx

Lines changed: 189 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ import {
6565
} from "./helpers";
6666
import { createReifiedRelation } from "~/utils/createReifiedBlock";
6767
import { getStoredRelationsEnabled } from "~/utils/storedRelations";
68+
import type { DiscourseRelation } from "~/utils/getDiscourseRelations";
6869
import { discourseContext, isPageUid } from "~/components/canvas/Tldraw";
6970
import getPageUidByPageTitle from "roamjs-components/queries/getPageUidByPageTitle";
7071

@@ -624,41 +625,77 @@ export const createAllRelationShapeUtils = (
624625
const relations = Object.values(discourseContext.relations).flat();
625626
const relation = relations.find((r) => r.id === arrow.type);
626627
if (!relation) return;
627-
const possibleTargets = discourseContext.relations[relation.label]
628-
.filter((r) => r.source === relation.source)
629-
.map((r) => r.destination);
630628

631-
if (!possibleTargets.includes(target.type)) {
632-
const uniqueTargets = [...new Set(possibleTargets)];
633-
const uniqueTargetTexts = uniqueTargets.map(
634-
(t) => discourseContext.nodes[t].text,
629+
const sourceNodeType = source.type;
630+
const targetNodeType = target.type;
631+
632+
// Check all relations with the same label for a match
633+
const {
634+
isDirect,
635+
isReverse,
636+
matchingRelation: foundRelation,
637+
} = this.checkConnectionTypeAcrossLabel(
638+
relation.label,
639+
sourceNodeType,
640+
targetNodeType,
641+
);
642+
const matchingRelation = foundRelation ?? relation;
643+
644+
if (!isDirect && !isReverse) {
645+
const validTargets = this.getValidTargetTypes(
646+
relation.label,
647+
sourceNodeType,
648+
);
649+
const uniqueTargetTexts = validTargets.map(
650+
(t) => discourseContext.nodes[t]?.text || t,
635651
);
636652
return deleteAndWarn(
637653
`Target node must be of type ${uniqueTargetTexts.join(", ")}`,
638654
);
639655
}
640-
if (arrow.type !== target.type) {
641-
editor.updateShapes([{ id: arrow.id, type: target.type }]);
656+
657+
// If we found a matching relation with a different ID, switch to it
658+
if (matchingRelation.id !== arrow.type) {
659+
// Get bindings before updating the shape type
660+
const existingBindings = editor.getBindingsFromShape(
661+
arrow,
662+
arrow.type,
663+
);
664+
// Update the shape type
665+
editor.updateShapes([{ id: arrow.id, type: matchingRelation.id }]);
666+
// Update bindings to use the new relation type
667+
for (const binding of existingBindings) {
668+
editor.updateBinding({
669+
...binding,
670+
type: matchingRelation.id,
671+
});
672+
}
642673
}
643674
if (getStoredRelationsEnabled()) {
644675
const sourceAsDNS = asDiscourseNodeShape(source, editor);
645676
const targetAsDNS = asDiscourseNodeShape(target, editor);
646677

647-
if (sourceAsDNS && targetAsDNS)
678+
if (sourceAsDNS && targetAsDNS) {
679+
const isOriginal = isDirect;
648680
await createReifiedRelation({
649-
sourceUid: sourceAsDNS.props.uid,
650-
destinationUid: targetAsDNS.props.uid,
651-
relationBlockUid: relation.id,
681+
sourceUid: isOriginal
682+
? sourceAsDNS.props.uid
683+
: targetAsDNS.props.uid,
684+
destinationUid: isOriginal
685+
? targetAsDNS.props.uid
686+
: sourceAsDNS.props.uid,
687+
relationBlockUid: matchingRelation.id,
652688
});
653-
else {
689+
} else {
654690
void internalError({
655691
error: "attempt to create a relation between non discourse nodes",
656692
type: "Canvas create relation",
657693
});
658694
}
659695
} else {
660-
const { triples, label: relationLabel } = relation;
661-
const isOriginal = arrow.props.text === relationLabel;
696+
const { triples } = matchingRelation;
697+
const isOriginal = isDirect;
698+
662699
const newTriples = triples
663700
.map((t) => {
664701
if (/is a/i.test(t[1])) {
@@ -845,6 +882,33 @@ export const createAllRelationShapeUtils = (
845882
return update;
846883
}
847884

885+
// Validate target node type compatibility before creating binding
886+
if (
887+
target.type !== "arrow" &&
888+
otherBinding &&
889+
target.id !== otherBinding.toId &&
890+
(!currentBinding || target.id !== currentBinding.toId)
891+
) {
892+
const sourceNodeId = otherBinding.toId;
893+
const sourceNode = this.editor.getShape(sourceNodeId);
894+
const targetNodeType = target.type;
895+
const sourceNodeType = sourceNode?.type;
896+
897+
if (sourceNodeType && targetNodeType && shape.type) {
898+
const isValidConnection = this.isValidNodeConnection(
899+
sourceNodeType,
900+
targetNodeType,
901+
shape.type,
902+
);
903+
904+
if (!isValidConnection) {
905+
removeArrowBinding(this.editor, shape, handleId);
906+
update.props![handleId] = { x: handle.x, y: handle.y };
907+
return update;
908+
}
909+
}
910+
}
911+
848912
// we've got a target! the handle is being dragged over a shape, bind to it
849913

850914
const targetGeometry = this.editor.getShapeGeometry(target);
@@ -921,6 +985,42 @@ export const createAllRelationShapeUtils = (
921985
this.editor.setHintingShapes([target.id]);
922986

923987
const newBindings = getArrowBindings(this.editor, shape);
988+
989+
// Check if both ends are bound and determine the correct text based on direction
990+
if (newBindings.start && newBindings.end) {
991+
const relations = Object.values(discourseContext.relations).flat();
992+
const relation = relations.find((r) => r.id === shape.type);
993+
994+
if (relation) {
995+
const startNode = this.editor.getShape(newBindings.start.toId);
996+
const endNode = this.editor.getShape(newBindings.end.toId);
997+
998+
if (startNode && endNode) {
999+
const startNodeType = startNode.type;
1000+
const endNodeType = endNode.type;
1001+
1002+
const { isReverse, matchingRelation } =
1003+
this.checkConnectionTypeAcrossLabel(
1004+
relation.label,
1005+
startNodeType,
1006+
endNodeType,
1007+
);
1008+
1009+
const effectiveRelation = matchingRelation ?? relation;
1010+
1011+
const newText =
1012+
isReverse && effectiveRelation.complement
1013+
? effectiveRelation.complement
1014+
: effectiveRelation.label;
1015+
1016+
if (shape.props.text !== newText) {
1017+
update.props = update.props || {};
1018+
update.props.text = newText;
1019+
}
1020+
}
1021+
}
1022+
}
1023+
9241024
if (
9251025
newBindings.start &&
9261026
newBindings.end &&
@@ -1600,6 +1700,79 @@ export class BaseDiscourseRelationUtil extends ShapeUtil<DiscourseRelationShape>
16001700
];
16011701
}
16021702

1703+
checkConnectionType(
1704+
relation: { source: string; destination: string },
1705+
sourceNodeType: string,
1706+
targetNodeType: string,
1707+
): { isDirect: boolean; isReverse: boolean } {
1708+
const isDirect =
1709+
sourceNodeType === relation.source &&
1710+
targetNodeType === relation.destination;
1711+
1712+
const isReverse =
1713+
sourceNodeType === relation.destination &&
1714+
targetNodeType === relation.source;
1715+
1716+
return { isDirect, isReverse };
1717+
}
1718+
1719+
checkConnectionTypeAcrossLabel(
1720+
label: string,
1721+
sourceNodeType: string,
1722+
targetNodeType: string,
1723+
): {
1724+
isDirect: boolean;
1725+
isReverse: boolean;
1726+
matchingRelation: DiscourseRelation | null;
1727+
} {
1728+
const relationsWithLabel = discourseContext.relations[label];
1729+
if (!relationsWithLabel) {
1730+
return { isDirect: false, isReverse: false, matchingRelation: null };
1731+
}
1732+
1733+
for (const rel of relationsWithLabel) {
1734+
const { isDirect, isReverse } = this.checkConnectionType(
1735+
rel,
1736+
sourceNodeType,
1737+
targetNodeType,
1738+
);
1739+
if (isDirect || isReverse) {
1740+
return { isDirect, isReverse, matchingRelation: rel };
1741+
}
1742+
}
1743+
1744+
return { isDirect: false, isReverse: false, matchingRelation: null };
1745+
}
1746+
1747+
getValidTargetTypes(label: string, sourceNodeType: string): string[] {
1748+
const relationsWithLabel = discourseContext.relations[label];
1749+
if (!relationsWithLabel) return [];
1750+
1751+
const targets = new Set<string>();
1752+
for (const rel of relationsWithLabel) {
1753+
if (rel.source === sourceNodeType) targets.add(rel.destination);
1754+
if (rel.destination === sourceNodeType) targets.add(rel.source);
1755+
}
1756+
return [...targets];
1757+
}
1758+
1759+
isValidNodeConnection(
1760+
sourceNodeType: string,
1761+
targetNodeType: string,
1762+
relationId: string,
1763+
): boolean {
1764+
const relations = Object.values(discourseContext.relations).flat();
1765+
const relation = relations.find((r) => r.id === relationId);
1766+
if (!relation) return false;
1767+
1768+
const { isDirect, isReverse } = this.checkConnectionTypeAcrossLabel(
1769+
relation.label,
1770+
sourceNodeType,
1771+
targetNodeType,
1772+
);
1773+
return isDirect || isReverse;
1774+
}
1775+
16031776
component(shape: DiscourseRelationShape) {
16041777
// eslint-disable-next-line react-hooks/rules-of-hooks
16051778
// const theme = useDefaultColorTheme();

0 commit comments

Comments
 (0)