@@ -65,6 +65,7 @@ import {
6565} from "./helpers" ;
6666import { createReifiedRelation } from "~/utils/createReifiedBlock" ;
6767import { getStoredRelationsEnabled } from "~/utils/storedRelations" ;
68+ import type { DiscourseRelation } from "~/utils/getDiscourseRelations" ;
6869import { discourseContext , isPageUid } from "~/components/canvas/Tldraw" ;
6970import getPageUidByPageTitle from "roamjs-components/queries/getPageUidByPageTitle" ;
7071
@@ -624,41 +625,77 @@ export const createAllRelationShapeUtils = (
624625 const relations = Object . values ( discourseContext . relations ) . flat ( ) ;
625626 const relation = relations . find ( ( r ) => r . id === arrow . type ) ;
626627 if ( ! relation ) return ;
627- const possibleTargets = discourseContext . relations [ relation . label ]
628- . filter ( ( r ) => r . source === relation . source )
629- . map ( ( r ) => r . destination ) ;
630628
631- if ( ! possibleTargets . includes ( target . type ) ) {
632- const uniqueTargets = [ ...new Set ( possibleTargets ) ] ;
633- const uniqueTargetTexts = uniqueTargets . map (
634- ( t ) => discourseContext . nodes [ t ] . text ,
629+ const sourceNodeType = source . type ;
630+ const targetNodeType = target . type ;
631+
632+ // Check all relations with the same label for a match
633+ const {
634+ isDirect,
635+ isReverse,
636+ matchingRelation : foundRelation ,
637+ } = this . checkConnectionTypeAcrossLabel (
638+ relation . label ,
639+ sourceNodeType ,
640+ targetNodeType ,
641+ ) ;
642+ const matchingRelation = foundRelation ?? relation ;
643+
644+ if ( ! isDirect && ! isReverse ) {
645+ const validTargets = this . getValidTargetTypes (
646+ relation . label ,
647+ sourceNodeType ,
648+ ) ;
649+ const uniqueTargetTexts = validTargets . map (
650+ ( t ) => discourseContext . nodes [ t ] ?. text || t ,
635651 ) ;
636652 return deleteAndWarn (
637653 `Target node must be of type ${ uniqueTargetTexts . join ( ", " ) } ` ,
638654 ) ;
639655 }
640- if ( arrow . type !== target . type ) {
641- editor . updateShapes ( [ { id : arrow . id , type : target . type } ] ) ;
656+
657+ // If we found a matching relation with a different ID, switch to it
658+ if ( matchingRelation . id !== arrow . type ) {
659+ // Get bindings before updating the shape type
660+ const existingBindings = editor . getBindingsFromShape (
661+ arrow ,
662+ arrow . type ,
663+ ) ;
664+ // Update the shape type
665+ editor . updateShapes ( [ { id : arrow . id , type : matchingRelation . id } ] ) ;
666+ // Update bindings to use the new relation type
667+ for ( const binding of existingBindings ) {
668+ editor . updateBinding ( {
669+ ...binding ,
670+ type : matchingRelation . id ,
671+ } ) ;
672+ }
642673 }
643674 if ( getStoredRelationsEnabled ( ) ) {
644675 const sourceAsDNS = asDiscourseNodeShape ( source , editor ) ;
645676 const targetAsDNS = asDiscourseNodeShape ( target , editor ) ;
646677
647- if ( sourceAsDNS && targetAsDNS )
678+ if ( sourceAsDNS && targetAsDNS ) {
679+ const isOriginal = isDirect ;
648680 await createReifiedRelation ( {
649- sourceUid : sourceAsDNS . props . uid ,
650- destinationUid : targetAsDNS . props . uid ,
651- relationBlockUid : relation . id ,
681+ sourceUid : isOriginal
682+ ? sourceAsDNS . props . uid
683+ : targetAsDNS . props . uid ,
684+ destinationUid : isOriginal
685+ ? targetAsDNS . props . uid
686+ : sourceAsDNS . props . uid ,
687+ relationBlockUid : matchingRelation . id ,
652688 } ) ;
653- else {
689+ } else {
654690 void internalError ( {
655691 error : "attempt to create a relation between non discourse nodes" ,
656692 type : "Canvas create relation" ,
657693 } ) ;
658694 }
659695 } else {
660- const { triples, label : relationLabel } = relation ;
661- const isOriginal = arrow . props . text === relationLabel ;
696+ const { triples } = matchingRelation ;
697+ const isOriginal = isDirect ;
698+
662699 const newTriples = triples
663700 . map ( ( t ) => {
664701 if ( / i s a / i. test ( t [ 1 ] ) ) {
@@ -845,6 +882,33 @@ export const createAllRelationShapeUtils = (
845882 return update ;
846883 }
847884
885+ // Validate target node type compatibility before creating binding
886+ if (
887+ target . type !== "arrow" &&
888+ otherBinding &&
889+ target . id !== otherBinding . toId &&
890+ ( ! currentBinding || target . id !== currentBinding . toId )
891+ ) {
892+ const sourceNodeId = otherBinding . toId ;
893+ const sourceNode = this . editor . getShape ( sourceNodeId ) ;
894+ const targetNodeType = target . type ;
895+ const sourceNodeType = sourceNode ?. type ;
896+
897+ if ( sourceNodeType && targetNodeType && shape . type ) {
898+ const isValidConnection = this . isValidNodeConnection (
899+ sourceNodeType ,
900+ targetNodeType ,
901+ shape . type ,
902+ ) ;
903+
904+ if ( ! isValidConnection ) {
905+ removeArrowBinding ( this . editor , shape , handleId ) ;
906+ update . props ! [ handleId ] = { x : handle . x , y : handle . y } ;
907+ return update ;
908+ }
909+ }
910+ }
911+
848912 // we've got a target! the handle is being dragged over a shape, bind to it
849913
850914 const targetGeometry = this . editor . getShapeGeometry ( target ) ;
@@ -921,6 +985,42 @@ export const createAllRelationShapeUtils = (
921985 this . editor . setHintingShapes ( [ target . id ] ) ;
922986
923987 const newBindings = getArrowBindings ( this . editor , shape ) ;
988+
989+ // Check if both ends are bound and determine the correct text based on direction
990+ if ( newBindings . start && newBindings . end ) {
991+ const relations = Object . values ( discourseContext . relations ) . flat ( ) ;
992+ const relation = relations . find ( ( r ) => r . id === shape . type ) ;
993+
994+ if ( relation ) {
995+ const startNode = this . editor . getShape ( newBindings . start . toId ) ;
996+ const endNode = this . editor . getShape ( newBindings . end . toId ) ;
997+
998+ if ( startNode && endNode ) {
999+ const startNodeType = startNode . type ;
1000+ const endNodeType = endNode . type ;
1001+
1002+ const { isReverse, matchingRelation } =
1003+ this . checkConnectionTypeAcrossLabel (
1004+ relation . label ,
1005+ startNodeType ,
1006+ endNodeType ,
1007+ ) ;
1008+
1009+ const effectiveRelation = matchingRelation ?? relation ;
1010+
1011+ const newText =
1012+ isReverse && effectiveRelation . complement
1013+ ? effectiveRelation . complement
1014+ : effectiveRelation . label ;
1015+
1016+ if ( shape . props . text !== newText ) {
1017+ update . props = update . props || { } ;
1018+ update . props . text = newText ;
1019+ }
1020+ }
1021+ }
1022+ }
1023+
9241024 if (
9251025 newBindings . start &&
9261026 newBindings . end &&
@@ -1600,6 +1700,79 @@ export class BaseDiscourseRelationUtil extends ShapeUtil<DiscourseRelationShape>
16001700 ] ;
16011701 }
16021702
1703+ checkConnectionType (
1704+ relation : { source : string ; destination : string } ,
1705+ sourceNodeType : string ,
1706+ targetNodeType : string ,
1707+ ) : { isDirect : boolean ; isReverse : boolean } {
1708+ const isDirect =
1709+ sourceNodeType === relation . source &&
1710+ targetNodeType === relation . destination ;
1711+
1712+ const isReverse =
1713+ sourceNodeType === relation . destination &&
1714+ targetNodeType === relation . source ;
1715+
1716+ return { isDirect, isReverse } ;
1717+ }
1718+
1719+ checkConnectionTypeAcrossLabel (
1720+ label : string ,
1721+ sourceNodeType : string ,
1722+ targetNodeType : string ,
1723+ ) : {
1724+ isDirect : boolean ;
1725+ isReverse : boolean ;
1726+ matchingRelation : DiscourseRelation | null ;
1727+ } {
1728+ const relationsWithLabel = discourseContext . relations [ label ] ;
1729+ if ( ! relationsWithLabel ) {
1730+ return { isDirect : false , isReverse : false , matchingRelation : null } ;
1731+ }
1732+
1733+ for ( const rel of relationsWithLabel ) {
1734+ const { isDirect, isReverse } = this . checkConnectionType (
1735+ rel ,
1736+ sourceNodeType ,
1737+ targetNodeType ,
1738+ ) ;
1739+ if ( isDirect || isReverse ) {
1740+ return { isDirect, isReverse, matchingRelation : rel } ;
1741+ }
1742+ }
1743+
1744+ return { isDirect : false , isReverse : false , matchingRelation : null } ;
1745+ }
1746+
1747+ getValidTargetTypes ( label : string , sourceNodeType : string ) : string [ ] {
1748+ const relationsWithLabel = discourseContext . relations [ label ] ;
1749+ if ( ! relationsWithLabel ) return [ ] ;
1750+
1751+ const targets = new Set < string > ( ) ;
1752+ for ( const rel of relationsWithLabel ) {
1753+ if ( rel . source === sourceNodeType ) targets . add ( rel . destination ) ;
1754+ if ( rel . destination === sourceNodeType ) targets . add ( rel . source ) ;
1755+ }
1756+ return [ ...targets ] ;
1757+ }
1758+
1759+ isValidNodeConnection (
1760+ sourceNodeType : string ,
1761+ targetNodeType : string ,
1762+ relationId : string ,
1763+ ) : boolean {
1764+ const relations = Object . values ( discourseContext . relations ) . flat ( ) ;
1765+ const relation = relations . find ( ( r ) => r . id === relationId ) ;
1766+ if ( ! relation ) return false ;
1767+
1768+ const { isDirect, isReverse } = this . checkConnectionTypeAcrossLabel (
1769+ relation . label ,
1770+ sourceNodeType ,
1771+ targetNodeType ,
1772+ ) ;
1773+ return isDirect || isReverse ;
1774+ }
1775+
16031776 component ( shape : DiscourseRelationShape ) {
16041777 // eslint-disable-next-line react-hooks/rules-of-hooks
16051778 // const theme = useDefaultColorTheme();
0 commit comments