From da35f2bd75e43fd48d11a9a74f60ee01c84a41d1 Mon Sep 17 00:00:00 2001 From: alex Date: Wed, 8 May 2024 13:37:31 +0100 Subject: [PATCH] Bindings (#3326) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit First draft of the new bindings API. We'll follow this up with some API refinements, tests, documentation, and examples. Bindings are a new record type for establishing relationships between two shapes so they can update at the same time. ### Change Type - [x] `sdk` — Changes the tldraw SDK - [x] `feature` — New feature ### Release Notes #### Breaking changes - The `start` and `end` properties on `TLArrowShape` no longer have `type: point | binding`. Instead, they're always a point, which may be out of date if a binding exists. To check for & retrieve arrow bindings, use `getArrowBindings(editor, shape)` instead. - `getArrowTerminalsInArrowSpace` must be passed a `TLArrowBindings` as a third argument: `getArrowTerminalsInArrowSpace(editor, shape, getArrowBindings(editor, shape))` - The following types have been renamed: - `ShapeProps` -> `RecordProps` - `ShapePropsType` -> `RecordPropsType` - `TLShapePropsMigrations` -> `TLPropsMigrations` - `SchemaShapeInfo` -> `SchemaPropsInfo` --------- Co-authored-by: David Sheldrick --- .../e2e/tests/export-snapshots.spec.ts | 12 +- .../PlayingCardShape/playing-card-util.tsx | 4 +- .../CardShape/card-shape-props.ts | 4 +- .../custom-shape/CustomShapeExample.tsx | 4 +- .../editable-shape/EditableShapeUtil.tsx | 4 +- .../src/examples/exploded/ExplodedExample.tsx | 2 + .../my-interactive-shape-util.tsx | 4 +- .../examples/popup-shape/PopupShapeUtil.tsx | 4 +- .../src/examples/slides/SlideShapeUtil.tsx | 4 +- .../SpeechBubble/SpeechBubbleUtil.tsx | 4 +- .../src/examples/sticker-bindings/README.md | 9 + .../sticker-bindings/StickerExample.tsx | 235 ++++ packages/editor/api-report.md | 174 ++- packages/editor/src/index.ts | 18 +- packages/editor/src/lib/TldrawEditor.tsx | 43 +- .../editor/src/lib/config/createTLStore.ts | 26 +- .../editor/src/lib/config/defaultBindings.ts | 19 + packages/editor/src/lib/editor/Editor.ts | 1067 +++++++++-------- .../src/lib/editor/bindings/BindingUtil.ts | 77 ++ .../editor/derivations/arrowBindingsIndex.ts | 141 --- .../editor/src/lib/editor/shapes/ShapeUtil.ts | 19 +- .../editor/shapes/shared/arrow/arrow-types.ts | 3 + .../shapes/shared/arrow/curved-arrow.ts | 17 +- .../lib/editor/shapes/shared/arrow/shared.ts | 154 ++- .../shapes/shared/arrow/straight-arrow.ts | 17 +- .../src/lib/editor/types/clipboard-types.ts | 3 +- .../src/lib/test/currentToolIdMask.test.ts | 1 + packages/editor/src/lib/test/user.test.ts | 1 + packages/store/api-report.md | 7 +- packages/store/src/index.ts | 1 + packages/store/src/lib/migrate.ts | 9 +- packages/tldraw/api-report.md | 69 +- packages/tldraw/src/index.ts | 1 + packages/tldraw/src/lib/Tldraw.tsx | 9 + packages/tldraw/src/lib/TldrawImage.tsx | 15 +- .../lib/bindings/arrow/ArrowBindingUtil.ts | 224 ++++ .../tldraw/src/lib/defaultBindingUtils.ts | 5 + .../lib/shapes/arrow/ArrowShapeTool.test.ts | 163 +-- .../lib/shapes/arrow/ArrowShapeUtil.test.ts | 293 ++--- .../src/lib/shapes/arrow/ArrowShapeUtil.tsx | 258 ++-- .../tldraw/src/lib/shapes/arrow/arrowLabel.ts | 4 +- .../lib/shapes/arrow/toolStates/Pointing.ts | 46 +- .../SelectTool/childStates/DraggingHandle.tsx | 28 +- .../SelectTool/childStates/PointingHandle.ts | 7 +- .../hooks/clipboard/pasteExcalidrawContent.ts | 2 + .../tldraw/src/lib/ui/hooks/menu-hooks.ts | 22 +- .../buildFromV1Document.test.ts.snap | 419 +++++++ .../utils/tldr/buildFromV1Document.test.ts | 44 + .../src/lib/utils/tldr/buildFromV1Document.ts | 32 +- .../tldr/test-fixtures/arrow-binding.tldr | 1 + .../test-fixtures/exact-arrow-binding.tldr | 1 + .../incorrect-arrow-binding.tldr | 1 + packages/tldraw/src/test/TestEditor.ts | 10 +- .../src/test/arrowBindingsIndex.test.tsx | 295 ----- .../tldraw/src/test/arrows-megabus.test.tsx | 270 ++--- packages/tldraw/src/test/cleanup.test.ts | 44 +- .../src/test/commands/clipboard.test.ts | 122 +- .../src/test/commands/deleteShapes.test.ts | 65 +- .../test/commands/moveShapesToPage.test.ts | 2 +- packages/tldraw/src/test/duplicate.test.ts | 216 ++-- packages/tldraw/src/test/flipShapes.test.ts | 147 ++- packages/tldraw/src/test/frames.test.ts | 26 +- .../tldraw/src/test/getCulledShapes.test.tsx | 73 +- packages/tldraw/src/test/groups.test.ts | 77 +- packages/tldraw/src/test/select.test.tsx | 8 +- .../tldraw/src/test/selection-omnibus.test.ts | 4 +- packages/tldraw/src/test/translating.test.ts | 126 +- packages/tlschema/api-report.md | 231 ++-- .../src/__tests__/migrationTestUtils.ts | 14 + .../tlschema/src/bindings/TLArrowBinding.ts | 39 + .../tlschema/src/bindings/TLBaseBinding.ts | 41 + packages/tlschema/src/createTLSchema.ts | 26 +- packages/tlschema/src/index.ts | 38 +- packages/tlschema/src/migrations.test.ts | 230 ++++ packages/tlschema/src/records/TLBinding.ts | 111 ++ packages/tlschema/src/records/TLRecord.ts | 2 + packages/tlschema/src/records/TLShape.ts | 150 +-- packages/tlschema/src/recordsWithProps.ts | 147 +++ packages/tlschema/src/shapes/TLArrowShape.ts | 134 ++- packages/tlschema/src/shapes/TLBaseShape.ts | 12 +- .../tlschema/src/shapes/TLBookmarkShape.ts | 11 +- packages/tlschema/src/shapes/TLDrawShape.ts | 11 +- packages/tlschema/src/shapes/TLEmbedShape.ts | 11 +- packages/tlschema/src/shapes/TLFrameShape.ts | 5 +- packages/tlschema/src/shapes/TLGeoShape.ts | 11 +- packages/tlschema/src/shapes/TLGroupShape.ts | 5 +- .../tlschema/src/shapes/TLHighlightShape.ts | 5 +- packages/tlschema/src/shapes/TLImageShape.ts | 11 +- packages/tlschema/src/shapes/TLLineShape.ts | 11 +- packages/tlschema/src/shapes/TLNoteShape.ts | 11 +- packages/tlschema/src/shapes/TLTextShape.ts | 11 +- packages/tlschema/src/shapes/TLVideoShape.ts | 11 +- packages/tlsync/src/test/FuzzEditor.ts | 9 +- packages/tlsync/src/test/TLSyncRoom.test.ts | 4 +- packages/tlsync/src/test/syncFuzz.test.ts | 15 +- 95 files changed, 4087 insertions(+), 2446 deletions(-) create mode 100644 apps/examples/src/examples/sticker-bindings/README.md create mode 100644 apps/examples/src/examples/sticker-bindings/StickerExample.tsx create mode 100644 packages/editor/src/lib/config/defaultBindings.ts create mode 100644 packages/editor/src/lib/editor/bindings/BindingUtil.ts delete mode 100644 packages/editor/src/lib/editor/derivations/arrowBindingsIndex.ts create mode 100644 packages/tldraw/src/lib/bindings/arrow/ArrowBindingUtil.ts create mode 100644 packages/tldraw/src/lib/defaultBindingUtils.ts create mode 100644 packages/tldraw/src/lib/utils/tldr/__snapshots__/buildFromV1Document.test.ts.snap create mode 100644 packages/tldraw/src/lib/utils/tldr/buildFromV1Document.test.ts create mode 100644 packages/tldraw/src/lib/utils/tldr/test-fixtures/arrow-binding.tldr create mode 100644 packages/tldraw/src/lib/utils/tldr/test-fixtures/exact-arrow-binding.tldr create mode 100644 packages/tldraw/src/lib/utils/tldr/test-fixtures/incorrect-arrow-binding.tldr delete mode 100644 packages/tldraw/src/test/arrowBindingsIndex.test.tsx create mode 100644 packages/tlschema/src/bindings/TLArrowBinding.ts create mode 100644 packages/tlschema/src/bindings/TLBaseBinding.ts create mode 100644 packages/tlschema/src/records/TLBinding.ts create mode 100644 packages/tlschema/src/recordsWithProps.ts diff --git a/apps/examples/e2e/tests/export-snapshots.spec.ts b/apps/examples/e2e/tests/export-snapshots.spec.ts index 8ac0cfd92..5246f5e23 100644 --- a/apps/examples/e2e/tests/export-snapshots.spec.ts +++ b/apps/examples/e2e/tests/export-snapshots.spec.ts @@ -69,8 +69,8 @@ test.describe('Export snapshots', () => { fill: fill, arrowheadStart: 'square', arrowheadEnd: 'dot', - start: { type: 'point', x: 0, y: 0 }, - end: { type: 'point', x: 100, y: 100 }, + start: { x: 0, y: 0 }, + end: { x: 100, y: 100 }, bend: 20, }, }, @@ -149,8 +149,8 @@ test.describe('Export snapshots', () => { arrowheadStart: 'square', arrowheadEnd: 'arrow', font, - start: { type: 'point', x: 0, y: 0 }, - end: { type: 'point', x: 100, y: 100 }, + start: { x: 0, y: 0 }, + end: { x: 100, y: 100 }, bend: 20, text: 'test', }, @@ -167,8 +167,8 @@ test.describe('Export snapshots', () => { arrowheadStart: 'square', arrowheadEnd: 'arrow', font, - start: { type: 'point', x: 0, y: 0 }, - end: { type: 'point', x: 100, y: 100 }, + start: { x: 0, y: 0 }, + end: { x: 100, y: 100 }, bend: 20, text: 'test', }, diff --git a/apps/examples/src/examples/bounds-snapping-shape/PlayingCardShape/playing-card-util.tsx b/apps/examples/src/examples/bounds-snapping-shape/PlayingCardShape/playing-card-util.tsx index 77f521369..c42582172 100644 --- a/apps/examples/src/examples/bounds-snapping-shape/PlayingCardShape/playing-card-util.tsx +++ b/apps/examples/src/examples/bounds-snapping-shape/PlayingCardShape/playing-card-util.tsx @@ -2,8 +2,8 @@ import { BaseBoxShapeUtil, BoundsSnapGeometry, HTMLContainer, + RecordProps, Rectangle2d, - ShapeProps, T, TLBaseShape, } from 'tldraw' @@ -23,7 +23,7 @@ type IPlayingCard = TLBaseShape< export class PlayingCardUtil extends BaseBoxShapeUtil { // [2] static override type = 'PlayingCard' as const - static override props: ShapeProps = { + static override props: RecordProps = { w: T.number, h: T.number, suit: T.string, diff --git a/apps/examples/src/examples/custom-config/CardShape/card-shape-props.ts b/apps/examples/src/examples/custom-config/CardShape/card-shape-props.ts index 88d55c551..9f078653b 100644 --- a/apps/examples/src/examples/custom-config/CardShape/card-shape-props.ts +++ b/apps/examples/src/examples/custom-config/CardShape/card-shape-props.ts @@ -1,8 +1,8 @@ -import { DefaultColorStyle, ShapeProps, T } from 'tldraw' +import { DefaultColorStyle, RecordProps, T } from 'tldraw' import { ICardShape } from './card-shape-types' // Validation for our custom card shape's props, using one of tldraw's default styles -export const cardShapeProps: ShapeProps = { +export const cardShapeProps: RecordProps = { w: T.number, h: T.number, color: DefaultColorStyle, diff --git a/apps/examples/src/examples/custom-shape/CustomShapeExample.tsx b/apps/examples/src/examples/custom-shape/CustomShapeExample.tsx index 7660a8282..d3e9e247a 100644 --- a/apps/examples/src/examples/custom-shape/CustomShapeExample.tsx +++ b/apps/examples/src/examples/custom-shape/CustomShapeExample.tsx @@ -1,8 +1,8 @@ import { Geometry2d, HTMLContainer, + RecordProps, Rectangle2d, - ShapeProps, ShapeUtil, T, TLBaseShape, @@ -28,7 +28,7 @@ type ICustomShape = TLBaseShape< export class MyShapeUtil extends ShapeUtil { // [a] static override type = 'my-custom-shape' as const - static override props: ShapeProps = { + static override props: RecordProps = { w: T.number, h: T.number, text: T.string, diff --git a/apps/examples/src/examples/editable-shape/EditableShapeUtil.tsx b/apps/examples/src/examples/editable-shape/EditableShapeUtil.tsx index a0e13d7f1..9cb009424 100644 --- a/apps/examples/src/examples/editable-shape/EditableShapeUtil.tsx +++ b/apps/examples/src/examples/editable-shape/EditableShapeUtil.tsx @@ -1,7 +1,7 @@ import { BaseBoxShapeUtil, HTMLContainer, - ShapeProps, + RecordProps, T, TLBaseShape, TLOnEditEndHandler, @@ -23,7 +23,7 @@ type IMyEditableShape = TLBaseShape< export class EditableShapeUtil extends BaseBoxShapeUtil { static override type = 'my-editable-shape' as const - static override props: ShapeProps = { + static override props: RecordProps = { w: T.number, h: T.number, animal: T.number, diff --git a/apps/examples/src/examples/exploded/ExplodedExample.tsx b/apps/examples/src/examples/exploded/ExplodedExample.tsx index 8a8fa7812..37a438987 100644 --- a/apps/examples/src/examples/exploded/ExplodedExample.tsx +++ b/apps/examples/src/examples/exploded/ExplodedExample.tsx @@ -9,6 +9,7 @@ import { TldrawSelectionBackground, TldrawSelectionForeground, TldrawUi, + defaultBindingUtils, defaultEditorAssetUrls, defaultShapeTools, defaultShapeUtils, @@ -45,6 +46,7 @@ export default function ExplodedExample() { { static override type = 'my-interactive-shape' as const - static override props: ShapeProps = { + static override props: RecordProps = { w: T.number, h: T.number, checked: T.boolean, diff --git a/apps/examples/src/examples/popup-shape/PopupShapeUtil.tsx b/apps/examples/src/examples/popup-shape/PopupShapeUtil.tsx index 184552e79..3d83ea4db 100644 --- a/apps/examples/src/examples/popup-shape/PopupShapeUtil.tsx +++ b/apps/examples/src/examples/popup-shape/PopupShapeUtil.tsx @@ -3,7 +3,7 @@ import { useEffect, useRef, useState } from 'react' import { BaseBoxShapeUtil, HTMLContainer, - ShapeProps, + RecordProps, T, TLBaseShape, stopEventPropagation, @@ -20,7 +20,7 @@ type IMyPopupShape = TLBaseShape< export class PopupShapeUtil extends BaseBoxShapeUtil { static override type = 'my-popup-shape' as const - static override props: ShapeProps = { + static override props: RecordProps = { w: T.number, h: T.number, animal: T.number, diff --git a/apps/examples/src/examples/slides/SlideShapeUtil.tsx b/apps/examples/src/examples/slides/SlideShapeUtil.tsx index e2c01542b..58aa9c5ca 100644 --- a/apps/examples/src/examples/slides/SlideShapeUtil.tsx +++ b/apps/examples/src/examples/slides/SlideShapeUtil.tsx @@ -1,9 +1,9 @@ import { useCallback } from 'react' import { Geometry2d, + RecordProps, Rectangle2d, SVGContainer, - ShapeProps, ShapeUtil, T, TLBaseShape, @@ -24,7 +24,7 @@ export type SlideShape = TLBaseShape< export class SlideShapeUtil extends ShapeUtil { static override type = 'slide' as const - static override props: ShapeProps = { + static override props: RecordProps = { w: T.number, h: T.number, } diff --git a/apps/examples/src/examples/speech-bubble/SpeechBubble/SpeechBubbleUtil.tsx b/apps/examples/src/examples/speech-bubble/SpeechBubble/SpeechBubbleUtil.tsx index 1c0d3e1f0..45f9183ac 100644 --- a/apps/examples/src/examples/speech-bubble/SpeechBubble/SpeechBubbleUtil.tsx +++ b/apps/examples/src/examples/speech-bubble/SpeechBubble/SpeechBubbleUtil.tsx @@ -8,7 +8,7 @@ import { Geometry2d, LABEL_FONT_SIZES, Polygon2d, - ShapePropsType, + RecordPropsType, ShapeUtil, T, TEXT_PROPS, @@ -52,7 +52,7 @@ export const speechBubbleShapeProps = { tail: vecModelValidator, } -export type SpeechBubbleShapeProps = ShapePropsType +export type SpeechBubbleShapeProps = RecordPropsType export type SpeechBubbleShape = TLBaseShape<'speech-bubble', SpeechBubbleShapeProps> export class SpeechBubbleUtil extends ShapeUtil { diff --git a/apps/examples/src/examples/sticker-bindings/README.md b/apps/examples/src/examples/sticker-bindings/README.md new file mode 100644 index 000000000..70c946e58 --- /dev/null +++ b/apps/examples/src/examples/sticker-bindings/README.md @@ -0,0 +1,9 @@ +--- +title: Sticker (bindings) +component: ./StickerExample.tsx +category: shapes/tools +--- + +A sticker shape, using bindings to attach shapes to one and other + +--- diff --git a/apps/examples/src/examples/sticker-bindings/StickerExample.tsx b/apps/examples/src/examples/sticker-bindings/StickerExample.tsx new file mode 100644 index 000000000..b5a9a664f --- /dev/null +++ b/apps/examples/src/examples/sticker-bindings/StickerExample.tsx @@ -0,0 +1,235 @@ +import { + BindingOnShapeChangeOptions, + BindingOnShapeDeleteOptions, + BindingUtil, + Box, + DefaultToolbar, + DefaultToolbarContent, + RecordProps, + Rectangle2d, + ShapeUtil, + StateNode, + TLBaseBinding, + TLBaseShape, + TLEventHandlers, + TLOnTranslateEndHandler, + TLOnTranslateStartHandler, + TLUiComponents, + TLUiOverrides, + Tldraw, + TldrawUiMenuItem, + VecModel, + createShapeId, + invLerp, + lerp, + useIsToolSelected, + useTools, +} from 'tldraw' + +// eslint-disable-next-line @typescript-eslint/ban-types +type StickerShape = TLBaseShape<'sticker', {}> + +const offsetX = -16 +const offsetY = -26 +class StickerShapeUtil extends ShapeUtil { + static override type = 'sticker' as const + static override props: RecordProps = {} + + override getDefaultProps() { + return {} + } + + override canBind = () => false + override canEdit = () => false + override canResize = () => false + override hideRotateHandle = () => true + override isAspectRatioLocked = () => true + + override getGeometry() { + return new Rectangle2d({ + width: 32, + height: 32, + x: offsetX, + y: offsetY, + isFilled: true, + }) + } + + override component() { + return ( +
+ ❤️ +
+ ) + } + + override indicator() { + return + } + + override onTranslateStart: TLOnTranslateStartHandler = (shape) => { + const bindings = this.editor.getBindingsFromShape(shape, 'sticker') + this.editor.deleteBindings(bindings) + } + + override onTranslateEnd: TLOnTranslateEndHandler = (initial, sticker) => { + const pageAnchor = this.editor.getShapePageTransform(sticker).applyToPoint({ x: 0, y: 0 }) + const target = this.editor.getShapeAtPoint(pageAnchor, { + hitInside: true, + filter: (shape) => shape.id !== sticker.id, + }) + + if (!target) return + + const targetBounds = Box.ZeroFix(this.editor.getShapeGeometry(target)!.bounds) + const pointInTargetSpace = this.editor.getPointInShapeSpace(target, pageAnchor) + + const anchor = { + x: invLerp(targetBounds.minX, targetBounds.maxX, pointInTargetSpace.x), + y: invLerp(targetBounds.minY, targetBounds.maxY, pointInTargetSpace.y), + } + + this.editor.createBinding({ + type: 'sticker', + fromId: sticker.id, + toId: target.id, + props: { + anchor, + }, + }) + } +} + +type StickerBinding = TLBaseBinding< + 'sticker', + { + anchor: VecModel + } +> +class StickerBindingUtil extends BindingUtil { + static override type = 'sticker' as const + + override getDefaultProps() { + return { + anchor: { x: 0.5, y: 0.5 }, + } + } + + // when the shape we're stuck to changes, update the sticker's position + override onAfterChangeToShape({ + binding, + shapeAfter, + }: BindingOnShapeChangeOptions): void { + const sticker = this.editor.getShape(binding.fromId)! + + const shapeBounds = this.editor.getShapeGeometry(shapeAfter)!.bounds + const shapeAnchor = { + x: lerp(shapeBounds.minX, shapeBounds.maxX, binding.props.anchor.x), + y: lerp(shapeBounds.minY, shapeBounds.maxY, binding.props.anchor.y), + } + const pageAnchor = this.editor.getShapePageTransform(shapeAfter).applyToPoint(shapeAnchor) + + const stickerParentAnchor = this.editor + .getShapeParentTransform(sticker) + .invert() + .applyToPoint(pageAnchor) + + this.editor.updateShape({ + id: sticker.id, + type: 'sticker', + x: stickerParentAnchor.x, + y: stickerParentAnchor.y, + }) + } + + // when the thing we're stuck to is deleted, delete the sticker too + override onBeforeDeleteToShape({ binding }: BindingOnShapeDeleteOptions): void { + const sticker = this.editor.getShape(binding.fromId) + if (sticker) this.editor.deleteShape(sticker.id) + } +} + +class StickerTool extends StateNode { + static override id = 'sticker' + + override onEnter = () => { + this.editor.setCursor({ type: 'cross', rotation: 0 }) + } + + override onPointerDown: TLEventHandlers['onPointerDown'] = (info) => { + const { currentPagePoint } = this.editor.inputs + const stickerId = createShapeId() + this.editor.mark(`creating:${stickerId}`) + this.editor.createShape({ + id: stickerId, + type: 'sticker', + x: currentPagePoint.x, + y: currentPagePoint.y, + }) + this.editor.setSelectedShapes([stickerId]) + this.editor.setCurrentTool('select.translating', { + ...info, + target: 'shape', + shape: this.editor.getShape(stickerId), + isCreating: true, + onInteractionEnd: 'sticker', + onCreate: () => { + this.editor.setCurrentTool('sticker') + }, + }) + } +} + +const overrides: TLUiOverrides = { + tools(editor, schema) { + schema['sticker'] = { + id: 'sticker', + label: 'Sticker', + icon: 'heart-icon', + kbd: 'p', + onSelect: () => { + editor.setCurrentTool('sticker') + }, + } + return schema + }, +} + +const components: TLUiComponents = { + Toolbar: (...props) => { + const sticker = useTools().sticker + const isStickerSelected = useIsToolSelected(sticker) + return ( + + + + + ) + }, +} + +export default function StickerExample() { + return ( +
+ { + ;(window as any).editor = editor + }} + shapeUtils={[StickerShapeUtil]} + bindingUtils={[StickerBindingUtil]} + tools={[StickerTool]} + overrides={overrides} + components={components} + /> +
+ ) +} diff --git a/packages/editor/api-report.md b/packages/editor/api-report.md index 611ff56ef..e4262d7ab 100644 --- a/packages/editor/api-report.md +++ b/packages/editor/api-report.md @@ -30,22 +30,27 @@ import { default as React_2 } from 'react'; import * as React_3 from 'react'; import { ReactElement } from 'react'; import { ReactNode } from 'react'; +import { RecordProps } from '@tldraw/tlschema'; import { RecordsDiff } from '@tldraw/store'; import { SerializedSchema } from '@tldraw/store'; import { SerializedStore } from '@tldraw/store'; -import { ShapeProps } from '@tldraw/tlschema'; import { Signal } from '@tldraw/state'; import { Store } from '@tldraw/store'; import { StoreSchema } from '@tldraw/store'; import { StoreSnapshot } from '@tldraw/store'; import { StyleProp } from '@tldraw/tlschema'; import { StylePropValue } from '@tldraw/tlschema'; +import { TLArrowBinding } from '@tldraw/tlschema'; +import { TLArrowBindingProps } from '@tldraw/tlschema'; import { TLArrowShape } from '@tldraw/tlschema'; import { TLArrowShapeArrowheadStyle } from '@tldraw/tlschema'; import { TLAsset } from '@tldraw/tlschema'; import { TLAssetId } from '@tldraw/tlschema'; import { TLAssetPartial } from '@tldraw/tlschema'; import { TLBaseShape } from '@tldraw/tlschema'; +import { TLBinding } from '@tldraw/tlschema'; +import { TLBindingId } from '@tldraw/tlschema'; +import { TLBindingPartial } from '@tldraw/tlschema'; import { TLBookmarkAsset } from '@tldraw/tlschema'; import { TLCamera } from '@tldraw/tlschema'; import { TLCursor } from '@tldraw/tlschema'; @@ -61,14 +66,15 @@ import { TLInstancePresence } from '@tldraw/tlschema'; import { TLPage } from '@tldraw/tlschema'; import { TLPageId } from '@tldraw/tlschema'; import { TLParentId } from '@tldraw/tlschema'; +import { TLPropsMigrations } from '@tldraw/tlschema'; import { TLRecord } from '@tldraw/tlschema'; import { TLScribble } from '@tldraw/tlschema'; import { TLShape } from '@tldraw/tlschema'; import { TLShapeId } from '@tldraw/tlschema'; import { TLShapePartial } from '@tldraw/tlschema'; -import { TLShapePropsMigrations } from '@tldraw/tlschema'; import { TLStore } from '@tldraw/tlschema'; import { TLStoreProps } from '@tldraw/tlschema'; +import { TLUnknownBinding } from '@tldraw/tlschema'; import { TLUnknownShape } from '@tldraw/tlschema'; import { TLVideoAsset } from '@tldraw/tlschema'; import { track } from '@tldraw/state'; @@ -170,6 +176,77 @@ export abstract class BaseBoxShapeUtil extends Sha onResize: TLOnResizeHandler; } +// @public (undocumented) +export interface BindingOnChangeOptions { + // (undocumented) + bindingAfter: Binding; + // (undocumented) + bindingBefore: Binding; +} + +// @public (undocumented) +export interface BindingOnCreateOptions { + // (undocumented) + binding: Binding; +} + +// @public (undocumented) +export interface BindingOnDeleteOptions { + // (undocumented) + binding: Binding; +} + +// @public (undocumented) +export interface BindingOnShapeChangeOptions { + // (undocumented) + binding: Binding; + // (undocumented) + shapeAfter: TLShape; + // (undocumented) + shapeBefore: TLShape; +} + +// @public (undocumented) +export interface BindingOnShapeDeleteOptions { + // (undocumented) + binding: Binding; + // (undocumented) + shape: TLShape; +} + +// @public (undocumented) +export abstract class BindingUtil { + constructor(editor: Editor); + // (undocumented) + editor: Editor; + abstract getDefaultProps(): Partial; + // (undocumented) + static migrations?: TLPropsMigrations; + // (undocumented) + onAfterChange?(options: BindingOnChangeOptions): void; + // (undocumented) + onAfterChangeFromShape?(options: BindingOnShapeChangeOptions): void; + // (undocumented) + onAfterChangeToShape?(options: BindingOnShapeChangeOptions): void; + // (undocumented) + onAfterCreate?(options: BindingOnCreateOptions): void; + // (undocumented) + onAfterDelete?(options: BindingOnDeleteOptions): void; + // (undocumented) + onBeforeChange?(options: BindingOnChangeOptions): Binding | void; + // (undocumented) + onBeforeCreate?(options: BindingOnCreateOptions): Binding | void; + // (undocumented) + onBeforeDelete?(options: BindingOnDeleteOptions): void; + // (undocumented) + onBeforeDeleteFromShape?(options: BindingOnShapeDeleteOptions): void; + // (undocumented) + onBeforeDeleteToShape?(options: BindingOnShapeDeleteOptions): void; + // (undocumented) + static props?: RecordProps; + static type: string; +} + // @public export interface BoundsSnapGeometry { points?: VecModel[]; @@ -374,6 +451,9 @@ export const coreShapes: readonly [typeof GroupShapeUtil]; // @public export function counterClockwiseAngleDist(a0: number, a1: number): number; +// @internal +export function createOrUpdateArrowBinding(editor: Editor, arrow: TLArrowShape | TLShapeId, target: TLShape | TLShapeId, props: TLArrowBindingProps): void; + // @public export function createSessionStateSnapshotSignal(store: TLStore): Signal; @@ -590,7 +670,7 @@ export class Edge2d extends Geometry2d { // @public (undocumented) export class Editor extends EventEmitter { - constructor({ store, user, shapeUtils, tools, getContainer, cameraOptions, initialState, inferDarkMode, }: TLEditorOptions); + constructor({ store, user, shapeUtils, bindingUtils, tools, getContainer, cameraOptions, initialState, inferDarkMode, }: TLEditorOptions); addOpenMenu(id: string): this; alignShapes(shapes: TLShape[] | TLShapeId[], operation: 'bottom' | 'center-horizontal' | 'center-vertical' | 'left' | 'right' | 'top'): this; animateShape(partial: null | TLShapePartial | undefined, opts?: Partial<{ @@ -621,6 +701,9 @@ export class Editor extends EventEmitter { bail(): this; bailToMark(id: string): this; batch(fn: () => void, opts?: TLHistoryBatchOptions): this; + bindingUtils: { + readonly [K in string]?: BindingUtil; + }; bringForward(shapes: TLShape[] | TLShapeId[]): this; bringToFront(shapes: TLShape[] | TLShapeId[]): this; cancel(): this; @@ -635,6 +718,10 @@ export class Editor extends EventEmitter { // @internal (undocumented) crash(error: unknown): this; createAssets(assets: TLAsset[]): this; + // (undocumented) + createBinding(partial: RequiredKeys): this; + // (undocumented) + createBindings(partials: RequiredKeys[]): this; // @internal (undocumented) createErrorAnnotations(origin: string, willCrashApp: 'unknown' | boolean): { extras: { @@ -652,6 +739,10 @@ export class Editor extends EventEmitter { createShape(shape: OptionalKeys, 'id'>): this; createShapes(shapes: OptionalKeys, 'id'>[]): this; deleteAssets(assets: TLAsset[] | TLAssetId[]): this; + // (undocumented) + deleteBinding(binding: TLBinding | TLBindingId): this; + // (undocumented) + deleteBindings(bindings: (TLBinding | TLBindingId)[]): this; deleteOpenMenu(id: string): this; deletePage(page: TLPage | TLPageId): this; deleteShape(id: TLShapeId): this; @@ -687,16 +778,28 @@ export class Editor extends EventEmitter { findCommonAncestor(shapes: TLShape[] | TLShapeId[], predicate?: (shape: TLShape) => boolean): TLShapeId | undefined; findShapeAncestor(shape: TLShape | TLShapeId, predicate: (parent: TLShape) => boolean): TLShape | undefined; flipShapes(shapes: TLShape[] | TLShapeId[], operation: 'horizontal' | 'vertical'): this; + // (undocumented) + getAllBindingsFromShape(shape: TLShape | TLShapeId): TLBinding[]; + // (undocumented) + getAllBindingsToShape(shape: TLShape | TLShapeId): TLBinding[]; getAncestorPageId(shape?: TLShape | TLShapeId): TLPageId | undefined; getArrowInfo(shape: TLArrowShape | TLShapeId): TLArrowInfo | undefined; - getArrowsBoundTo(shapeId: TLShapeId): { - arrowId: TLShapeId; - handleId: "end" | "start"; - }[]; + getArrowsBoundTo(shapeId: TLShapeId): TLArrowShape[]; getAsset(asset: TLAsset | TLAssetId): TLAsset | undefined; getAssetForExternalContent(info: TLExternalAssetContent): Promise; getAssets(): (TLBookmarkAsset | TLImageAsset | TLVideoAsset)[]; getBaseZoom(): number; + // (undocumented) + getBinding(id: TLBindingId): TLBinding | undefined; + // (undocumented) + getBindingsFromShape(shape: TLShape | TLShapeId, type: Binding['type']): Binding[]; + // (undocumented) + getBindingsToShape(shape: TLShape | TLShapeId, type: Binding['type']): Binding[]; + getBindingUtil(binding: S | TLBindingPartial): BindingUtil; + // (undocumented) + getBindingUtil(type: S['type']): BindingUtil; + // (undocumented) + getBindingUtil(type: T extends BindingUtil ? R['type'] : string): T; getCamera(): TLCamera; getCameraOptions(): TLCameraOptions; getCameraState(): "idle" | "moving"; @@ -783,6 +886,8 @@ export class Editor extends EventEmitter { getShapeLocalTransform(shape: TLShape | TLShapeId): Mat; getShapeMask(shape: TLShape | TLShapeId): undefined | VecLike[]; getShapeMaskedPageBounds(shape: TLShape | TLShapeId): Box | undefined; + // @internal + getShapeNearestSibling(siblingShape: TLShape, targetShape: TLShape | undefined): TLShape | undefined; getShapePageBounds(shape: TLShape | TLShapeId): Box | undefined; getShapePageTransform(shape: TLShape | TLShapeId): Mat; getShapeParent(shape?: TLShape | TLShapeId): TLShape | undefined; @@ -944,6 +1049,10 @@ export class Editor extends EventEmitter { // (undocumented) ungroupShapes(ids: TLShape[]): this; updateAssets(assets: TLAssetPartial[]): this; + // (undocumented) + updateBinding(partial: TLBindingPartial): this; + // (undocumented) + updateBindings(partials: (null | TLBindingPartial | undefined)[]): this; updateCurrentPageState(partial: Partial>, historyOptions?: TLHistoryBatchOptions): this; // (undocumented) _updateCurrentPageState: (partial: Partial>, historyOptions?: TLHistoryBatchOptions) => void; @@ -1088,7 +1197,10 @@ export abstract class Geometry2d { export function getArcMeasure(A: number, B: number, sweepFlag: number, largeArcFlag: number): number; // @public (undocumented) -export function getArrowTerminalsInArrowSpace(editor: Editor, shape: TLArrowShape): { +export function getArrowBindings(editor: Editor, shape: TLArrowShape): TLArrowBindings; + +// @public (undocumented) +export function getArrowTerminalsInArrowSpace(editor: Editor, shape: TLArrowShape, bindings: TLArrowBindings): { end: Vec; start: Vec; }; @@ -1184,11 +1296,11 @@ export class GroupShapeUtil extends ShapeUtil { // (undocumented) indicator(shape: TLGroupShape): JSX_2.Element; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) onChildrenChange: TLOnChildrenChangeHandler; // (undocumented) - static props: ShapeProps; + static props: RecordProps; // (undocumented) static type: "group"; } @@ -1602,6 +1714,9 @@ export function refreshPage(): void; // @public (undocumented) export function releasePointerCapture(element: Element, event: PointerEvent | React_2.PointerEvent): void; +// @internal +export function removeArrowBinding(editor: Editor, arrow: TLArrowShape, terminal: 'end' | 'start'): void; + // @public (undocumented) export type RequiredKeys = Partial> & Pick; @@ -1681,6 +1796,7 @@ export abstract class ShapeUtil { constructor(editor: Editor); // @internal backgroundComponent?(shape: Shape): any; + canBeLaidOut: TLShapeUtilFlag; canBind: (_shape: Shape, _otherShape?: K) => boolean; canCrop: TLShapeUtilFlag; canDropShapes(shape: Shape, shapes: TLShape[]): boolean; @@ -1708,7 +1824,7 @@ export abstract class ShapeUtil { abstract indicator(shape: Shape): any; isAspectRatioLocked: TLShapeUtilFlag; // (undocumented) - static migrations?: LegacyMigrations | TLShapePropsMigrations; + static migrations?: LegacyMigrations | MigrationSequence | TLPropsMigrations; onBeforeCreate?: TLOnBeforeCreateHandler; onBeforeUpdate?: TLOnBeforeUpdateHandler; // @internal @@ -1733,7 +1849,7 @@ export abstract class ShapeUtil { onTranslateEnd?: TLOnTranslateEndHandler; onTranslateStart?: TLOnTranslateStartHandler; // (undocumented) - static props?: ShapeProps; + static props?: RecordProps; // @internal providesBackgroundForChildren(shape: Shape): boolean; toBackgroundSvg?(shape: Shape, ctx: SvgExportContext): null | Promise | ReactElement; @@ -1974,6 +2090,9 @@ export type TLAfterCreateHandler = (record: R, source: 'remo // @public (undocumented) export type TLAfterDeleteHandler = (record: R, source: 'remote' | 'user') => void; +// @public (undocumented) +export type TLAnyBindingUtilConstructor = TLBindingUtilConstructor; + // @public (undocumented) export type TLAnyShapeUtilConstructor = TLShapeUtilConstructor; @@ -1993,8 +2112,17 @@ export interface TLArcInfo { sweepFlag: number; } +// @public (undocumented) +export interface TLArrowBindings { + // (undocumented) + end: TLArrowBinding | undefined; + // (undocumented) + start: TLArrowBinding | undefined; +} + // @public (undocumented) export type TLArrowInfo = { + bindings: TLArrowBindings; bodyArc: TLArcInfo; end: TLArrowPoint; handleArc: TLArcInfo; @@ -2003,6 +2131,7 @@ export type TLArrowInfo = { middle: VecLike; start: TLArrowPoint; } | { + bindings: TLArrowBindings; end: TLArrowPoint; isStraight: true; isValid: boolean; @@ -2048,6 +2177,18 @@ export type TLBeforeCreateHandler = (record: R, source: 'rem // @public (undocumented) export type TLBeforeDeleteHandler = (record: R, source: 'remote' | 'user') => false | void; +// @public (undocumented) +export interface TLBindingUtilConstructor = BindingUtil> { + // (undocumented) + new (editor: Editor): U; + // (undocumented) + migrations?: TLPropsMigrations; + // (undocumented) + props?: RecordProps; + // (undocumented) + type: T['type']; +} + // @public (undocumented) export type TLBrushProps = { brush: BoxModel; @@ -2136,6 +2277,8 @@ export interface TLContent { // (undocumented) assets: TLAsset[]; // (undocumented) + bindings: TLBinding[] | undefined; + // (undocumented) rootShapeIds: TLShapeId[]; // (undocumented) schema: SerializedSchema; @@ -2159,6 +2302,7 @@ export const TldrawEditor: React_2.NamedExoticComponent; // @public export interface TldrawEditorBaseProps { autoFocus?: boolean; + bindingUtils?: readonly TLAnyBindingUtilConstructor[]; cameraOptions?: Partial; children?: ReactNode; className?: string; @@ -2191,6 +2335,7 @@ export type TLEditorComponents = Partial<{ // @public (undocumented) export interface TLEditorOptions { + bindingUtils: readonly TLBindingUtilConstructor[]; cameraOptions?: Partial; getContainer: () => HTMLElement; inferDarkMode?: boolean; @@ -2618,9 +2763,9 @@ export interface TLShapeUtilConstructor; + props?: RecordProps; // (undocumented) type: T['type']; } @@ -2656,6 +2801,7 @@ export type TLStoreOptions = { id?: string; initialData?: SerializedStore; } & ({ + bindingUtils?: readonly TLAnyBindingUtilConstructor[]; migrations?: readonly MigrationSequence[]; shapeUtils?: readonly TLAnyShapeUtilConstructor[]; } | { diff --git a/packages/editor/src/index.ts b/packages/editor/src/index.ts index e6090e951..2d7e4f027 100644 --- a/packages/editor/src/index.ts +++ b/packages/editor/src/index.ts @@ -104,6 +104,7 @@ export { type TLStoreOptions, } from './lib/config/createTLStore' export { createTLUser } from './lib/config/createTLUser' +export { type TLAnyBindingUtilConstructor } from './lib/config/defaultBindings' export { coreShapes, type TLAnyShapeUtilConstructor } from './lib/config/defaultShapes' export { ANIMATION_MEDIUM_MS, @@ -122,6 +123,15 @@ export { SVG_PADDING, } from './lib/constants' export { Editor, type TLEditorOptions, type TLResizeShapeOptions } from './lib/editor/Editor' +export { + BindingUtil, + type BindingOnChangeOptions, + type BindingOnCreateOptions, + type BindingOnDeleteOptions, + type BindingOnShapeChangeOptions, + type BindingOnShapeDeleteOptions, + type TLBindingUtilConstructor, +} from './lib/editor/bindings/BindingUtil' export { HistoryManager } from './lib/editor/managers/HistoryManager' export type { SideEffectManager, @@ -178,7 +188,13 @@ export { type TLArrowInfo, type TLArrowPoint, } from './lib/editor/shapes/shared/arrow/arrow-types' -export { getArrowTerminalsInArrowSpace } from './lib/editor/shapes/shared/arrow/shared' +export { + createOrUpdateArrowBinding, + getArrowBindings, + getArrowTerminalsInArrowSpace, + removeArrowBinding, + type TLArrowBindings, +} from './lib/editor/shapes/shared/arrow/shared' export { resizeBox, type ResizeBoxOptions } from './lib/editor/shapes/shared/resizeBox' export { BaseBoxShapeTool } from './lib/editor/tools/BaseBoxShapeTool/BaseBoxShapeTool' export { StateNode, type TLStateNodeConstructor } from './lib/editor/tools/StateNode' diff --git a/packages/editor/src/lib/TldrawEditor.tsx b/packages/editor/src/lib/TldrawEditor.tsx index a099118a6..565a73607 100644 --- a/packages/editor/src/lib/TldrawEditor.tsx +++ b/packages/editor/src/lib/TldrawEditor.tsx @@ -15,6 +15,7 @@ import classNames from 'classnames' import { OptionalErrorBoundary } from './components/ErrorBoundary' import { DefaultErrorFallback } from './components/default-components/DefaultErrorFallback' import { TLUser, createTLUser } from './config/createTLUser' +import { TLAnyBindingUtilConstructor } from './config/defaultBindings' import { TLAnyShapeUtilConstructor } from './config/defaultShapes' import { Editor } from './editor/Editor' import { TLStateNodeConstructor } from './editor/tools/StateNode' @@ -76,6 +77,11 @@ export interface TldrawEditorBaseProps { */ shapeUtils?: readonly TLAnyShapeUtilConstructor[] + /** + * An array of binding utils to use in the editor. + */ + bindingUtils?: readonly TLAnyBindingUtilConstructor[] + /** * An array of tools to add to the editor's state chart. */ @@ -141,6 +147,7 @@ declare global { } const EMPTY_SHAPE_UTILS_ARRAY = [] as const +const EMPTY_BINDING_UTILS_ARRAY = [] as const const EMPTY_TOOLS_ARRAY = [] as const /** @public */ @@ -163,6 +170,7 @@ export const TldrawEditor = memo(function TldrawEditor({ const withDefaults = { ...rest, shapeUtils: rest.shapeUtils ?? EMPTY_SHAPE_UTILS_ARRAY, + bindingUtils: rest.bindingUtils ?? EMPTY_BINDING_UTILS_ARRAY, tools: rest.tools ?? EMPTY_TOOLS_ARRAY, components, } @@ -203,12 +211,25 @@ export const TldrawEditor = memo(function TldrawEditor({ }) function TldrawEditorWithOwnStore( - props: Required + props: Required< + TldrawEditorProps & { store: undefined; user: TLUser }, + 'shapeUtils' | 'bindingUtils' | 'tools' + > ) { - const { defaultName, snapshot, initialData, shapeUtils, persistenceKey, sessionId, user } = props + const { + defaultName, + snapshot, + initialData, + shapeUtils, + bindingUtils, + persistenceKey, + sessionId, + user, + } = props const syncedStore = useLocalStore({ shapeUtils, + bindingUtils, initialData, persistenceKey, sessionId, @@ -225,7 +246,7 @@ const TldrawEditorWithLoadingStore = memo(function TldrawEditorBeforeLoading({ ...rest }: Required< TldrawEditorProps & { store: TLStoreWithStatus; user: TLUser }, - 'shapeUtils' | 'tools' + 'shapeUtils' | 'bindingUtils' | 'tools' >) { const container = useContainer() @@ -268,6 +289,7 @@ function TldrawEditorWithReadyStore({ store, tools, shapeUtils, + bindingUtils, user, initialState, autoFocus = true, @@ -278,7 +300,7 @@ function TldrawEditorWithReadyStore({ store: TLStore user: TLUser }, - 'shapeUtils' | 'tools' + 'shapeUtils' | 'bindingUtils' | 'tools' >) { const { ErrorFallback } = useEditorComponents() const container = useContainer() @@ -288,6 +310,7 @@ function TldrawEditorWithReadyStore({ const editor = new Editor({ store, shapeUtils, + bindingUtils, tools, getContainer: () => container, user, @@ -300,7 +323,17 @@ function TldrawEditorWithReadyStore({ return () => { editor.dispose() } - }, [container, shapeUtils, tools, store, user, initialState, inferDarkMode, cameraOptions]) + }, [ + container, + shapeUtils, + bindingUtils, + tools, + store, + user, + initialState, + inferDarkMode, + cameraOptions, + ]) const crashingError = useSyncExternalStore( useCallback( diff --git a/packages/editor/src/lib/config/createTLStore.ts b/packages/editor/src/lib/config/createTLStore.ts index 3361a18a5..f71511e35 100644 --- a/packages/editor/src/lib/config/createTLStore.ts +++ b/packages/editor/src/lib/config/createTLStore.ts @@ -1,13 +1,6 @@ import { HistoryEntry, MigrationSequence, SerializedStore, Store, StoreSchema } from '@tldraw/store' -import { - SchemaShapeInfo, - TLRecord, - TLStore, - TLStoreProps, - TLUnknownShape, - createTLSchema, -} from '@tldraw/tlschema' -import { TLShapeUtilConstructor } from '../editor/shapes/ShapeUtil' +import { SchemaPropsInfo, TLRecord, TLStore, TLStoreProps, createTLSchema } from '@tldraw/tlschema' +import { TLAnyBindingUtilConstructor, checkBindings } from './defaultBindings' import { TLAnyShapeUtilConstructor, checkShapesAndAddCore } from './defaultShapes' /** @public */ @@ -16,7 +9,11 @@ export type TLStoreOptions = { defaultName?: string id?: string } & ( - | { shapeUtils?: readonly TLAnyShapeUtilConstructor[]; migrations?: readonly MigrationSequence[] } + | { + shapeUtils?: readonly TLAnyShapeUtilConstructor[] + migrations?: readonly MigrationSequence[] + bindingUtils?: readonly TLAnyBindingUtilConstructor[] + } | { schema?: StoreSchema } ) @@ -41,9 +38,12 @@ export function createTLStore({ rest.schema : // we need a schema createTLSchema({ - shapes: currentPageShapesToShapeMap( + shapes: utilsToMap( checkShapesAndAddCore('shapeUtils' in rest && rest.shapeUtils ? rest.shapeUtils : []) ), + bindings: utilsToMap( + checkBindings('bindingUtils' in rest && rest.bindingUtils ? rest.bindingUtils : []) + ), migrations: 'migrations' in rest ? rest.migrations : [], }) @@ -57,9 +57,9 @@ export function createTLStore({ }) } -function currentPageShapesToShapeMap(shapeUtils: TLShapeUtilConstructor[]) { +function utilsToMap(utils: T[]) { return Object.fromEntries( - shapeUtils.map((s): [string, SchemaShapeInfo] => [ + utils.map((s): [string, SchemaPropsInfo] => [ s.type, { props: s.props, diff --git a/packages/editor/src/lib/config/defaultBindings.ts b/packages/editor/src/lib/config/defaultBindings.ts new file mode 100644 index 000000000..19d2f7e09 --- /dev/null +++ b/packages/editor/src/lib/config/defaultBindings.ts @@ -0,0 +1,19 @@ +import { TLBindingUtilConstructor } from '../editor/bindings/BindingUtil' + +/** @public */ +export type TLAnyBindingUtilConstructor = TLBindingUtilConstructor + +export function checkBindings(customBindings: readonly TLAnyBindingUtilConstructor[]) { + const bindings = [] as TLAnyBindingUtilConstructor[] + + const addedCustomBindingTypes = new Set() + for (const customBinding of customBindings) { + if (addedCustomBindingTypes.has(customBinding.type)) { + throw new Error(`Binding type "${customBinding.type}" is defined more than once`) + } + bindings.push(customBinding) + addedCustomBindingTypes.add(customBinding.type) + } + + return bindings +} diff --git a/packages/editor/src/lib/editor/Editor.ts b/packages/editor/src/lib/editor/Editor.ts index 949ffecaf..2ff5e8925 100644 --- a/packages/editor/src/lib/editor/Editor.ts +++ b/packages/editor/src/lib/editor/Editor.ts @@ -1,15 +1,25 @@ import { EMPTY_ARRAY, atom, computed, transact } from '@tldraw/state' -import { ComputedCache, RecordType, StoreSnapshot } from '@tldraw/store' +import { + ComputedCache, + RecordType, + StoreSnapshot, + UnknownRecord, + reverseRecordsDiff, +} from '@tldraw/store' import { CameraRecordType, InstancePageStateRecordType, PageRecordType, StyleProp, StylePropValue, + TLArrowBinding, TLArrowShape, TLAsset, TLAssetId, TLAssetPartial, + TLBinding, + TLBindingId, + TLBindingPartial, TLCursor, TLCursorType, TLDOCUMENT_ID, @@ -31,8 +41,10 @@ import { TLShapeId, TLShapePartial, TLStore, + TLUnknownBinding, TLUnknownShape, TLVideoAsset, + createBindingId, createShapeId, getShapePropKeysByStyle, isPageId, @@ -42,8 +54,10 @@ import { IndexKey, JsonObject, PerformanceTracker, + Result, annotateError, assert, + assertExists, compact, dedupe, exhaustiveSwitchError, @@ -63,6 +77,7 @@ import { EventEmitter } from 'eventemitter3' import { flushSync } from 'react-dom' import { createRoot } from 'react-dom/client' import { TLUser, createTLUser } from '../config/createTLUser' +import { checkBindings } from '../config/defaultBindings' import { checkShapesAndAddCore } from '../config/defaultShapes' import { ANIMATION_MEDIUM_MS, @@ -89,7 +104,7 @@ import { STYLUS_ERASER_BUTTON, } from '../constants' import { Box, BoxLike } from '../primitives/Box' -import { Mat, MatLike, MatModel } from '../primitives/Mat' +import { Mat, MatLike } from '../primitives/Mat' import { Vec, VecLike } from '../primitives/Vec' import { EASINGS } from '../primitives/easings' import { Geometry2d } from '../primitives/geometry/Geometry2d' @@ -104,7 +119,7 @@ import { getIncrementedName } from '../utils/getIncrementedName' import { getReorderingShapesChanges } from '../utils/reorderShapes' import { applyRotationToSnapshotShapes, getRotationSnapshot } from '../utils/rotation' import { uniqueId } from '../utils/uniqueId' -import { arrowBindingsIndex } from './derivations/arrowBindingsIndex' +import { BindingUtil, TLBindingUtilConstructor } from './bindings/BindingUtil' import { notVisibleShapes } from './derivations/notVisibleShapes' import { parentsToChildren } from './derivations/parentsToChildren' import { deriveShapeIdsInCurrentPage } from './derivations/shapeIdsInCurrentPage' @@ -121,7 +136,7 @@ import { UserPreferencesManager } from './managers/UserPreferencesManager' import { ShapeUtil, TLResizeMode, TLShapeUtilConstructor } from './shapes/ShapeUtil' import { TLArrowInfo } from './shapes/shared/arrow/arrow-types' import { getCurvedArrowInfo } from './shapes/shared/arrow/curved-arrow' -import { getArrowTerminalsInArrowSpace, getIsArrowStraight } from './shapes/shared/arrow/shared' +import { getArrowBindings, getIsArrowStraight } from './shapes/shared/arrow/shared' import { getStraightArrowInfo } from './shapes/shared/arrow/straight-arrow' import { RootState } from './tools/RootState' import { StateNode, TLStateNodeConstructor } from './tools/StateNode' @@ -167,6 +182,10 @@ export interface TLEditorOptions { * An array of shapes to use in the editor. These will be used to create and manage shapes in the editor. */ shapeUtils: readonly TLShapeUtilConstructor[] + /** + * An array of bindings to use in the editor. These will be used to create and manage bindings in the editor. + */ + bindingUtils: readonly TLBindingUtilConstructor[] /** * An array of tools to use in the editor. These will be used to handle events and manage user interactions in the editor. */ @@ -200,6 +219,7 @@ export class Editor extends EventEmitter { store, user, shapeUtils, + bindingUtils, tools, getContainer, cameraOptions, @@ -262,6 +282,14 @@ export class Editor extends EventEmitter { this.shapeUtils = _shapeUtils this.styleProps = _styleProps + const allBindingUtils = checkBindings(bindingUtils) + const _bindingUtils = {} as Record> + for (const Util of allBindingUtils) { + const util = new Util(this) + _bindingUtils[Util.type] = util + } + this.bindingUtils = _bindingUtils + // Tools. // Accept tools from constructor parameters which may not conflict with the root note's default or // "baked in" tools, select and zoom. @@ -279,119 +307,6 @@ export class Editor extends EventEmitter { const invalidParents = new Set() - const reparentArrow = (arrowId: TLArrowShape['id']) => { - const arrow = this.getShape(arrowId) - if (!arrow) return - const { start, end } = arrow.props - const startShape = start.type === 'binding' ? this.getShape(start.boundShapeId) : undefined - const endShape = end.type === 'binding' ? this.getShape(end.boundShapeId) : undefined - - const parentPageId = this.getAncestorPageId(arrow) - if (!parentPageId) return - - let nextParentId: TLParentId - if (startShape && endShape) { - // if arrow has two bindings, always parent arrow to closest common ancestor of the bindings - nextParentId = this.findCommonAncestor([startShape, endShape]) ?? parentPageId - } else if (startShape || endShape) { - const bindingParentId = (startShape || endShape)?.parentId - // If the arrow and the shape that it is bound to have the same parent, then keep that parent - if (bindingParentId && bindingParentId === arrow.parentId) { - nextParentId = arrow.parentId - } else { - // if arrow has one binding, keep arrow on its own page - nextParentId = parentPageId - } - } else { - return - } - - if (nextParentId && nextParentId !== arrow.parentId) { - this.reparentShapes([arrowId], nextParentId) - } - - const reparentedArrow = this.getShape(arrowId) - if (!reparentedArrow) throw Error('no reparented arrow') - - const startSibling = this.getShapeNearestSibling(reparentedArrow, startShape) - const endSibling = this.getShapeNearestSibling(reparentedArrow, endShape) - - let highestSibling: TLShape | undefined - - if (startSibling && endSibling) { - highestSibling = startSibling.index > endSibling.index ? startSibling : endSibling - } else if (startSibling && !endSibling) { - highestSibling = startSibling - } else if (endSibling && !startSibling) { - highestSibling = endSibling - } else { - return - } - - let finalIndex: IndexKey - - const higherSiblings = this.getSortedChildIdsForParent(highestSibling.parentId) - .map((id) => this.getShape(id)!) - .filter((sibling) => sibling.index > highestSibling!.index) - - if (higherSiblings.length) { - // there are siblings above the highest bound sibling, we need to - // insert between them. - - // if the next sibling is also a bound arrow though, we can end up - // all fighting for the same indexes. so lets find the next - // non-arrow sibling... - const nextHighestNonArrowSibling = higherSiblings.find( - (sibling) => sibling.type !== 'arrow' - ) - - if ( - // ...then, if we're above the last shape we want to be above... - reparentedArrow.index > highestSibling.index && - // ...but below the next non-arrow sibling... - (!nextHighestNonArrowSibling || reparentedArrow.index < nextHighestNonArrowSibling.index) - ) { - // ...then we're already in the right place. no need to update! - return - } - - // otherwise, we need to find the index between the highest sibling - // we want to be above, and the next highest sibling we want to be - // below: - finalIndex = getIndexBetween(highestSibling.index, higherSiblings[0].index) - } else { - // if there are no siblings above us, we can just get the next index: - finalIndex = getIndexAbove(highestSibling.index) - } - - if (finalIndex !== reparentedArrow.index) { - this.updateShapes([{ id: arrowId, type: 'arrow', index: finalIndex }]) - } - } - - const unbindArrowTerminal = (arrow: TLArrowShape, handleId: 'start' | 'end') => { - const { x, y } = getArrowTerminalsInArrowSpace(this, arrow)[handleId] - this.store.put([{ ...arrow, props: { ...arrow.props, [handleId]: { type: 'point', x, y } } }]) - } - - const arrowDidUpdate = (arrow: TLArrowShape) => { - // if the shape is an arrow and its bound shape is on another page - // or was deleted, unbind it - for (const handle of ['start', 'end'] as const) { - const terminal = arrow.props[handle] - if (terminal.type !== 'binding') continue - const boundShape = this.getShape(terminal.boundShapeId) - const isShapeInSamePageAsArrow = - this.getAncestorPageId(arrow) === this.getAncestorPageId(boundShape) - if (!boundShape || !isShapeInSamePageAsArrow) { - unbindArrowTerminal(arrow, handle) - } - } - - // always check the arrow parents - reparentArrow(arrow.id) - } - const cleanupInstancePageState = ( prevPageState: TLInstancePageState, shapesNoLongerInPage: Set @@ -463,39 +378,56 @@ export class Editor extends EventEmitter { this.disposables.add( this.sideEffects.register({ shape: { - afterCreate: (record) => { - if (this.isShapeOfType(record, 'arrow')) { - arrowDidUpdate(record) + afterChange: (shapeBefore, shapeAfter) => { + for (const binding of this.getAllBindingsFromShape(shapeAfter)) { + this.getBindingUtil(binding).onAfterChangeFromShape?.({ + binding, + shapeBefore, + shapeAfter, + }) } - }, - afterChange: (prev, next) => { - if (this.isShapeOfType(next, 'arrow')) { - arrowDidUpdate(next) + for (const binding of this.getAllBindingsToShape(shapeAfter)) { + this.getBindingUtil(binding).onAfterChangeToShape?.({ + binding, + shapeBefore, + shapeAfter, + }) } - // if the shape's parent changed and it is bound to an arrow, update the arrow's parent - if (prev.parentId !== next.parentId) { - const reparentBoundArrows = (id: TLShapeId) => { - const boundArrows = this._getArrowBindingsIndex().get()[id] - if (boundArrows?.length) { - for (const arrow of boundArrows) { - reparentArrow(arrow.arrowId) - } + // if the shape's parent changed and it has a binding, update the binding + if (shapeBefore.parentId !== shapeAfter.parentId) { + const notifyBindingAncestryChange = (id: TLShapeId) => { + const descendantShape = this.getShape(id) + if (!descendantShape) return + + for (const binding of this.getAllBindingsFromShape(descendantShape)) { + this.getBindingUtil(binding).onAfterChangeFromShape?.({ + binding, + shapeBefore: descendantShape, + shapeAfter: descendantShape, + }) + } + for (const binding of this.getAllBindingsToShape(descendantShape)) { + this.getBindingUtil(binding).onAfterChangeToShape?.({ + binding, + shapeBefore: descendantShape, + shapeAfter: descendantShape, + }) } } - reparentBoundArrows(next.id) - this.visitDescendants(next.id, reparentBoundArrows) + notifyBindingAncestryChange(shapeAfter.id) + this.visitDescendants(shapeAfter.id, notifyBindingAncestryChange) } // if this shape moved to a new page, clean up any previous page's instance state - if (prev.parentId !== next.parentId && isPageId(next.parentId)) { - const allMovingIds = new Set([prev.id]) - this.visitDescendants(prev.id, (id) => { + if (shapeBefore.parentId !== shapeAfter.parentId && isPageId(shapeAfter.parentId)) { + const allMovingIds = new Set([shapeBefore.id]) + this.visitDescendants(shapeBefore.id, (id) => { allMovingIds.add(id) }) for (const instancePageState of this.getPageStates()) { - if (instancePageState.pageId === next.parentId) continue + if (instancePageState.pageId === shapeAfter.parentId) continue const nextPageState = cleanupInstancePageState(instancePageState, allMovingIds) if (nextPageState) { @@ -504,29 +436,32 @@ export class Editor extends EventEmitter { } } - if (prev.parentId && isShapeId(prev.parentId)) { - invalidParents.add(prev.parentId) + if (shapeBefore.parentId && isShapeId(shapeBefore.parentId)) { + invalidParents.add(shapeBefore.parentId) } - if (next.parentId !== prev.parentId && isShapeId(next.parentId)) { - invalidParents.add(next.parentId) + if (shapeAfter.parentId !== shapeBefore.parentId && isShapeId(shapeAfter.parentId)) { + invalidParents.add(shapeAfter.parentId) } }, - beforeDelete: (record) => { + beforeDelete: (shape) => { // if the deleted shape has a parent shape make sure we call it's onChildrenChange callback - if (record.parentId && isShapeId(record.parentId)) { - invalidParents.add(record.parentId) + if (shape.parentId && isShapeId(shape.parentId)) { + invalidParents.add(shape.parentId) } - // clean up any arrows bound to this shape - const bindings = this._getArrowBindingsIndex().get()[record.id] - if (bindings?.length) { - for (const { arrowId, handleId } of bindings) { - const arrow = this.getShape(arrowId) - if (!arrow) continue - unbindArrowTerminal(arrow, handleId) - } + + const deleteBindingIds: TLBindingId[] = [] + for (const binding of this.getAllBindingsFromShape(shape)) { + this.getBindingUtil(binding).onBeforeDeleteFromShape?.({ binding, shape }) + deleteBindingIds.push(binding.id) } - const deletedIds = new Set([record.id]) + for (const binding of this.getAllBindingsToShape(shape)) { + this.getBindingUtil(binding).onBeforeDeleteToShape?.({ binding, shape }) + deleteBindingIds.push(binding.id) + } + this.deleteBindings(deleteBindingIds) + + const deletedIds = new Set([shape.id]) const updates = compact( this.getPageStates().map((pageState) => { return cleanupInstancePageState(pageState, deletedIds) @@ -538,6 +473,33 @@ export class Editor extends EventEmitter { } }, }, + binding: { + beforeCreate: (binding) => { + const next = this.getBindingUtil(binding).onBeforeCreate?.({ binding }) + if (next) return next + return binding + }, + afterCreate: (binding) => { + this.getBindingUtil(binding).onAfterCreate?.({ binding }) + }, + beforeChange: (bindingBefore, bindingAfter) => { + const updated = this.getBindingUtil(bindingAfter).onBeforeChange?.({ + bindingBefore, + bindingAfter, + }) + if (updated) return updated + return bindingAfter + }, + afterChange: (bindingBefore, bindingAfter) => { + this.getBindingUtil(bindingAfter).onAfterChange?.({ bindingBefore, bindingAfter }) + }, + beforeDelete: (binding) => { + this.getBindingUtil(binding).onBeforeDelete?.({ binding }) + }, + afterDelete: (binding) => { + this.getBindingUtil(binding).onAfterDelete?.({ binding }) + }, + }, page: { afterCreate: (record) => { const cameraId = CameraRecordType.createId(record.id) @@ -806,6 +768,41 @@ export class Editor extends EventEmitter { return shapeUtil } + /* ------------------- Binding Utils ------------------ */ + /** + * A map of shape utility classes (TLShapeUtils) by shape type. + * + * @public + */ + bindingUtils: { readonly [K in string]?: BindingUtil } + + /** + * Get a binding util from a binding itself. + * + * @example + * ```ts + * const util = editor.getBindingUtil(myArrowBinding) + * const util = editor.getBindingUtil('arrow') + * const util = editor.getBindingUtil(myArrowBinding) + * const util = editor.getBindingUtil(TLArrowBinding)('arrow') + * ``` + * + * @param binding - A binding, binding partial, or binding type. + * + * @public + */ + getBindingUtil(binding: S | TLBindingPartial): BindingUtil + getBindingUtil(type: S['type']): BindingUtil + getBindingUtil( + type: T extends BindingUtil ? R['type'] : string + ): T + getBindingUtil(arg: string | { type: string }) { + const type = typeof arg === 'string' ? arg : arg.type + const bindingUtil = getOwnProperty(this.bindingUtils, type) + assert(bindingUtil, `No binding util found for type "${type}"`) + return bindingUtil + } + /* --------------------- History -------------------- */ /** @@ -927,12 +924,6 @@ export class Editor extends EventEmitter { /* --------------------- Arrows --------------------- */ // todo: move these to tldraw or replace with a bindings API - /** @internal */ - @computed - private _getArrowBindingsIndex() { - return arrowBindingsIndex(this) - } - /** * Get all arrows bound to a shape. * @@ -941,15 +932,19 @@ export class Editor extends EventEmitter { * @public */ getArrowsBoundTo(shapeId: TLShapeId) { - return this._getArrowBindingsIndex().get()[shapeId] || EMPTY_ARRAY + const ids = new Set( + this.getBindingsToShape(shapeId, 'arrow').map((b) => b.fromId) + ) + return compact(Array.from(ids, (id) => this.getShape(id))) } @computed private getArrowInfoCache() { return this.store.createComputedCache('arrow infoCache', (shape) => { + const bindings = getArrowBindings(this, shape) return getIsArrowStraight(shape) - ? getStraightArrowInfo(this, shape) - : getCurvedArrowInfo(this, shape) + ? getStraightArrowInfo(this, shape, bindings) + : getCurvedArrowInfo(this, shape, bindings) }) } @@ -4667,7 +4662,7 @@ export class Editor extends EventEmitter { * * @internal */ - private getShapeNearestSibling( + getShapeNearestSibling( siblingShape: TLShape, targetShape: TLShape | undefined ): TLShape | undefined { @@ -4935,7 +4930,7 @@ export class Editor extends EventEmitter { } /** - * Get the shape ids of all descendants of the given shapes (including the shapes themselves). + * Get the shape ids of all descendants of the given shapes (including the shapes themselves). IDs are returned in z-index order. * * @param ids - The ids of the shapes to get descendants of. * @@ -4944,21 +4939,14 @@ export class Editor extends EventEmitter { * @public */ getShapeAndDescendantIds(ids: TLShapeId[]): Set { - const idsToInclude = new Set() - - const idsToCheck = [...ids] - - while (idsToCheck.length > 0) { - const id = idsToCheck.pop() - if (!id) break - if (idsToInclude.has(id)) continue - idsToInclude.add(id) - for (const childId of this.getSortedChildIdsForParent(id)) { - idsToCheck.push(childId) - } + const shapeIds = new Set() + for (const shape of ids.map((id) => this.getShape(id)!).sort(sortByIndex)) { + shapeIds.add(shape.id) + this.visitDescendants(shape, (descendantId) => { + shapeIds.add(descendantId) + }) } - - return idsToInclude + return shapeIds } /** @@ -5042,6 +5030,99 @@ export class Editor extends EventEmitter { return match } + /* -------------------- Bindings -------------------- */ + + getBinding(id: TLBindingId): TLBinding | undefined { + return this.store.get(id) as TLBinding | undefined + } + + // TODO(alex) #bindings - cache `allBindings` getters and derive type-specific ones from them + getBindingsFromShape( + shape: TLShape | TLShapeId, + type: Binding['type'] + ): Binding[] { + const id = typeof shape === 'string' ? shape : shape.id + return this.store.query.exec('binding', { + fromId: { eq: id }, + type: { eq: type }, + }) as Binding[] + } + getBindingsToShape( + shape: TLShape | TLShapeId, + type: Binding['type'] + ): Binding[] { + const id = typeof shape === 'string' ? shape : shape.id + return this.store.query.exec('binding', { + toId: { eq: id }, + type: { eq: type }, + }) as Binding[] + } + getAllBindingsFromShape(shape: TLShape | TLShapeId): TLBinding[] { + const id = typeof shape === 'string' ? shape : shape.id + return this.store.query.exec('binding', { + fromId: { eq: id }, + }) + } + getAllBindingsToShape(shape: TLShape | TLShapeId): TLBinding[] { + const id = typeof shape === 'string' ? shape : shape.id + return this.store.query.exec('binding', { + toId: { eq: id }, + }) + } + + createBindings(partials: RequiredKeys[]) { + const bindings = partials.map((partial) => { + const util = this.getBindingUtil(partial.type) + const defaultProps = util.getDefaultProps() + return this.store.schema.types.binding.create({ + ...partial, + id: partial.id ?? createBindingId(), + props: { + ...defaultProps, + ...partial.props, + }, + }) + }) + this.store.put(bindings) + return this + } + createBinding(partial: RequiredKeys) { + return this.createBindings([partial]) + } + + updateBindings(partials: (TLBindingPartial | null | undefined)[]) { + const updated: TLBinding[] = [] + + for (const partial of partials) { + if (!partial) continue + + const current = this.getBinding(partial.id) + if (!current) continue + + const updatedBinding = applyPartialToRecordWithProps(current, partial) + if (updatedBinding === current) continue + + updated.push(updatedBinding) + } + + this.store.put(updated) + + return this + } + + updateBinding(partial: TLBindingPartial) { + return this.updateBindings([partial]) + } + + deleteBindings(bindings: (TLBinding | TLBindingId)[]) { + const ids = bindings.map((binding) => (typeof binding === 'string' ? binding : binding.id)) + this.store.remove(ids) + return this + } + deleteBinding(binding: TLBinding | TLBindingId) { + return this.deleteBindings([binding]) + } + /* -------------------- Commands -------------------- */ /** @@ -5076,24 +5157,24 @@ export class Editor extends EventEmitter { let workingShape = initialShape const util = this.getShapeUtil(initialShape) - workingShape = applyPartialToShape( + workingShape = applyPartialToRecordWithProps( workingShape, util.onTranslateStart?.(workingShape) ?? undefined ) - workingShape = applyPartialToShape(workingShape, { + workingShape = applyPartialToRecordWithProps(workingShape, { id: initialShape.id, type: initialShape.type, x: newShapeCoords.x, y: newShapeCoords.y, }) - workingShape = applyPartialToShape( + workingShape = applyPartialToRecordWithProps( workingShape, util.onTranslate?.(initialShape, workingShape) ?? undefined ) - workingShape = applyPartialToShape( + workingShape = applyPartialToRecordWithProps( workingShape, util.onTranslateEnd?.(initialShape, workingShape) ?? undefined ) @@ -5151,162 +5232,92 @@ export class Editor extends EventEmitter { * @public */ duplicateShapes(shapes: TLShapeId[] | TLShape[], offset?: VecLike): this { - const ids = - typeof shapes[0] === 'string' - ? (shapes as TLShapeId[]) - : (shapes as TLShape[]).map((s) => s.id) - - if (ids.length <= 0) return this - - const maxShapesReached = - shapes.length + this.getCurrentPageShapeIds().size > MAX_SHAPES_PER_PAGE - - if (maxShapesReached) { - alertMaxShapes(this) - return this - } - - const initialIds = new Set(ids) - const idsToCreate: TLShapeId[] = [] - const idsToCheck = [...ids] - - while (idsToCheck.length > 0) { - const id = idsToCheck.pop() - if (!id) break - idsToCreate.push(id) - this.getSortedChildIdsForParent(id).forEach((childId) => idsToCheck.push(childId)) - } - - idsToCreate.reverse() - - const idsMap = new Map(idsToCreate.map((id) => [id, createShapeId()])) - - const shapesToCreate = compact( - idsToCreate.map((id) => { - const shape = this.getShape(id) - - if (!shape) { - return null - } - - const createId = idsMap.get(id)! - - let ox = 0 - let oy = 0 - - if (offset && initialIds.has(id)) { - const parentTransform = this.getShapeParentTransform(shape) - const vec = new Vec(offset.x, offset.y).rot(-parentTransform!.rotation()) - ox = vec.x - oy = vec.y - } - - const parentId = shape.parentId ?? this.getCurrentPageId() - const siblings = this.getSortedChildIdsForParent(parentId) - const currentIndex = siblings.indexOf(shape.id) - const siblingAboveId = siblings[currentIndex + 1] - const siblingAbove = siblingAboveId ? this.getShape(siblingAboveId) : null - - const index = siblingAbove - ? getIndexBetween(shape.index, siblingAbove.index) - : getIndexAbove(shape.index) - - let newShape: TLShape = structuredClone(shape) - - if ( - this.isShapeOfType(shape, 'arrow') && - this.isShapeOfType(newShape, 'arrow') - ) { - const info = this.getArrowInfo(shape) - let newStartShapeId: TLShapeId | undefined = undefined - let newEndShapeId: TLShapeId | undefined = undefined - - if (shape.props.start.type === 'binding') { - newStartShapeId = idsMap.get(shape.props.start.boundShapeId) - - if (!newStartShapeId) { - if (info?.isValid) { - const { x, y } = info.start.point - newShape.props.start = { - type: 'point', - x, - y, - } - } else { - const { start } = getArrowTerminalsInArrowSpace(this, shape) - newShape.props.start = { - type: 'point', - x: start.x, - y: start.y, - } - } - } - } - - if (shape.props.end.type === 'binding') { - newEndShapeId = idsMap.get(shape.props.end.boundShapeId) - if (!newEndShapeId) { - if (info?.isValid) { - const { x, y } = info.end.point - newShape.props.end = { - type: 'point', - x, - y, - } - } else { - const { end } = getArrowTerminalsInArrowSpace(this, shape) - newShape.props.start = { - type: 'point', - x: end.x, - y: end.y, - } - } - } - } - - const infoAfter = getIsArrowStraight(newShape) - ? getStraightArrowInfo(this, newShape) - : getCurvedArrowInfo(this, newShape) - - if (info?.isValid && infoAfter?.isValid && !getIsArrowStraight(shape)) { - const mpA = Vec.Med(info.start.handle, info.end.handle) - const distA = Vec.Dist(info.middle, mpA) - const distB = Vec.Dist(infoAfter.middle, mpA) - if (newShape.props.bend < 0) { - newShape.props.bend += distB - distA - } else { - newShape.props.bend -= distB - distA - } - } - - if (newShape.props.start.type === 'binding' && newStartShapeId) { - newShape.props.start.boundShapeId = newStartShapeId - } - - if (newShape.props.end.type === 'binding' && newEndShapeId) { - newShape.props.end.boundShapeId = newEndShapeId - } - } - - newShape = { ...newShape, id: createId, x: shape.x + ox, y: shape.y + oy, index } - - return newShape - }) - ) - - shapesToCreate.forEach((shape) => { - if (isShapeId(shape.parentId)) { - if (idsMap.has(shape.parentId)) { - shape.parentId = idsMap.get(shape.parentId)! - } - } - }) - this.history.batch(() => { - const ids = shapesToCreate.map((s) => s.id) + const ids = + typeof shapes[0] === 'string' + ? (shapes as TLShapeId[]) + : (shapes as TLShape[]).map((s) => s.id) + + if (ids.length <= 0) return this + + const initialIds = new Set(ids) + const shapeIdSet = this.getShapeAndDescendantIds(ids) + + const orderedShapeIds = [...shapeIdSet].reverse() + const shapeIds = new Map() + for (const shapeId of shapeIdSet) { + shapeIds.set(shapeId, createShapeId()) + } + + const { shapesToCreate, bindingsToCreate } = withoutBindingsToUnrelatedShapes( + this, + shapeIdSet, + (bindingIdsToMaintain) => { + const bindingsToCreate: TLBinding[] = [] + for (const originalId of bindingIdsToMaintain) { + const originalBinding = this.getBinding(originalId) + if (!originalBinding) continue + + const duplicatedId = createBindingId() + bindingsToCreate.push({ + ...originalBinding, + id: duplicatedId, + fromId: assertExists(shapeIds.get(originalBinding.fromId)), + toId: assertExists(shapeIds.get(originalBinding.toId)), + }) + } + + const shapesToCreate: TLShape[] = [] + for (const originalId of orderedShapeIds) { + const duplicatedId = assertExists(shapeIds.get(originalId)) + const originalShape = this.getShape(originalId) + if (!originalShape) continue + + let ox = 0 + let oy = 0 + + if (offset && initialIds.has(originalId)) { + const parentTransform = this.getShapeParentTransform(originalShape) + const vec = new Vec(offset.x, offset.y).rot(-parentTransform!.rotation()) + ox = vec.x + oy = vec.y + } + + const parentId = originalShape.parentId + const siblings = this.getSortedChildIdsForParent(parentId) + const currentIndex = siblings.indexOf(originalShape.id) + const siblingAboveId = siblings[currentIndex + 1] + const siblingAbove = siblingAboveId ? this.getShape(siblingAboveId) : null + + const index = siblingAbove + ? getIndexBetween(originalShape.index, siblingAbove.index) + : getIndexAbove(originalShape.index) + + shapesToCreate.push({ + ...originalShape, + id: duplicatedId, + x: originalShape.x + ox, + y: originalShape.y + oy, + index, + parentId: shapeIds.get(originalShape.parentId as TLShapeId) ?? originalShape.parentId, + }) + } + + return { shapesToCreate, bindingsToCreate } + } + ) + + const maxShapesReached = + shapesToCreate.length + this.getCurrentPageShapeIds().size > MAX_SHAPES_PER_PAGE + + if (maxShapesReached) { + alertMaxShapes(this) + return + } this.createShapes(shapesToCreate) - this.setSelectedShapes(ids) + this.createBindings(bindingsToCreate) + this.setSelectedShapes(compact(ids.map((id) => shapeIds.get(id)))) if (offset !== undefined) { // If we've offset the duplicated shapes, check to see whether their new bounds is entirely @@ -5628,21 +5639,13 @@ export class Editor extends EventEmitter { : (shapes as TLShape[]).map((s) => s.id) if (this.getInstanceState().isReadonly) return this - const shapesToStack = compact( - ids - .map((id) => this.getShape(id)) // always fresh shapes - .filter((shape) => { - if (!shape) return false + const shapesToStack = ids + .map((id) => this.getShape(id)) // always fresh shapes + .filter((shape): shape is TLShape => { + if (!shape) return false - if (this.isShapeOfType(shape, 'arrow')) { - if (shape.props.start.type === 'binding' || shape.props.end.type === 'binding') { - return false - } - } - - return true - }) - ) + return this.getShapeUtil(shape).canBeLaidOut(shape) + }) const len = shapesToStack.length @@ -5774,21 +5777,13 @@ export class Editor extends EventEmitter { if (this.getInstanceState().isReadonly) return this if (ids.length < 2) return this - const shapesToPack = compact( - ids - .map((id) => this.getShape(id)) // always fresh shapes - .filter((shape) => { - if (!shape) return false + const shapesToPack = ids + .map((id) => this.getShape(id)) // always fresh shapes + .filter((shape): shape is TLShape => { + if (!shape) return false - if (this.isShapeOfType(shape, 'arrow')) { - if (shape.props.start.type === 'binding' || shape.props.end.type === 'binding') { - return false - } - } - - return true - }) - ) + return this.getShapeUtil(shape).canBeLaidOut(shape) + }) const shapePageBounds: Record = {} const nextShapePageBounds: Record = {} @@ -6969,7 +6964,7 @@ export class Editor extends EventEmitter { // Get the updated version of the shape // If the update had no effect, we'll skip this update - updated = applyPartialToShape(shape, partial) + updated = applyPartialToRecordWithProps(shape, partial) if (updated === shape) continue //if any shape has an onBeforeUpdate handler, call it and, if the handler returns a @@ -7466,134 +7461,63 @@ export class Editor extends EventEmitter { if (!ids) return if (ids.length === 0) return - const pageTransforms: Record = {} + const shapeIds = this.getShapeAndDescendantIds(ids) - let shapesForContent = dedupe( - ids - .map((id) => this.getShape(id)!) - .sort(sortByIndex) - .flatMap((shape) => { - const allShapes = [shape] - this.visitDescendants(shape.id, (descendant) => { - allShapes.push(this.getShape(descendant)!) + return withoutBindingsToUnrelatedShapes(this, shapeIds, (bindingIdsToKeep) => { + const bindings: TLBinding[] = [] + for (const id of bindingIdsToKeep) { + const binding = this.getBinding(id) + if (!binding) continue + bindings.push(binding) + } + + const rootShapeIds: TLShapeId[] = [] + const shapes: TLShape[] = [] + for (const shapeId of shapeIds) { + const shape = this.getShape(shapeId) + if (!shape) continue + + const isRootShape = !shapeIds.has(shape.parentId as TLShapeId) + if (isRootShape) { + // Need to get page point and rotation of the shape because shapes in + // groups use local position/rotation + const pageTransform = this.getShapePageTransform(shape.id)! + const pagePoint = pageTransform.point() + shapes.push({ + ...shape, + x: pagePoint.x, + y: pagePoint.y, + rotation: pageTransform.rotation(), + parentId: this.getCurrentPageId(), }) - return allShapes - }) - ) - - shapesForContent = shapesForContent.map((shape) => { - pageTransforms[shape.id] = this.getShapePageTransform(shape.id)! - - shape = structuredClone(shape) as typeof shape - - if (this.isShapeOfType(shape, 'arrow')) { - const startBindingId = - shape.props.start.type === 'binding' ? shape.props.start.boundShapeId : undefined - - const endBindingId = - shape.props.end.type === 'binding' ? shape.props.end.boundShapeId : undefined - - const info = this.getArrowInfo(shape) - - if (shape.props.start.type === 'binding') { - if (!shapesForContent.some((s) => s.id === startBindingId)) { - // Uh oh, the arrow's bound-to shape isn't among the shapes - // that we're getting the content for. We should try to adjust - // the arrow so that it appears in the place it would be - if (info?.isValid) { - const { x, y } = info.start.point - shape.props.start = { - type: 'point', - x, - y, - } - } else { - const { start } = getArrowTerminalsInArrowSpace(this, shape) - shape.props.start = { - type: 'point', - x: start.x, - y: start.y, - } - } - } - } - - if (shape.props.end.type === 'binding') { - if (!shapesForContent.some((s) => s.id === endBindingId)) { - if (info?.isValid) { - const { x, y } = info.end.point - shape.props.end = { - type: 'point', - x, - y, - } - } else { - const { end } = getArrowTerminalsInArrowSpace(this, shape) - shape.props.end = { - type: 'point', - x: end.x, - y: end.y, - } - } - } - } - - const infoAfter = getIsArrowStraight(shape) - ? getStraightArrowInfo(this, shape) - : getCurvedArrowInfo(this, shape) - - if (info?.isValid && infoAfter?.isValid && !getIsArrowStraight(shape)) { - const mpA = Vec.Med(info.start.handle, info.end.handle) - const distA = Vec.Dist(info.middle, mpA) - const distB = Vec.Dist(infoAfter.middle, mpA) - if (shape.props.bend < 0) { - shape.props.bend += distB - distA - } else { - shape.props.bend -= distB - distA - } - } - - return shape - } - - return shape - }) - - const rootShapeIds: TLShapeId[] = [] - - shapesForContent.forEach((shape) => { - if (shapesForContent.find((s) => s.id === shape.parentId) === undefined) { - // Need to get page point and rotation of the shape because shapes in - // groups use local position/rotation - - const pageTransform = this.getShapePageTransform(shape.id)! - const pagePoint = pageTransform.point() - const pageRotation = pageTransform.rotation() - shape.x = pagePoint.x - shape.y = pagePoint.y - shape.rotation = pageRotation - shape.parentId = this.getCurrentPageId() - - rootShapeIds.push(shape.id) - } - }) - - const assetsSet = new Set() - - shapesForContent.forEach((shape) => { - if ('assetId' in shape.props) { - if (shape.props.assetId !== null) { - assetsSet.add(shape.props.assetId) + rootShapeIds.push(shape.id) + } else { + shapes.push(shape) } } - }) - return { - shapes: shapesForContent, - rootShapeIds, - schema: this.store.schema.serialize(), - assets: compact(Array.from(assetsSet).map((id) => this.getAsset(id))), - } + const assets: TLAsset[] = [] + const seenAssetIds = new Set() + for (const shape of shapes) { + if (!('assetId' in shape.props)) continue + + const assetId = shape.props.assetId + if (!assetId || seenAssetIds.has(assetId)) continue + + seenAssetIds.add(assetId) + const asset = this.getAsset(assetId) + if (!asset) continue + assets.push(asset) + } + + return { + schema: this.store.schema.serialize(), + shapes, + rootShapeIds, + bindings, + assets, + } + }) } /** @@ -7629,15 +7553,19 @@ export class Editor extends EventEmitter { const currentPageId = this.getCurrentPageId() const { rootShapeIds } = content - // We need to collect the migrated shapes and assets + // We need to collect the migrated records const assets: TLAsset[] = [] const shapes: TLShape[] = [] + const bindings: TLBinding[] = [] // Let's treat the content as a store, and then migrate that store. const store: StoreSnapshot = { store: { ...Object.fromEntries(content.assets.map((asset) => [asset.id, asset] as const)), - ...Object.fromEntries(content.shapes.map((asset) => [asset.id, asset] as const)), + ...Object.fromEntries(content.shapes.map((shape) => [shape.id, shape] as const)), + ...Object.fromEntries( + content.bindings?.map((bindings) => [bindings.id, bindings] as const) ?? [] + ), }, schema: content.schema, } @@ -7655,11 +7583,24 @@ export class Editor extends EventEmitter { shapes.push(record) break } + case 'binding': { + bindings.push(record) + break + } } } - // Ok, we've got our migrated shapes and assets, now we can continue! - const idMap = new Map(shapes.map((shape) => [shape.id, createShapeId()])) + // Ok, we've got our migrated records, now we can continue! + const shapeIdMap = new Map( + preserveIds + ? shapes.map((shape) => [shape.id, shape.id]) + : shapes.map((shape) => [shape.id, createShapeId()]) + ) + const bindingIdMap = new Map( + preserveIds + ? bindings.map((binding) => [binding.id, binding.id]) + : bindings.map((binding) => [binding.id, createBindingId()]) + ) // By default, the paste parent will be the current page. let pasteParentId = this.getCurrentPageId() as TLPageId | TLShapeId @@ -7724,7 +7665,7 @@ export class Editor extends EventEmitter { } if (!isDuplicating) { - isDuplicating = idMap.has(pasteParentId) + isDuplicating = shapeIdMap.has(pasteParentId) } if (isDuplicating) { @@ -7735,20 +7676,13 @@ export class Editor extends EventEmitter { const rootShapes: TLShape[] = [] - const newShapes: TLShape[] = shapes.map((shape): TLShape => { - let newShape: TLShape + const newShapes: TLShape[] = shapes.map((oldShape): TLShape => { + const newId = shapeIdMap.get(oldShape.id)! - if (preserveIds) { - newShape = structuredClone(shape) - idMap.set(shape.id, shape.id) - } else { - const id = idMap.get(shape.id)! + // Create the new shape (new except for the id) + const newShape = { ...oldShape, id: newId } - // Create the new shape (new except for the id) - newShape = structuredClone({ ...shape, id }) - } - - if (rootShapeIds.includes(shape.id)) { + if (rootShapeIds.includes(oldShape.id)) { newShape.parentId = currentPageId rootShapes.push(newShape) } @@ -7757,8 +7691,8 @@ export class Editor extends EventEmitter { // If the child's parent is among the putting shapes, then assign // it to the new parent's id. - if (idMap.has(newShape.parentId)) { - newShape.parentId = idMap.get(shape.parentId)! + if (shapeIdMap.has(newShape.parentId)) { + newShape.parentId = shapeIdMap.get(oldShape.parentId)! } else { rootShapeIds.push(newShape.id) // newShape.parentId = pasteParentId @@ -7766,25 +7700,6 @@ export class Editor extends EventEmitter { index = getIndexAbove(index) } - if (this.isShapeOfType(newShape, 'arrow')) { - if (newShape.props.start.type === 'binding') { - const mappedId = idMap.get(newShape.props.start.boundShapeId) - newShape.props.start = mappedId - ? { ...newShape.props.start, boundShapeId: mappedId } - : // this shouldn't happen, if you copy an arrow but not it's bound shape it should - // convert the binding to a point at the time of copying - { type: 'point', x: 0, y: 0 } - } - if (newShape.props.end.type === 'binding') { - const mappedId = idMap.get(newShape.props.end.boundShapeId) - newShape.props.end = mappedId - ? { ...newShape.props.end, boundShapeId: mappedId } - : // this shouldn't happen, if you copy an arrow but not it's bound shape it should - // convert the binding to a point at the time of copying - { type: 'point', x: 0, y: 0 } - } - } - return newShape }) @@ -7796,6 +7711,15 @@ export class Editor extends EventEmitter { return this } + const newBindings = bindings.map( + (oldBinding): TLBinding => ({ + ...oldBinding, + id: assertExists(bindingIdMap.get(oldBinding.id)), + fromId: assertExists(shapeIdMap.get(oldBinding.fromId)), + toId: assertExists(shapeIdMap.get(oldBinding.toId)), + }) + ) + // These are all the assets we need to create const assetsToCreate: TLAsset[] = [] @@ -7856,6 +7780,7 @@ export class Editor extends EventEmitter { // Create the shapes with root shapes as children of the page this.createShapes(newShapes) + this.createBindings(newBindings) if (select) { this.select(...rootShapes.map((s) => s.id)) @@ -8752,7 +8677,9 @@ function alertMaxShapes(editor: Editor, pageId = editor.getCurrentPageId()) { editor.emit('max-shapes', { name, pageId, count: MAX_SHAPES_PER_PAGE }) } -function applyPartialToShape(prev: T, partial?: TLShapePartial): T { +function applyPartialToRecordWithProps< + T extends UnknownRecord & { type: string; props: object; meta: object }, +>(prev: T, partial?: Partial & { props?: Partial }): T { if (!partial) return prev let next = null as null | T const entries = Object.entries(partial) @@ -8797,6 +8724,92 @@ function pushShapeWithDescendants(editor: Editor, id: TLShapeId, result: TLShape } } +/** + * Run `callback` in a world where all bindings from the shapes in `shapeIds` to shapes not in + * `shapeIds` are removed. This is useful when you want to duplicate/copy shapes without worrying + * about bindings that might be pointing to shapes that are not being duplicated. + * + * The callback is given the set of bindings that should be maintained. + */ +function withoutBindingsToUnrelatedShapes( + editor: Editor, + shapeIds: Set, + callback: (bindingsWithBoth: Set) => T +): T { + const bindingsWithBoth = new Set() + const bindingsWithoutFrom = new Set() + const bindingsWithoutTo = new Set() + + for (const shapeId of shapeIds) { + const shape = editor.getShape(shapeId) + if (!shape) continue + + for (const binding of editor.getAllBindingsFromShape(shapeId)) { + if (shapeIds.has(binding.toId)) { + // if we have both sides of the binding, we want to recreate it + bindingsWithBoth.add(binding.id) + } else { + // otherwise, if we only have one side, we need to record that and duplicate + // the shape as if the one it's bound to has been deleted + bindingsWithoutTo.add(binding.id) + } + } + for (const binding of editor.getAllBindingsToShape(shapeId)) { + if (shapeIds.has(binding.fromId)) { + bindingsWithBoth.add(binding.id) + } else { + bindingsWithoutFrom.add(binding.id) + } + } + } + + let result!: Result + + editor.history.ignore(() => { + const changes = editor.store.extractingChanges(() => { + const bindingsToRemove: TLBindingId[] = [] + + for (const bindingId of bindingsWithoutFrom) { + const binding = editor.getBinding(bindingId) + if (!binding) continue + + const shape = editor.getShape(binding.fromId) + if (!shape) continue + + editor.getBindingUtil(binding).onBeforeDeleteFromShape?.({ binding, shape }) + bindingsToRemove.push(binding.id) + } + + for (const bindingId of bindingsWithoutTo) { + const binding = editor.getBinding(bindingId) + if (!binding) continue + + const shape = editor.getShape(binding.toId) + if (!shape) continue + + editor.getBindingUtil(binding).onBeforeDeleteToShape?.({ binding, shape }) + bindingsToRemove.push(binding.id) + } + + editor.deleteBindings(bindingsToRemove) + + try { + result = Result.ok(callback(bindingsWithBoth)) + } catch (error) { + result = Result.err(error) + } + }) + + editor.store.applyDiff(reverseRecordsDiff(changes)) + }) + + if (result.ok) { + return result.value + } else { + throw result.error + } +} + function getCameraFitXFitY(editor: Editor, cameraOptions: TLCameraOptions) { if (!cameraOptions.constraints) throw Error('Should have constraints here') const { diff --git a/packages/editor/src/lib/editor/bindings/BindingUtil.ts b/packages/editor/src/lib/editor/bindings/BindingUtil.ts new file mode 100644 index 000000000..3289e408c --- /dev/null +++ b/packages/editor/src/lib/editor/bindings/BindingUtil.ts @@ -0,0 +1,77 @@ +import { RecordProps, TLPropsMigrations, TLShape, TLUnknownBinding } from '@tldraw/tlschema' +import { Editor } from '../Editor' + +/** @public */ +export interface TLBindingUtilConstructor< + T extends TLUnknownBinding, + U extends BindingUtil = BindingUtil, +> { + new (editor: Editor): U + type: T['type'] + props?: RecordProps + migrations?: TLPropsMigrations +} + +/** @public */ +export interface BindingOnCreateOptions { + binding: Binding +} + +/** @public */ +export interface BindingOnChangeOptions { + bindingBefore: Binding + bindingAfter: Binding +} + +/** @public */ +export interface BindingOnDeleteOptions { + binding: Binding +} + +/** @public */ +export interface BindingOnShapeChangeOptions { + binding: Binding + shapeBefore: TLShape + shapeAfter: TLShape +} + +/** @public */ +export interface BindingOnShapeDeleteOptions { + binding: Binding + shape: TLShape +} + +/** @public */ +export abstract class BindingUtil { + constructor(public editor: Editor) {} + static props?: RecordProps + static migrations?: TLPropsMigrations + + /** + * The type of the binding util, which should match the binding's type. + * + * @public + */ + static type: string + + /** + * Get the default props for a binding. + * + * @public + */ + abstract getDefaultProps(): Partial + + // self lifecycle hooks + onBeforeCreate?(options: BindingOnCreateOptions): Binding | void + onAfterCreate?(options: BindingOnCreateOptions): void + onBeforeChange?(options: BindingOnChangeOptions): Binding | void + onAfterChange?(options: BindingOnChangeOptions): void + onBeforeDelete?(options: BindingOnDeleteOptions): void + onAfterDelete?(options: BindingOnDeleteOptions): void + + onAfterChangeFromShape?(options: BindingOnShapeChangeOptions): void + onAfterChangeToShape?(options: BindingOnShapeChangeOptions): void + + onBeforeDeleteFromShape?(options: BindingOnShapeDeleteOptions): void + onBeforeDeleteToShape?(options: BindingOnShapeDeleteOptions): void +} diff --git a/packages/editor/src/lib/editor/derivations/arrowBindingsIndex.ts b/packages/editor/src/lib/editor/derivations/arrowBindingsIndex.ts deleted file mode 100644 index 553c4a88e..000000000 --- a/packages/editor/src/lib/editor/derivations/arrowBindingsIndex.ts +++ /dev/null @@ -1,141 +0,0 @@ -import { Computed, RESET_VALUE, computed, isUninitialized } from '@tldraw/state' -import { TLArrowShape, TLShape, TLShapeId } from '@tldraw/tlschema' -import { Editor } from '../Editor' - -type TLArrowBindingsIndex = Record< - TLShapeId, - undefined | { arrowId: TLShapeId; handleId: 'start' | 'end' }[] -> - -export const arrowBindingsIndex = (editor: Editor): Computed => { - const { store } = editor - const shapeHistory = store.query.filterHistory('shape') - const arrowQuery = store.query.records('shape', () => ({ type: { eq: 'arrow' as const } })) - function fromScratch() { - const allArrows = arrowQuery.get() as TLArrowShape[] - - const bindings2Arrows: TLArrowBindingsIndex = {} - - for (const arrow of allArrows) { - const { start, end } = arrow.props - if (start.type === 'binding') { - const arrows = bindings2Arrows[start.boundShapeId] - if (arrows) arrows.push({ arrowId: arrow.id, handleId: 'start' }) - else bindings2Arrows[start.boundShapeId] = [{ arrowId: arrow.id, handleId: 'start' }] - } - - if (end.type === 'binding') { - const arrows = bindings2Arrows[end.boundShapeId] - if (arrows) arrows.push({ arrowId: arrow.id, handleId: 'end' }) - else bindings2Arrows[end.boundShapeId] = [{ arrowId: arrow.id, handleId: 'end' }] - } - } - - return bindings2Arrows - } - - return computed('arrowBindingsIndex', (_lastValue, lastComputedEpoch) => { - if (isUninitialized(_lastValue)) { - return fromScratch() - } - - const lastValue = _lastValue - - const diff = shapeHistory.getDiffSince(lastComputedEpoch) - - if (diff === RESET_VALUE) { - return fromScratch() - } - - let nextValue: TLArrowBindingsIndex | undefined = undefined - - function ensureNewArray(boundShapeId: TLShapeId) { - // this will never happen - if (!nextValue) { - nextValue = { ...lastValue } - } - if (!nextValue[boundShapeId]) { - nextValue[boundShapeId] = [] - } else if (nextValue[boundShapeId] === lastValue[boundShapeId]) { - nextValue[boundShapeId] = [...nextValue[boundShapeId]!] - } - } - - function removingBinding( - boundShapeId: TLShapeId, - arrowId: TLShapeId, - handleId: 'start' | 'end' - ) { - ensureNewArray(boundShapeId) - nextValue![boundShapeId] = nextValue![boundShapeId]!.filter( - (binding) => binding.arrowId !== arrowId || binding.handleId !== handleId - ) - if (nextValue![boundShapeId]!.length === 0) { - delete nextValue![boundShapeId] - } - } - - function addBinding(boundShapeId: TLShapeId, arrowId: TLShapeId, handleId: 'start' | 'end') { - ensureNewArray(boundShapeId) - nextValue![boundShapeId]!.push({ arrowId, handleId }) - } - - for (const changes of diff) { - for (const newShape of Object.values(changes.added)) { - if (editor.isShapeOfType(newShape, 'arrow')) { - const { start, end } = newShape.props - if (start.type === 'binding') { - addBinding(start.boundShapeId, newShape.id, 'start') - } - if (end.type === 'binding') { - addBinding(end.boundShapeId, newShape.id, 'end') - } - } - } - - for (const [prev, next] of Object.values(changes.updated) as [TLShape, TLShape][]) { - if ( - !editor.isShapeOfType(prev, 'arrow') || - !editor.isShapeOfType(next, 'arrow') - ) - continue - - for (const handle of ['start', 'end'] as const) { - const prevTerminal = prev.props[handle] - const nextTerminal = next.props[handle] - - if (prevTerminal.type === 'binding' && nextTerminal.type === 'point') { - // if the binding was removed - removingBinding(prevTerminal.boundShapeId, prev.id, handle) - } else if (prevTerminal.type === 'point' && nextTerminal.type === 'binding') { - // if the binding was added - addBinding(nextTerminal.boundShapeId, next.id, handle) - } else if ( - prevTerminal.type === 'binding' && - nextTerminal.type === 'binding' && - prevTerminal.boundShapeId !== nextTerminal.boundShapeId - ) { - // if the binding was changed - removingBinding(prevTerminal.boundShapeId, prev.id, handle) - addBinding(nextTerminal.boundShapeId, next.id, handle) - } - } - } - - for (const prev of Object.values(changes.removed)) { - if (editor.isShapeOfType(prev, 'arrow')) { - const { start, end } = prev.props - if (start.type === 'binding') { - removingBinding(start.boundShapeId, prev.id, 'start') - } - if (end.type === 'binding') { - removingBinding(end.boundShapeId, prev.id, 'end') - } - } - } - } - - // TODO: add diff entries if we need them - return nextValue ?? lastValue - }) -} diff --git a/packages/editor/src/lib/editor/shapes/ShapeUtil.ts b/packages/editor/src/lib/editor/shapes/ShapeUtil.ts index e596d90f5..838d8a941 100644 --- a/packages/editor/src/lib/editor/shapes/ShapeUtil.ts +++ b/packages/editor/src/lib/editor/shapes/ShapeUtil.ts @@ -1,11 +1,11 @@ /* eslint-disable @typescript-eslint/no-unused-vars */ import { LegacyMigrations, MigrationSequence } from '@tldraw/store' import { - ShapeProps, + RecordProps, TLHandle, + TLPropsMigrations, TLShape, TLShapePartial, - TLShapePropsMigrations, TLUnknownShape, } from '@tldraw/tlschema' import { ReactElement } from 'react' @@ -25,8 +25,8 @@ export interface TLShapeUtilConstructor< > { new (editor: Editor): U type: T['type'] - props?: ShapeProps - migrations?: LegacyMigrations | TLShapePropsMigrations | MigrationSequence + props?: RecordProps + migrations?: LegacyMigrations | TLPropsMigrations | MigrationSequence } /** @public */ @@ -41,8 +41,8 @@ export interface TLShapeUtilCanvasSvgDef { /** @public */ export abstract class ShapeUtil { constructor(public editor: Editor) {} - static props?: ShapeProps - static migrations?: LegacyMigrations | TLShapePropsMigrations + static props?: RecordProps + static migrations?: LegacyMigrations | TLPropsMigrations | MigrationSequence /** * The type of the shape util, which should match the shape's type. @@ -132,6 +132,13 @@ export abstract class ShapeUtil { */ canCrop: TLShapeUtilFlag = () => false + /** + * Whether the shape participates in stacking, aligning, and distributing. + * + * @public + */ + canBeLaidOut: TLShapeUtilFlag = () => true + /** * Does this shape provide a background for its children? If this is true, * then any children with a `renderBackground` method will have their diff --git a/packages/editor/src/lib/editor/shapes/shared/arrow/arrow-types.ts b/packages/editor/src/lib/editor/shapes/shared/arrow/arrow-types.ts index dbcba62a8..e0412ed69 100644 --- a/packages/editor/src/lib/editor/shapes/shared/arrow/arrow-types.ts +++ b/packages/editor/src/lib/editor/shapes/shared/arrow/arrow-types.ts @@ -1,5 +1,6 @@ import { TLArrowShapeArrowheadStyle } from '@tldraw/tlschema' import { VecLike } from '../../../../primitives/Vec' +import { TLArrowBindings } from './shared' /** @public */ export type TLArrowPoint = { @@ -21,6 +22,7 @@ export interface TLArcInfo { /** @public */ export type TLArrowInfo = | { + bindings: TLArrowBindings isStraight: false start: TLArrowPoint end: TLArrowPoint @@ -30,6 +32,7 @@ export type TLArrowInfo = isValid: boolean } | { + bindings: TLArrowBindings isStraight: true start: TLArrowPoint end: TLArrowPoint diff --git a/packages/editor/src/lib/editor/shapes/shared/arrow/curved-arrow.ts b/packages/editor/src/lib/editor/shapes/shared/arrow/curved-arrow.ts index a8a755df5..1bff1f53f 100644 --- a/packages/editor/src/lib/editor/shapes/shared/arrow/curved-arrow.ts +++ b/packages/editor/src/lib/editor/shapes/shared/arrow/curved-arrow.ts @@ -15,6 +15,7 @@ import { BOUND_ARROW_OFFSET, MIN_ARROW_LENGTH, STROKE_SIZES, + TLArrowBindings, WAY_TOO_BIG_ARROW_BEND_FACTOR, getArrowTerminalsInArrowSpace, getBoundShapeInfoForTerminal, @@ -25,16 +26,16 @@ import { getStraightArrowInfo } from './straight-arrow' export function getCurvedArrowInfo( editor: Editor, shape: TLArrowShape, - extraBend = 0 + bindings: TLArrowBindings ): TLArrowInfo { const { arrowheadEnd, arrowheadStart } = shape.props - const bend = shape.props.bend + extraBend + const bend = shape.props.bend if (Math.abs(bend) > Math.abs(shape.props.bend * WAY_TOO_BIG_ARROW_BEND_FACTOR)) { - return getStraightArrowInfo(editor, shape) + return getStraightArrowInfo(editor, shape, bindings) } - const terminalsInArrowSpace = getArrowTerminalsInArrowSpace(editor, shape) + const terminalsInArrowSpace = getArrowTerminalsInArrowSpace(editor, shape, bindings) const med = Vec.Med(terminalsInArrowSpace.start, terminalsInArrowSpace.end) // point between start and end const distance = Vec.Sub(terminalsInArrowSpace.end, terminalsInArrowSpace.start) @@ -42,8 +43,8 @@ export function getCurvedArrowInfo( const u = Vec.Len(distance) ? distance.uni() : Vec.From(distance) // unit vector between start and end const middle = Vec.Add(med, u.per().mul(-bend)) // middle handle - const startShapeInfo = getBoundShapeInfoForTerminal(editor, shape.props.start) - const endShapeInfo = getBoundShapeInfoForTerminal(editor, shape.props.end) + const startShapeInfo = getBoundShapeInfoForTerminal(editor, shape, 'start') + const endShapeInfo = getBoundShapeInfoForTerminal(editor, shape, 'end') // The positions of the body of the arrow, which may be different // than the arrow's start / end points if the arrow is bound to shapes @@ -53,6 +54,7 @@ export function getCurvedArrowInfo( if (Vec.Equals(a, b)) { return { + bindings, isStraight: true, start: { handle: a, @@ -84,7 +86,7 @@ export function getCurvedArrowInfo( !isSafeFloat(handleArc.length) || !isSafeFloat(handleArc.size) ) { - return getStraightArrowInfo(editor, shape) + return getStraightArrowInfo(editor, shape, bindings) } const tempA = a.clone() @@ -341,6 +343,7 @@ export function getCurvedArrowInfo( const bodyArc = getArcInfo(a, b, c) return { + bindings, isStraight: false, start: { point: a, diff --git a/packages/editor/src/lib/editor/shapes/shared/arrow/shared.ts b/packages/editor/src/lib/editor/shapes/shared/arrow/shared.ts index 5d4bc94f8..7797e45a1 100644 --- a/packages/editor/src/lib/editor/shapes/shared/arrow/shared.ts +++ b/packages/editor/src/lib/editor/shapes/shared/arrow/shared.ts @@ -1,4 +1,10 @@ -import { TLArrowShape, TLArrowShapeTerminal, TLShape, TLShapeId } from '@tldraw/tlschema' +import { + TLArrowBinding, + TLArrowBindingProps, + TLArrowShape, + TLShape, + TLShapeId, +} from '@tldraw/tlschema' import { Mat } from '../../../../primitives/Mat' import { Vec } from '../../../../primitives/Vec' import { Group2d } from '../../../../primitives/geometry/Group2d' @@ -19,16 +25,18 @@ export type BoundShapeInfo = { export function getBoundShapeInfoForTerminal( editor: Editor, - terminal: TLArrowShapeTerminal + arrow: TLArrowShape, + terminalName: 'start' | 'end' ): BoundShapeInfo | undefined { - if (terminal.type === 'point') { - return - } + const binding = editor + .getBindingsFromShape(arrow, 'arrow') + .find((b) => b.props.terminal === terminalName) + if (!binding) return - const shape = editor.getShape(terminal.boundShapeId) - if (!shape) return - const transform = editor.getShapePageTransform(shape) - const geometry = editor.getShapeGeometry(shape) + const boundShape = editor.getShape(binding.toId)! + if (!boundShape) return + const transform = editor.getShapePageTransform(boundShape)! + const geometry = editor.getShapeGeometry(boundShape) // This is hacky: we're only looking at the first child in the group. Really the arrow should // consider all items in the group which are marked as snappable as separate polygons with which @@ -37,10 +45,10 @@ export function getBoundShapeInfoForTerminal( const outline = geometry instanceof Group2d ? geometry.children[0].vertices : geometry.vertices return { - shape, + shape: boundShape, transform, isClosed: geometry.isClosed, - isExact: terminal.isExact, + isExact: binding.props.isExact, didIntersect: false, outline, } @@ -49,14 +57,10 @@ export function getBoundShapeInfoForTerminal( function getArrowTerminalInArrowSpace( editor: Editor, arrowPageTransform: Mat, - terminal: TLArrowShapeTerminal, + binding: TLArrowBinding, forceImprecise: boolean ) { - if (terminal.type === 'point') { - return Vec.From(terminal) - } - - const boundShape = editor.getShape(terminal.boundShapeId) + const boundShape = editor.getShape(binding.toId) if (!boundShape) { // this can happen in multiplayer contexts where the shape is being deleted @@ -70,7 +74,9 @@ function getArrowTerminalInArrowSpace( point, Vec.MulV( // if the parent is the bound shape, then it's ALWAYS precise - terminal.isPrecise || forceImprecise ? terminal.normalizedAnchor : { x: 0.5, y: 0.5 }, + binding.props.isPrecise || forceImprecise + ? binding.props.normalizedAnchor + : { x: 0.5, y: 0.5 }, size ) ) @@ -81,40 +87,108 @@ function getArrowTerminalInArrowSpace( } /** @public */ -export function getArrowTerminalsInArrowSpace(editor: Editor, shape: TLArrowShape) { - const arrowPageTransform = editor.getShapePageTransform(shape)! +export interface TLArrowBindings { + start: TLArrowBinding | undefined + end: TLArrowBinding | undefined +} - let startBoundShapeId: TLShapeId | undefined - let endBoundShapeId: TLShapeId | undefined - - if (shape.props.start.type === 'binding' && shape.props.end.type === 'binding') { - startBoundShapeId = shape.props.start.boundShapeId - endBoundShapeId = shape.props.end.boundShapeId +/** @public */ +export function getArrowBindings(editor: Editor, shape: TLArrowShape): TLArrowBindings { + const bindings = editor.getBindingsFromShape(shape, 'arrow') + return { + start: bindings.find((b) => b.props.terminal === 'start'), + end: bindings.find((b) => b.props.terminal === 'end'), } +} + +/** @public */ +export function getArrowTerminalsInArrowSpace( + editor: Editor, + shape: TLArrowShape, + bindings: TLArrowBindings +) { + const arrowPageTransform = editor.getShapePageTransform(shape)! const boundShapeRelationships = getBoundShapeRelationships( editor, - startBoundShapeId, - endBoundShapeId + bindings.start?.toId, + bindings.end?.toId ) - const start = getArrowTerminalInArrowSpace( - editor, - arrowPageTransform, - shape.props.start, - boundShapeRelationships === 'double-bound' || boundShapeRelationships === 'start-contains-end' - ) + const start = bindings.start + ? getArrowTerminalInArrowSpace( + editor, + arrowPageTransform, + bindings.start, + boundShapeRelationships === 'double-bound' || + boundShapeRelationships === 'start-contains-end' + ) + : Vec.From(shape.props.start) - const end = getArrowTerminalInArrowSpace( - editor, - arrowPageTransform, - shape.props.end, - boundShapeRelationships === 'double-bound' || boundShapeRelationships === 'end-contains-start' - ) + const end = bindings.end + ? getArrowTerminalInArrowSpace( + editor, + arrowPageTransform, + bindings.end, + boundShapeRelationships === 'double-bound' || + boundShapeRelationships === 'end-contains-start' + ) + : Vec.From(shape.props.end) return { start, end } } +/** + * Create or update the arrow binding for a particular arrow terminal. Will clear up if needed. + * @internal + */ +export function createOrUpdateArrowBinding( + editor: Editor, + arrow: TLArrowShape | TLShapeId, + target: TLShape | TLShapeId, + props: TLArrowBindingProps +) { + const arrowId = typeof arrow === 'string' ? arrow : arrow.id + const targetId = typeof target === 'string' ? target : target.id + + const existingMany = editor + .getBindingsFromShape(arrowId, 'arrow') + .filter((b) => b.props.terminal === props.terminal) + + // if we've somehow ended up with too many bindings, delete the extras + if (existingMany.length > 1) { + editor.deleteBindings(existingMany.slice(1)) + } + + const existing = existingMany[0] + if (existing) { + editor.updateBinding({ + ...existing, + toId: targetId, + props, + }) + } else { + editor.createBinding({ + type: 'arrow', + fromId: arrowId, + toId: targetId, + props, + }) + } +} + +/** + * Remove any arrow bindings for a particular terminal. + * @internal + */ +export function removeArrowBinding(editor: Editor, arrow: TLArrowShape, terminal: 'start' | 'end') { + const existing = editor + .getBindingsFromShape(arrow, 'arrow') + .filter((b) => b.props.terminal === terminal) + + editor.deleteBindings(existing) +} + /** @internal */ export const MIN_ARROW_LENGTH = 10 /** @internal */ diff --git a/packages/editor/src/lib/editor/shapes/shared/arrow/straight-arrow.ts b/packages/editor/src/lib/editor/shapes/shared/arrow/straight-arrow.ts index db2373c56..4ecea943f 100644 --- a/packages/editor/src/lib/editor/shapes/shared/arrow/straight-arrow.ts +++ b/packages/editor/src/lib/editor/shapes/shared/arrow/straight-arrow.ts @@ -12,15 +12,20 @@ import { BoundShapeInfo, MIN_ARROW_LENGTH, STROKE_SIZES, + TLArrowBindings, getArrowTerminalsInArrowSpace, getBoundShapeInfoForTerminal, getBoundShapeRelationships, } from './shared' -export function getStraightArrowInfo(editor: Editor, shape: TLArrowShape): TLArrowInfo { - const { start, end, arrowheadStart, arrowheadEnd } = shape.props +export function getStraightArrowInfo( + editor: Editor, + shape: TLArrowShape, + bindings: TLArrowBindings +): TLArrowInfo { + const { arrowheadStart, arrowheadEnd } = shape.props - const terminalsInArrowSpace = getArrowTerminalsInArrowSpace(editor, shape) + const terminalsInArrowSpace = getArrowTerminalsInArrowSpace(editor, shape, bindings) const a = terminalsInArrowSpace.start.clone() const b = terminalsInArrowSpace.end.clone() @@ -28,6 +33,7 @@ export function getStraightArrowInfo(editor: Editor, shape: TLArrowShape): TLArr if (Vec.Equals(a, b)) { return { + bindings, isStraight: true, start: { handle: a, @@ -49,8 +55,8 @@ export function getStraightArrowInfo(editor: Editor, shape: TLArrowShape): TLArr // Update the arrowhead points using intersections with the bound shapes, if any. - const startShapeInfo = getBoundShapeInfoForTerminal(editor, start) - const endShapeInfo = getBoundShapeInfoForTerminal(editor, end) + const startShapeInfo = getBoundShapeInfoForTerminal(editor, shape, 'start') + const endShapeInfo = getBoundShapeInfoForTerminal(editor, shape, 'end') const arrowPageTransform = editor.getShapePageTransform(shape)! @@ -189,6 +195,7 @@ export function getStraightArrowInfo(editor: Editor, shape: TLArrowShape): TLArr const length = Vec.Dist(a, b) return { + bindings, isStraight: true, start: { handle: terminalsInArrowSpace.start, diff --git a/packages/editor/src/lib/editor/types/clipboard-types.ts b/packages/editor/src/lib/editor/types/clipboard-types.ts index 6b298a9a6..9c9f085b3 100644 --- a/packages/editor/src/lib/editor/types/clipboard-types.ts +++ b/packages/editor/src/lib/editor/types/clipboard-types.ts @@ -1,9 +1,10 @@ import { SerializedSchema } from '@tldraw/store' -import { TLAsset, TLShape, TLShapeId } from '@tldraw/tlschema' +import { TLAsset, TLBinding, TLShape, TLShapeId } from '@tldraw/tlschema' /** @public */ export interface TLContent { shapes: TLShape[] + bindings: TLBinding[] | undefined rootShapeIds: TLShapeId[] assets: TLAsset[] schema: SerializedSchema diff --git a/packages/editor/src/lib/test/currentToolIdMask.test.ts b/packages/editor/src/lib/test/currentToolIdMask.test.ts index ea60924ae..b76905c46 100644 --- a/packages/editor/src/lib/test/currentToolIdMask.test.ts +++ b/packages/editor/src/lib/test/currentToolIdMask.test.ts @@ -24,6 +24,7 @@ beforeEach(() => { editor = new Editor({ initialState: 'A', shapeUtils: [], + bindingUtils: [], tools: [A, B, C], store: createTLStore({ shapeUtils: [] }), getContainer: () => document.body, diff --git a/packages/editor/src/lib/test/user.test.ts b/packages/editor/src/lib/test/user.test.ts index 256ef3472..09c46d861 100644 --- a/packages/editor/src/lib/test/user.test.ts +++ b/packages/editor/src/lib/test/user.test.ts @@ -6,6 +6,7 @@ let editor: Editor beforeEach(() => { editor = new Editor({ shapeUtils: [], + bindingUtils: [], tools: [], store: createTLStore({ shapeUtils: [] }), getContainer: () => document.body, diff --git a/packages/store/api-report.md b/packages/store/api-report.md index cd2ba6928..e30dc7247 100644 --- a/packages/store/api-report.md +++ b/packages/store/api-report.md @@ -37,7 +37,7 @@ export type ComputedCache = { export function createEmptyRecordsDiff(): RecordsDiff; // @public -export function createMigrationIds>(sequenceId: ID, versions: Versions): { +export function createMigrationIds>(sequenceId: ID, versions: Versions): { [K in keyof Versions]: `${ID}/${Versions[K]}`; }; @@ -265,6 +265,11 @@ export function squashRecordDiffs(diffs: RecordsDiff // @internal export function squashRecordDiffsMutable(target: RecordsDiff, diffs: RecordsDiff[]): void; +// @public (undocumented) +export type StandaloneDependsOn = { + readonly dependsOn: readonly MigrationId[]; +}; + // @public export class Store { constructor(config: { diff --git a/packages/store/src/index.ts b/packages/store/src/index.ts index 2216ec4d7..fe45f3f04 100644 --- a/packages/store/src/index.ts +++ b/packages/store/src/index.ts @@ -43,5 +43,6 @@ export { type MigrationId, type MigrationResult, type MigrationSequence, + type StandaloneDependsOn, } from './lib/migrate' export type { AllRecords } from './lib/type-utils' diff --git a/packages/store/src/lib/migrate.ts b/packages/store/src/lib/migrate.ts index 95dc30c7e..7782eb9cf 100644 --- a/packages/store/src/lib/migrate.ts +++ b/packages/store/src/lib/migrate.ts @@ -91,10 +91,10 @@ export function createMigrationSequence({ * @public * @public */ -export function createMigrationIds>( - sequenceId: ID, - versions: Versions -): { [K in keyof Versions]: `${ID}/${Versions[K]}` } { +export function createMigrationIds< + const ID extends string, + const Versions extends Record, +>(sequenceId: ID, versions: Versions): { [K in keyof Versions]: `${ID}/${Versions[K]}` } { return Object.fromEntries( objectMapEntries(versions).map(([key, version]) => [key, `${sequenceId}/${version}`] as const) ) as any @@ -136,6 +136,7 @@ export type LegacyMigration = { /** @public */ export type MigrationId = `${string}/${number}` +/** @public */ export type StandaloneDependsOn = { readonly dependsOn: readonly MigrationId[] } diff --git a/packages/tldraw/api-report.md b/packages/tldraw/api-report.md index b35bccf42..9d43253e3 100644 --- a/packages/tldraw/api-report.md +++ b/packages/tldraw/api-report.md @@ -33,7 +33,6 @@ import { MemoExoticComponent } from 'react'; import { MigrationFailureReason } from '@tldraw/editor'; import { MigrationSequence } from '@tldraw/editor'; import { NamedExoticComponent } from 'react'; -import { ObjectValidator } from '@tldraw/editor'; import { Polygon2d } from '@tldraw/editor'; import { Polyline2d } from '@tldraw/editor'; import { default as React_2 } from 'react'; @@ -54,6 +53,7 @@ import { StoreSnapshot } from '@tldraw/editor'; import { StyleProp } from '@tldraw/editor'; import { SvgExportContext } from '@tldraw/editor'; import { T } from '@tldraw/editor'; +import { TLAnyBindingUtilConstructor } from '@tldraw/editor'; import { TLAnyShapeUtilConstructor } from '@tldraw/editor'; import { TLArrowShape } from '@tldraw/editor'; import { TLAssetId } from '@tldraw/editor'; @@ -95,6 +95,7 @@ import { TLOnDoubleClickHandler } from '@tldraw/editor'; import { TLOnEditEndHandler } from '@tldraw/editor'; import { TLOnHandleDragHandler } from '@tldraw/editor'; import { TLOnResizeHandler } from '@tldraw/editor'; +import { TLOnResizeStartHandler } from '@tldraw/editor'; import { TLOnTranslateHandler } from '@tldraw/editor'; import { TLOnTranslateStartHandler } from '@tldraw/editor'; import { TLPageId } from '@tldraw/editor'; @@ -102,6 +103,7 @@ import { TLParentId } from '@tldraw/editor'; import { TLPointerEvent } from '@tldraw/editor'; import { TLPointerEventInfo } from '@tldraw/editor'; import { TLPointerEventName } from '@tldraw/editor'; +import { TLPropsMigrations } from '@tldraw/editor'; import { TLRecord } from '@tldraw/editor'; import { TLRotationSnapshot } from '@tldraw/editor'; import { TLSchema } from '@tldraw/editor'; @@ -112,7 +114,6 @@ import { TLSelectionHandle } from '@tldraw/editor'; import { TLShape } from '@tldraw/editor'; import { TLShapeId } from '@tldraw/editor'; import { TLShapePartial } from '@tldraw/editor'; -import { TLShapePropsMigrations } from '@tldraw/editor'; import { TLShapeUtilCanvasSvgDef } from '@tldraw/editor'; import { TLShapeUtilFlag } from '@tldraw/editor'; import { TLStore } from '@tldraw/editor'; @@ -121,7 +122,6 @@ import { TLSvgOptions } from '@tldraw/editor'; import { TLTextShape } from '@tldraw/editor'; import { TLUnknownShape } from '@tldraw/editor'; import { TLVideoShape } from '@tldraw/editor'; -import { UnionValidator } from '@tldraw/editor'; import { UnknownRecord } from '@tldraw/editor'; import { Validator } from '@tldraw/editor'; import { Vec } from '@tldraw/editor'; @@ -165,6 +165,8 @@ export class ArrowShapeTool extends StateNode { // @public (undocumented) export class ArrowShapeUtil extends ShapeUtil { + // (undocumented) + canBeLaidOut: TLShapeUtilFlag; // (undocumented) canBind: () => boolean; // (undocumented) @@ -192,7 +194,7 @@ export class ArrowShapeUtil extends ShapeUtil { // (undocumented) indicator(shape: TLArrowShape): JSX_2.Element | null; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: MigrationSequence; // (undocumented) onDoubleClickHandle: (shape: TLArrowShape, handle: TLHandle) => TLShapePartial | void; // (undocumented) @@ -202,6 +204,8 @@ export class ArrowShapeUtil extends ShapeUtil { // (undocumented) onResize: TLOnResizeHandler; // (undocumented) + onResizeStart?: TLOnResizeStartHandler; + // (undocumented) onTranslate?: TLOnTranslateHandler; // (undocumented) onTranslateStart: TLOnTranslateStartHandler; @@ -212,39 +216,13 @@ export class ArrowShapeUtil extends ShapeUtil { bend: Validator; color: EnumStyleProp<"black" | "blue" | "green" | "grey" | "light-blue" | "light-green" | "light-red" | "light-violet" | "orange" | "red" | "violet" | "white" | "yellow">; dash: EnumStyleProp<"dashed" | "dotted" | "draw" | "solid">; - end: UnionValidator<"type", { - binding: ObjectValidator< { - boundShapeId: TLShapeId; - isExact: boolean; - isPrecise: boolean; - normalizedAnchor: VecModel; - type: "binding"; - }>; - point: ObjectValidator< { - type: "point"; - x: number; - y: number; - }>; - }, never>; + end: Validator; fill: EnumStyleProp<"none" | "pattern" | "semi" | "solid">; font: EnumStyleProp<"draw" | "mono" | "sans" | "serif">; labelColor: EnumStyleProp<"black" | "blue" | "green" | "grey" | "light-blue" | "light-green" | "light-red" | "light-violet" | "orange" | "red" | "violet" | "white" | "yellow">; labelPosition: Validator; size: EnumStyleProp<"l" | "m" | "s" | "xl">; - start: UnionValidator<"type", { - binding: ObjectValidator< { - boundShapeId: TLShapeId; - isExact: boolean; - isPrecise: boolean; - normalizedAnchor: VecModel; - type: "binding"; - }>; - point: ObjectValidator< { - type: "point"; - x: number; - y: number; - }>; - }, never>; + start: Validator; text: Validator; }; // (undocumented) @@ -281,7 +259,7 @@ export class BookmarkShapeUtil extends BaseBoxShapeUtil { // (undocumented) indicator(shape: TLBookmarkShape): JSX_2.Element; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) onBeforeCreate?: TLOnBeforeCreateHandler; // (undocumented) @@ -360,6 +338,9 @@ export const DefaultActionsMenu: NamedExoticComponent; // @public (undocumented) export function DefaultActionsMenuContent(): JSX_2.Element; +// @public (undocumented) +export const defaultBindingUtils: TLAnyBindingUtilConstructor[]; + // @public (undocumented) const DefaultContextMenu: NamedExoticComponent; export { DefaultContextMenu as ContextMenu } @@ -492,7 +473,7 @@ export class DrawShapeUtil extends ShapeUtil { // (undocumented) indicator(shape: TLDrawShape): JSX_2.Element; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) onResize: TLOnResizeHandler; // (undocumented) @@ -549,7 +530,7 @@ export class EmbedShapeUtil extends BaseBoxShapeUtil { // (undocumented) isAspectRatioLocked: TLShapeUtilFlag; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) onResize: TLOnResizeHandler; // (undocumented) @@ -656,7 +637,7 @@ export class FrameShapeUtil extends BaseBoxShapeUtil { // (undocumented) indicator(shape: TLFrameShape): JSX_2.Element; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) onDragShapesOut: (_shape: TLFrameShape, shapes: TLShape[]) => void; // (undocumented) @@ -709,7 +690,7 @@ export class GeoShapeUtil extends BaseBoxShapeUtil { // (undocumented) indicator(shape: TLGeoShape): JSX_2.Element; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) onBeforeCreate: (shape: TLGeoShape) => { id: TLShapeId; @@ -923,7 +904,7 @@ export class HighlightShapeUtil extends ShapeUtil { // (undocumented) indicator(shape: TLHighlightShape): JSX_2.Element; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) onResize: TLOnResizeHandler; // (undocumented) @@ -961,7 +942,7 @@ export class ImageShapeUtil extends BaseBoxShapeUtil { // (undocumented) isAspectRatioLocked: () => boolean; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) onDoubleClick: (shape: TLImageShape) => void; // (undocumented) @@ -1065,7 +1046,7 @@ export class LineShapeUtil extends ShapeUtil { // (undocumented) indicator(shape: TLLineShape): JSX_2.Element; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) onHandleDrag: TLOnHandleDragHandler; // (undocumented) @@ -1129,7 +1110,7 @@ export class NoteShapeUtil extends ShapeUtil { // (undocumented) indicator(shape: TLNoteShape): JSX_2.Element; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) onBeforeCreate: (next: TLNoteShape) => { id: TLShapeId; @@ -1376,7 +1357,7 @@ export class TextShapeUtil extends ShapeUtil { // (undocumented) isAspectRatioLocked: TLShapeUtilFlag; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) onBeforeCreate: (shape: TLTextShape) => { id: TLShapeId; @@ -1496,6 +1477,7 @@ export function TldrawHandles({ children }: TLHandlesProps): JSX_2.Element | nul // @public export const TldrawImage: NamedExoticComponent< { background?: boolean | undefined; +bindingUtils?: readonly TLAnyBindingUtilConstructor[] | undefined; bounds?: Box | undefined; darkMode?: boolean | undefined; format?: "png" | "svg" | undefined; @@ -1509,6 +1491,7 @@ snapshot: StoreSnapshot; // @public export type TldrawImageProps = Expand<{ + bindingUtils?: readonly TLAnyBindingUtilConstructor[]; shapeUtils?: readonly TLAnyShapeUtilConstructor[]; format?: 'png' | 'svg'; pageId?: TLPageId; @@ -2669,7 +2652,7 @@ export class VideoShapeUtil extends BaseBoxShapeUtil { // (undocumented) isAspectRatioLocked: () => boolean; // (undocumented) - static migrations: TLShapePropsMigrations; + static migrations: TLPropsMigrations; // (undocumented) static props: { assetId: Validator; diff --git a/packages/tldraw/src/index.ts b/packages/tldraw/src/index.ts index 03e5643ab..821f3d40d 100644 --- a/packages/tldraw/src/index.ts +++ b/packages/tldraw/src/index.ts @@ -12,6 +12,7 @@ export { TldrawHandles } from './lib/canvas/TldrawHandles' export { TldrawScribble } from './lib/canvas/TldrawScribble' export { TldrawSelectionBackground } from './lib/canvas/TldrawSelectionBackground' export { TldrawSelectionForeground } from './lib/canvas/TldrawSelectionForeground' +export { defaultBindingUtils } from './lib/defaultBindingUtils' export { defaultShapeTools } from './lib/defaultShapeTools' export { defaultShapeUtils } from './lib/defaultShapeUtils' export { defaultTools } from './lib/defaultTools' diff --git a/packages/tldraw/src/lib/Tldraw.tsx b/packages/tldraw/src/lib/Tldraw.tsx index 159cf2023..86c2d3f0c 100644 --- a/packages/tldraw/src/lib/Tldraw.tsx +++ b/packages/tldraw/src/lib/Tldraw.tsx @@ -23,6 +23,7 @@ import { TldrawHandles } from './canvas/TldrawHandles' import { TldrawScribble } from './canvas/TldrawScribble' import { TldrawSelectionBackground } from './canvas/TldrawSelectionBackground' import { TldrawSelectionForeground } from './canvas/TldrawSelectionForeground' +import { defaultBindingUtils } from './defaultBindingUtils' import { TLExternalContentProps, registerDefaultExternalContentHandlers, @@ -79,6 +80,7 @@ export function Tldraw(props: TldrawProps) { onMount, components = {}, shapeUtils = [], + bindingUtils = [], tools = [], ...rest } = props @@ -102,6 +104,12 @@ export function Tldraw(props: TldrawProps) { [_shapeUtils] ) + const _bindingUtils = useShallowArrayIdentity(bindingUtils) + const bindingUtilsWithDefaults = useMemo( + () => [...defaultBindingUtils, ..._bindingUtils], + [_bindingUtils] + ) + const _tools = useShallowArrayIdentity(tools) const toolsWithDefaults = useMemo( () => [...defaultTools, ...defaultShapeTools, ..._tools], @@ -123,6 +131,7 @@ export function Tldraw(props: TldrawProps) { {...rest} components={componentsWithDefault} shapeUtils={shapeUtilsWithDefaults} + bindingUtils={bindingUtilsWithDefaults} tools={toolsWithDefaults} > diff --git a/packages/tldraw/src/lib/TldrawImage.tsx b/packages/tldraw/src/lib/TldrawImage.tsx index f3ce3cd95..8c345ee89 100644 --- a/packages/tldraw/src/lib/TldrawImage.tsx +++ b/packages/tldraw/src/lib/TldrawImage.tsx @@ -4,6 +4,7 @@ import { Expand, LoadingScreen, StoreSnapshot, + TLAnyBindingUtilConstructor, TLAnyShapeUtilConstructor, TLPageId, TLRecord, @@ -12,6 +13,7 @@ import { useTLStore, } from '@tldraw/editor' import { memo, useLayoutEffect, useMemo, useState } from 'react' +import { defaultBindingUtils } from './defaultBindingUtils' import { defaultShapeUtils } from './defaultShapeUtils' import { usePreloadAssets } from './ui/hooks/usePreloadAssets' import { getSvgAsImage } from './utils/export/export' @@ -43,6 +45,10 @@ export type TldrawImageProps = Expand< * Additional shape utils to use. */ shapeUtils?: readonly TLAnyShapeUtilConstructor[] + /** + * Additional binding utils to use. + */ + bindingUtils?: readonly TLAnyBindingUtilConstructor[] } & Partial > @@ -69,6 +75,11 @@ export const TldrawImage = memo(function TldrawImage(props: TldrawImageProps) { const shapeUtils = useShallowArrayIdentity(props.shapeUtils ?? []) const shapeUtilsWithDefaults = useMemo(() => [...defaultShapeUtils, ...shapeUtils], [shapeUtils]) + const bindingUtils = useShallowArrayIdentity(props.bindingUtils ?? []) + const bindingUtilsWithDefaults = useMemo( + () => [...defaultBindingUtils, ...bindingUtils], + [bindingUtils] + ) const store = useTLStore({ snapshot: props.snapshot, shapeUtils: shapeUtilsWithDefaults }) const assets = useDefaultEditorAssetsWithOverrides() @@ -98,7 +109,8 @@ export const TldrawImage = memo(function TldrawImage(props: TldrawImageProps) { const editor = new Editor({ store, - shapeUtils: shapeUtilsWithDefaults ?? [], + shapeUtils: shapeUtilsWithDefaults, + bindingUtils: bindingUtilsWithDefaults, tools: [], getContainer: () => tempElm, }) @@ -152,6 +164,7 @@ export const TldrawImage = memo(function TldrawImage(props: TldrawImageProps) { container, store, shapeUtilsWithDefaults, + bindingUtilsWithDefaults, pageId, bounds, scale, diff --git a/packages/tldraw/src/lib/bindings/arrow/ArrowBindingUtil.ts b/packages/tldraw/src/lib/bindings/arrow/ArrowBindingUtil.ts new file mode 100644 index 000000000..5c5a58178 --- /dev/null +++ b/packages/tldraw/src/lib/bindings/arrow/ArrowBindingUtil.ts @@ -0,0 +1,224 @@ +import { + BindingOnChangeOptions, + BindingOnCreateOptions, + BindingOnShapeChangeOptions, + BindingOnShapeDeleteOptions, + BindingUtil, + Editor, + IndexKey, + TLArrowBinding, + TLArrowBindingProps, + TLArrowShape, + TLParentId, + TLShape, + TLShapeId, + TLShapePartial, + Vec, + arrowBindingMigrations, + arrowBindingProps, + assert, + getArrowBindings, + getIndexAbove, + getIndexBetween, + intersectLineSegmentCircle, + removeArrowBinding, +} from '@tldraw/editor' + +export class ArrowBindingUtil extends BindingUtil { + static override type = 'arrow' + + static override props = arrowBindingProps + static override migrations = arrowBindingMigrations + + override getDefaultProps(): Partial { + return { + isPrecise: false, + isExact: false, + normalizedAnchor: { x: 0.5, y: 0.5 }, + } + } + + // when the binding itself changes + override onAfterCreate({ binding }: BindingOnCreateOptions): void { + arrowDidUpdate(this.editor, this.editor.getShape(binding.fromId) as TLArrowShape) + } + + // when the binding itself changes + override onAfterChange({ bindingAfter }: BindingOnChangeOptions): void { + arrowDidUpdate(this.editor, this.editor.getShape(bindingAfter.fromId) as TLArrowShape) + } + + // when the arrow itself changes + override onAfterChangeFromShape({ + shapeAfter, + }: BindingOnShapeChangeOptions): void { + arrowDidUpdate(this.editor, shapeAfter as TLArrowShape) + } + + // when the shape an arrow is bound to changes + override onAfterChangeToShape({ binding }: BindingOnShapeChangeOptions): void { + reparentArrow(this.editor, binding.fromId) + } + + // when the shape the arrow is pointing to is deleted + override onBeforeDeleteToShape({ binding }: BindingOnShapeDeleteOptions): void { + const arrow = this.editor.getShape(binding.fromId) + if (!arrow) return + unbindArrowTerminal(this.editor, arrow, binding.props.terminal) + } +} + +function reparentArrow(editor: Editor, arrowId: TLShapeId) { + const arrow = editor.getShape(arrowId) + if (!arrow) return + const bindings = getArrowBindings(editor, arrow) + const { start, end } = bindings + const startShape = start ? editor.getShape(start.toId) : undefined + const endShape = end ? editor.getShape(end.toId) : undefined + + const parentPageId = editor.getAncestorPageId(arrow) + if (!parentPageId) return + + let nextParentId: TLParentId + if (startShape && endShape) { + // if arrow has two bindings, always parent arrow to closest common ancestor of the bindings + nextParentId = editor.findCommonAncestor([startShape, endShape]) ?? parentPageId + } else if (startShape || endShape) { + const bindingParentId = (startShape || endShape)?.parentId + // If the arrow and the shape that it is bound to have the same parent, then keep that parent + if (bindingParentId && bindingParentId === arrow.parentId) { + nextParentId = arrow.parentId + } else { + // if arrow has one binding, keep arrow on its own page + nextParentId = parentPageId + } + } else { + return + } + + if (nextParentId && nextParentId !== arrow.parentId) { + editor.reparentShapes([arrowId], nextParentId) + } + + const reparentedArrow = editor.getShape(arrowId) + if (!reparentedArrow) throw Error('no reparented arrow') + + const startSibling = editor.getShapeNearestSibling(reparentedArrow, startShape) + const endSibling = editor.getShapeNearestSibling(reparentedArrow, endShape) + + let highestSibling: TLShape | undefined + + if (startSibling && endSibling) { + highestSibling = startSibling.index > endSibling.index ? startSibling : endSibling + } else if (startSibling && !endSibling) { + highestSibling = startSibling + } else if (endSibling && !startSibling) { + highestSibling = endSibling + } else { + return + } + + let finalIndex: IndexKey + + const higherSiblings = editor + .getSortedChildIdsForParent(highestSibling.parentId) + .map((id) => editor.getShape(id)!) + .filter((sibling) => sibling.index > highestSibling!.index) + + if (higherSiblings.length) { + // there are siblings above the highest bound sibling, we need to + // insert between them. + + // if the next sibling is also a bound arrow though, we can end up + // all fighting for the same indexes. so lets find the next + // non-arrow sibling... + const nextHighestNonArrowSibling = higherSiblings.find((sibling) => sibling.type !== 'arrow') + + if ( + // ...then, if we're above the last shape we want to be above... + reparentedArrow.index > highestSibling.index && + // ...but below the next non-arrow sibling... + (!nextHighestNonArrowSibling || reparentedArrow.index < nextHighestNonArrowSibling.index) + ) { + // ...then we're already in the right place. no need to update! + return + } + + // otherwise, we need to find the index between the highest sibling + // we want to be above, and the next highest sibling we want to be + // below: + finalIndex = getIndexBetween(highestSibling.index, higherSiblings[0].index) + } else { + // if there are no siblings above us, we can just get the next index: + finalIndex = getIndexAbove(highestSibling.index) + } + + if (finalIndex !== reparentedArrow.index) { + editor.updateShapes([{ id: arrowId, type: 'arrow', index: finalIndex }]) + } +} + +function arrowDidUpdate(editor: Editor, arrow: TLArrowShape) { + const bindings = getArrowBindings(editor, arrow) + // if the shape is an arrow and its bound shape is on another page + // or was deleted, unbind it + for (const handle of ['start', 'end'] as const) { + const binding = bindings[handle] + if (!binding) continue + const boundShape = editor.getShape(binding.toId) + const isShapeInSamePageAsArrow = + editor.getAncestorPageId(arrow) === editor.getAncestorPageId(boundShape) + if (!boundShape || !isShapeInSamePageAsArrow) { + unbindArrowTerminal(editor, arrow, handle) + } + } + + // always check the arrow parents + reparentArrow(editor, arrow.id) +} + +function unbindArrowTerminal(editor: Editor, arrow: TLArrowShape, terminal: 'start' | 'end') { + const info = editor.getArrowInfo(arrow)! + if (!info) { + throw new Error('expected arrow info') + } + + const update = { + id: arrow.id, + type: 'arrow', + props: { + [terminal]: { x: info[terminal].point.x, y: info[terminal].point.y }, + bend: arrow.props.bend, + }, + } satisfies TLShapePartial + + // fix up the bend: + if (!info.isStraight) { + // find the new start/end points of the resulting arrow + const newStart = terminal === 'start' ? info.start.point : info.start.handle + const newEnd = terminal === 'end' ? info.end.point : info.end.handle + const newMidPoint = Vec.Med(newStart, newEnd) + + // intersect a line segment perpendicular to the new arrow with the old arrow arc to + // find the new mid-point + const lineSegment = Vec.Sub(newStart, newEnd) + .per() + .uni() + .mul(info.handleArc.radius * 2 * Math.sign(arrow.props.bend)) + + // find the intersections with the old arrow arc: + const intersections = intersectLineSegmentCircle( + info.handleArc.center, + Vec.Add(newMidPoint, lineSegment), + info.handleArc.center, + info.handleArc.radius + ) + + assert(intersections?.length === 1) + const bend = Vec.Dist(newMidPoint, intersections[0]) * Math.sign(arrow.props.bend) + update.props.bend = bend + } + + editor.updateShape(update) + removeArrowBinding(editor, arrow, terminal) +} diff --git a/packages/tldraw/src/lib/defaultBindingUtils.ts b/packages/tldraw/src/lib/defaultBindingUtils.ts new file mode 100644 index 000000000..42e00c424 --- /dev/null +++ b/packages/tldraw/src/lib/defaultBindingUtils.ts @@ -0,0 +1,5 @@ +import { TLAnyBindingUtilConstructor } from '@tldraw/editor' +import { ArrowBindingUtil } from './bindings/arrow/ArrowBindingUtil' + +/** @public */ +export const defaultBindingUtils: TLAnyBindingUtilConstructor[] = [ArrowBindingUtil] diff --git a/packages/tldraw/src/lib/shapes/arrow/ArrowShapeTool.test.ts b/packages/tldraw/src/lib/shapes/arrow/ArrowShapeTool.test.ts index 50e698626..5e4c9eb65 100644 --- a/packages/tldraw/src/lib/shapes/arrow/ArrowShapeTool.test.ts +++ b/packages/tldraw/src/lib/shapes/arrow/ArrowShapeTool.test.ts @@ -1,4 +1,11 @@ -import { IndexKey, TLArrowShape, Vec, createShapeId } from '@tldraw/editor' +import { + IndexKey, + TLArrowShape, + TLShapeId, + Vec, + createShapeId, + getArrowBindings, +} from '@tldraw/editor' import { TestEditor } from '../../../test/TestEditor' let editor: TestEditor @@ -19,6 +26,10 @@ const ids = { box3: createShapeId('box3'), } +function bindings(id: TLShapeId) { + return getArrowBindings(editor, editor.getShape(id) as TLArrowShape) +} + beforeEach(() => { editor = new TestEditor() editor @@ -89,10 +100,11 @@ describe('When dragging the arrow', () => { x: 0, y: 0, props: { - start: { type: 'point', x: 0, y: 0 }, - end: { type: 'point', x: 10, y: 10 }, + start: { x: 0, y: 0 }, + end: { x: 10, y: 10 }, }, }) + expect(bindings(arrow.id)).toMatchObject({ start: undefined, end: undefined }) editor.expectToBeIn('select.dragging_handle') }) @@ -146,15 +158,20 @@ describe('When pointing a start shape', () => { x: 375, y: 375, props: { - start: { - type: 'binding', + start: { x: 0, y: 0 }, + end: { x: 0, y: 125 }, + }, + }) + expect(bindings(arrow.id)).toMatchObject({ + start: { + toId: ids.box3, + props: { isExact: false, normalizedAnchor: { x: 0.5, y: 0.5 }, // center! isPrecise: false, - boundShapeId: ids.box3, }, - end: { type: 'point', x: 0, y: 125 }, }, + end: undefined, }) editor.pointerUp() @@ -187,13 +204,17 @@ describe('When pointing an end shape', () => { x: 0, y: 0, props: { - start: { type: 'point', x: 0, y: 0 }, - end: { - type: 'binding', + start: { x: 0, y: 0 }, + }, + }) + expect(bindings(arrow.id)).toMatchObject({ + start: undefined, + end: { + toId: ids.box3, + props: { isExact: false, normalizedAnchor: { x: 0.5, y: 0.5 }, // center! isPrecise: false, - boundShapeId: ids.box3, }, }, }) @@ -214,19 +235,14 @@ describe('When pointing an end shape', () => { expect(editor.getHintingShapeIds().length).toBe(1) - editor.expectShapeToMatch(arrow, { - id: arrow.id, - type: 'arrow', - x: 0, - y: 0, - props: { - start: { type: 'point', x: 0, y: 0 }, - end: { - type: 'binding', + expect(bindings(arrow.id)).toMatchObject({ + start: undefined, + end: { + toId: ids.box3, + props: { isExact: false, normalizedAnchor: { x: 0.5, y: 0.5 }, isPrecise: false, - boundShapeId: ids.box3, }, }, }) @@ -235,19 +251,14 @@ describe('When pointing an end shape', () => { arrow = editor.getCurrentPageShapes()[editor.getCurrentPageShapes().length - 1] - editor.expectShapeToMatch(arrow, { - id: arrow.id, - type: 'arrow', - x: 0, - y: 0, - props: { - start: { type: 'point', x: 0, y: 0 }, - end: { - type: 'binding', + expect(bindings(arrow.id)).toMatchObject({ + start: undefined, + end: { + toId: ids.box3, + props: { isExact: false, normalizedAnchor: { x: 0.5, y: 0.5 }, isPrecise: true, - boundShapeId: ids.box3, }, }, }) @@ -262,10 +273,14 @@ describe('When pointing an end shape', () => { x: 0, y: 0, props: { - start: { type: 'point', x: 0, y: 0 }, - end: { type: 'point', x: 375, y: 0 }, + start: { x: 0, y: 0 }, + end: { x: 375, y: 0 }, }, }) + expect(bindings(arrow.id)).toMatchObject({ + start: undefined, + end: undefined, + }) // Build up some velocity editor.inputs.pointerVelocity = new Vec(1, 1) @@ -280,13 +295,17 @@ describe('When pointing an end shape', () => { x: 0, y: 0, props: { - start: { type: 'point', x: 0, y: 0 }, - end: { - type: 'binding', + start: { x: 0, y: 0 }, + }, + }) + expect(bindings(arrow.id)).toMatchObject({ + start: undefined, + end: { + toId: ids.box2, + props: { isExact: false, normalizedAnchor: { x: 0.25, y: 0.25 }, // center! isPrecise: false, - boundShapeId: ids.box2, }, }, }) @@ -296,18 +315,14 @@ describe('When pointing an end shape', () => { arrow = editor.getCurrentPageShapes()[editor.getCurrentPageShapes().length - 1] - editor.expectShapeToMatch(arrow, { - id: arrow.id, - type: 'arrow', - x: 0, - y: 0, - props: { - start: { type: 'point', x: 0, y: 0 }, - end: { - type: 'binding', + expect(bindings(arrow.id)).toMatchObject({ + start: undefined, + end: { + toId: ids.box2, + props: { isExact: false, normalizedAnchor: { x: 0.25, y: 0.25 }, // precise! - boundShapeId: ids.box2, + isPrecise: true, }, }, }) @@ -325,19 +340,14 @@ describe('When pointing an end shape', () => { expect(editor.getHintingShapeIds().length).toBe(1) - editor.expectShapeToMatch(arrow, { - id: arrow.id, - type: 'arrow', - x: 0, - y: 0, - props: { - start: { type: 'point', x: 0, y: 0 }, - end: { - type: 'binding', + expect(bindings(arrow.id)).toMatchObject({ + start: undefined, + end: { + toId: ids.box3, + props: { isExact: false, normalizedAnchor: { x: 0.4, y: 0.4 }, isPrecise: false, - boundShapeId: ids.box3, }, }, }) @@ -348,15 +358,9 @@ describe('When pointing an end shape', () => { let arrow = editor.getCurrentPageShapes()[editor.getCurrentPageShapes().length - 1] - editor.expectShapeToMatch(arrow, { - id: arrow.id, - type: 'arrow', - x: 0, - y: 0, - props: { - start: { type: 'point', x: 0, y: 0 }, - end: { type: 'point', x: 2, y: 0 }, - }, + expect(bindings(arrow.id)).toMatchObject({ + start: undefined, + end: undefined, }) expect(editor.getHintingShapeIds().length).toBe(0) @@ -373,14 +377,15 @@ describe('When pointing an end shape', () => { type: 'arrow', x: 0, y: 0, - props: { - start: { type: 'point', x: 0, y: 0 }, - end: { - type: 'binding', + }) + expect(bindings(arrow.id)).toMatchObject({ + start: undefined, + end: { + toId: ids.box3, + props: { isExact: false, normalizedAnchor: { x: 0.5, y: 0.5 }, isPrecise: true, - boundShapeId: ids.box3, }, }, }) @@ -423,8 +428,8 @@ describe('reparenting issue', () => { editor.expectShapeToMatch({ id: arrowId, index: 'a3V' as IndexKey, - props: { end: { boundShapeId: ids.box2 } }, }) // between box 2 (a3) and 3 (a4) + expect(bindings(arrowId)).toMatchObject({ end: { toId: ids.box2 } }) expect(editor.getShapeAtPoint({ x: 350, y: 350 }, { hitInside: true })).toMatchObject({ id: ids.box3, @@ -434,8 +439,8 @@ describe('reparenting issue', () => { editor.expectShapeToMatch({ id: arrowId, index: 'a5' as IndexKey, - props: { end: { boundShapeId: ids.box3 } }, }) // above box 3 (a4) + expect(bindings(arrowId)).toMatchObject({ end: { toId: ids.box3 } }) editor.pointerMove(150, 150) // over box 1 editor.expectShapeToMatch({ id: arrowId, index: 'a2V' as IndexKey }) // between box 1 (a2) and box 3 (a3) @@ -465,14 +470,14 @@ describe('reparenting issue', () => { type: 'arrow', x: 0, y: 0, - props: { start: { type: 'point', x: 0, y: 0 }, end: { type: 'point', x: 100, y: 100 } }, + props: { start: { x: 0, y: 0 }, end: { x: 100, y: 100 } }, }, { id: arrow2Id, type: 'arrow', x: 0, y: 0, - props: { start: { type: 'point', x: 0, y: 0 }, end: { type: 'point', x: 100, y: 100 } }, + props: { start: { x: 0, y: 0 }, end: { x: 100, y: 100 } }, }, ]) @@ -530,8 +535,8 @@ describe('line bug', () => { .keyUp('Shift') expect(editor.getCurrentPageShapes().length).toBe(2) - const arrow = editor.getCurrentPageShapes()[1] as TLArrowShape - expect(arrow.props.end.type).toBe('binding') + const bindings = getArrowBindings(editor, editor.getCurrentPageShapes()[1] as TLArrowShape) + expect(bindings.end).toBeDefined() }) it('works as expected when binding to a straight horizontal line', () => { @@ -552,7 +557,7 @@ describe('line bug', () => { .pointerUp() expect(editor.getCurrentPageShapes().length).toBe(2) - const arrow = editor.getCurrentPageShapes()[1] as TLArrowShape - expect(arrow.props.end.type).toBe('binding') + const bindings = getArrowBindings(editor, editor.getCurrentPageShapes()[1] as TLArrowShape) + expect(bindings.end).toBeDefined() }) }) diff --git a/packages/tldraw/src/lib/shapes/arrow/ArrowShapeUtil.test.ts b/packages/tldraw/src/lib/shapes/arrow/ArrowShapeUtil.test.ts index 2105d6d4a..d8e0b7b61 100644 --- a/packages/tldraw/src/lib/shapes/arrow/ArrowShapeUtil.test.ts +++ b/packages/tldraw/src/lib/shapes/arrow/ArrowShapeUtil.test.ts @@ -1,10 +1,10 @@ import { - assert, - createShapeId, HALF_PI, TLArrowShape, - TLArrowShapeTerminal, TLShapeId, + createOrUpdateArrowBinding, + createShapeId, + getArrowBindings, } from '@tldraw/editor' import { TestEditor } from '../../../test/TestEditor' @@ -28,6 +28,13 @@ window.cancelAnimationFrame = function cancelAnimationFrame(id) { clearTimeout(id) } +function arrow(id = ids.arrow1) { + return editor.getShape(id) as TLArrowShape +} +function bindings(id = ids.arrow1) { + return getArrowBindings(editor, arrow(id)) +} + beforeEach(() => { editor = new TestEditor() editor @@ -42,23 +49,25 @@ beforeEach(() => { x: 150, y: 150, props: { - start: { - type: 'binding', - isExact: false, - boundShapeId: ids.box1, - normalizedAnchor: { x: 0.5, y: 0.5 }, - isPrecise: false, - }, - end: { - type: 'binding', - isExact: false, - boundShapeId: ids.box2, - normalizedAnchor: { x: 0.5, y: 0.5 }, - isPrecise: false, - }, + start: { x: 0, y: 0 }, + end: { x: 0, y: 0 }, }, }, ]) + + createOrUpdateArrowBinding(editor, ids.arrow1, ids.box1, { + terminal: 'start', + isExact: false, + isPrecise: false, + normalizedAnchor: { x: 0.5, y: 0.5 }, + }) + + createOrUpdateArrowBinding(editor, ids.arrow1, ids.box2, { + terminal: 'end', + isExact: false, + isPrecise: false, + normalizedAnchor: { x: 0.5, y: 0.5 }, + }) }) describe('When translating a bound shape', () => { @@ -77,17 +86,23 @@ describe('When translating a bound shape', () => { x: 150, y: 150, props: { - start: { - type: 'binding', + start: { x: 0, y: 0 }, + end: { x: 0, y: 0 }, + }, + }) + expect(bindings()).toMatchObject({ + start: { + toId: ids.box1, + props: { isExact: false, - boundShapeId: ids.box1, normalizedAnchor: { x: 0.5, y: 0.5 }, isPrecise: false, }, - end: { - type: 'binding', + }, + end: { + toId: ids.box2, + props: { isExact: false, - boundShapeId: ids.box2, normalizedAnchor: { x: 0.5, y: 0.5 }, isPrecise: false, }, @@ -111,17 +126,24 @@ describe('When translating a bound shape', () => { x: 150, y: 150, props: { - start: { - type: 'binding', + start: { x: 0, y: 0 }, + end: { x: 0, y: 0 }, + bend: 20, + }, + }) + expect(bindings()).toMatchObject({ + start: { + toId: ids.box1, + props: { isExact: false, - boundShapeId: ids.box1, normalizedAnchor: { x: 0.5, y: 0.5 }, isPrecise: false, }, - end: { - type: 'binding', + }, + end: { + toId: ids.box2, + props: { isExact: false, - boundShapeId: ids.box2, normalizedAnchor: { x: 0.5, y: 0.5 }, isPrecise: false, }, @@ -147,17 +169,23 @@ describe('When translating the arrow', () => { x: 150, y: 100, props: { - start: { - type: 'binding', + start: { x: 0, y: 0 }, + end: { x: 0, y: 0 }, + }, + }) + expect(bindings()).toMatchObject({ + start: { + toId: ids.box1, + props: { isExact: false, - boundShapeId: ids.box1, normalizedAnchor: { x: 0.5, y: 0.5 }, isPrecise: false, }, - end: { - type: 'binding', + }, + end: { + toId: ids.box2, + props: { isExact: false, - boundShapeId: ids.box2, normalizedAnchor: { x: 0.5, y: 0.5 }, isPrecise: false, }, @@ -172,23 +200,18 @@ describe('Other cases when arrow are moved', () => { // When box one is not selected, unbinds box1 and keeps binding to box2 editor.nudgeShapes(editor.getSelectedShapeIds(), { x: 0, y: -1 }) - - expect(editor.getShape(ids.arrow1)).toMatchObject({ - props: { - start: { type: 'binding', boundShapeId: ids.box1 }, - end: { type: 'binding', boundShapeId: ids.box2 }, - }, + expect(bindings()).toMatchObject({ + start: { toId: ids.box1, props: { isPrecise: false } }, + end: { toId: ids.box2, props: { isPrecise: false } }, }) // when only the arrow is selected, we keep the binding but make it precise: editor.select(ids.arrow1) editor.nudgeShapes(editor.getSelectedShapeIds(), { x: 0, y: -1 }) - expect(editor.getShape(ids.arrow1)).toMatchObject({ - props: { - start: { type: 'binding', boundShapeId: ids.box1, isPrecise: true }, - end: { type: 'binding', boundShapeId: ids.box2, isPrecise: true }, - }, + expect(bindings()).toMatchObject({ + start: { toId: ids.box1, props: { isPrecise: true } }, + end: { toId: ids.box2, props: { isPrecise: true } }, }) }) @@ -200,11 +223,9 @@ describe('Other cases when arrow are moved', () => { editor.alignShapes(editor.getSelectedShapeIds(), 'right') jest.advanceTimersByTime(1000) - expect(editor.getShape(ids.arrow1)).toMatchObject({ - props: { - start: { type: 'binding', boundShapeId: ids.box1 }, - end: { type: 'binding', boundShapeId: ids.box2 }, - }, + expect(bindings()).toMatchObject({ + start: { toId: ids.box1, props: { isPrecise: false } }, + end: { toId: ids.box2, props: { isPrecise: false } }, }) // maintains bindings if they would still be over the same shape (but makes them precise), but unbinds others @@ -212,16 +233,9 @@ describe('Other cases when arrow are moved', () => { editor.alignShapes(editor.getSelectedShapeIds(), 'top') jest.advanceTimersByTime(1000) - expect(editor.getShape(ids.arrow1)).toMatchObject({ - props: { - start: { - type: 'binding', - isPrecise: true, - }, - end: { - type: 'point', - }, - }, + expect(bindings()).toMatchObject({ + start: { toId: ids.box1, props: { isPrecise: true } }, + end: undefined, }) }) @@ -236,17 +250,9 @@ describe('Other cases when arrow are moved', () => { editor.distributeShapes(editor.getSelectedShapeIds(), 'horizontal') jest.advanceTimersByTime(1000) - expect(editor.getShape(ids.arrow1)).toMatchObject({ - props: { - start: { - type: 'binding', - boundShapeId: ids.box1, - }, - end: { - type: 'binding', - boundShapeId: ids.box2, - }, - }, + expect(bindings()).toMatchObject({ + start: { toId: ids.box1, props: { isPrecise: false } }, + end: { toId: ids.box2, props: { isPrecise: false } }, }) // unbinds when only the arrow is selected (not its bound shapes) if the arrow itself has moved @@ -255,17 +261,9 @@ describe('Other cases when arrow are moved', () => { jest.advanceTimersByTime(1000) // The arrow didn't actually move - expect(editor.getShape(ids.arrow1)).toMatchObject({ - props: { - start: { - type: 'binding', - boundShapeId: ids.box1, - }, - end: { - type: 'binding', - boundShapeId: ids.box2, - }, - }, + expect(bindings()).toMatchObject({ + start: { toId: ids.box1, props: { isPrecise: false } }, + end: { toId: ids.box2, props: { isPrecise: false } }, }) // The arrow will move this time, so it should unbind @@ -273,15 +271,9 @@ describe('Other cases when arrow are moved', () => { editor.distributeShapes(editor.getSelectedShapeIds(), 'vertical') jest.advanceTimersByTime(1000) - expect(editor.getShape(ids.arrow1)).toMatchObject({ - props: { - start: { - type: 'point', - }, - end: { - type: 'point', - }, - }, + expect(bindings()).toMatchObject({ + start: undefined, + end: undefined, }) }) @@ -298,57 +290,44 @@ describe('Other cases when arrow are moved', () => { .groupShapes(editor.getSelectedShapeIds()) editor.setCurrentTool('arrow').pointerDown(1000, 1000).pointerMove(50, 350).pointerUp(50, 350) - let arrow = editor.getCurrentPageShapes()[editor.getCurrentPageShapes().length - 1] - assert(editor.isShapeOfType(arrow, 'arrow')) - assert(arrow.props.end.type === 'binding') - expect(arrow.props.end.boundShapeId).toBe(ids.box3) + const arrowId = editor.getOnlySelectedShape()!.id + expect(bindings(arrowId).end?.toId).toBe(ids.box3) // translate: editor.selectAll().nudgeShapes(editor.getSelectedShapeIds(), { x: 0, y: 1 }) // arrow should still be bound to box3 - arrow = editor.getShape(arrow.id)! - assert(editor.isShapeOfType(arrow, 'arrow')) - assert(arrow.props.end.type === 'binding') - expect(arrow.props.end.boundShapeId).toBe(ids.box3) + expect(bindings(arrowId).end?.toId).toBe(ids.box3) }) }) describe('When a shape is rotated', () => { it('binds correctly', () => { editor.setCurrentTool('arrow').pointerDown(0, 0).pointerMove(375, 375) + const arrowId = editor.getCurrentPageShapes()[editor.getCurrentPageShapes().length - 1].id - const arrow = editor.getCurrentPageShapes()[editor.getCurrentPageShapes().length - 1] - - expect(editor.getShape(arrow.id)).toMatchObject({ - props: { - start: { type: 'point' }, - end: { - type: 'binding', - boundShapeId: ids.box2, + expect(bindings(arrowId)).toMatchObject({ + start: undefined, + end: { + toId: ids.box2, + props: { normalizedAnchor: { x: 0.75, y: 0.75 }, // moving slowly }, }, }) editor.updateShapes([{ id: ids.box2, type: 'geo', rotation: HALF_PI }]) - editor.pointerMove(225, 350) - expect(editor.getShape(arrow.id)).toMatchObject({ - props: { - start: { type: 'point' }, - end: { type: 'binding', boundShapeId: ids.box2 }, + expect(bindings(arrowId)).toCloselyMatchObject({ + start: undefined, + end: { + toId: ids.box2, + props: { + normalizedAnchor: { x: 0.5, y: 0.75 }, // moving slowly + }, }, }) - - const anchor = ( - editor.getShape(arrow.id)!.props.end as TLArrowShapeTerminal & { - type: 'binding' - } - ).normalizedAnchor - expect(anchor.x).toBeCloseTo(0.5) - expect(anchor.y).toBeCloseTo(0.75) }) }) @@ -362,8 +341,7 @@ describe('Arrow labels', () => { it('should create an arrow with a label', () => { const arrowId = editor.getOnlySelectedShape()!.id - const arrow = editor.getShape(arrowId) - expect(arrow).toMatchObject({ + expect(arrow(arrowId)).toMatchObject({ props: { text: 'Test Label', }, @@ -373,8 +351,7 @@ describe('Arrow labels', () => { it('should update the label of an arrow', () => { const arrowId = editor.getOnlySelectedShape()!.id editor.updateShapes([{ id: arrowId, type: 'arrow', props: { text: 'New Label' } }]) - const arrow = editor.getShape(arrowId) - expect(arrow).toMatchObject({ + expect(arrow(arrowId)).toMatchObject({ props: { text: 'New Label', }, @@ -533,32 +510,22 @@ describe("an arrow's parents", () => { editor.pointerDown(15, 15).pointerMove(50, 50) const arrowId = editor.getOnlySelectedShape()!.id - expect(editor.getShape(arrowId)).toMatchObject({ - props: { - start: { type: 'binding', boundShapeId: boxAid }, - end: { type: 'binding', boundShapeId: frameId }, - }, - }) - expect(editor.getShape(arrowId)?.parentId).toBe(editor.getCurrentPageId()) + expect(arrow(arrowId).parentId).toBe(editor.getCurrentPageId()) // move arrow to b editor.pointerMove(15, 85) - expect(editor.getShape(arrowId)?.parentId).toBe(frameId) - expect(editor.getShape(arrowId)).toMatchObject({ - props: { - start: { type: 'binding', boundShapeId: boxAid }, - end: { type: 'binding', boundShapeId: boxBid }, - }, + expect(arrow(arrowId).parentId).toBe(frameId) + expect(bindings(arrowId)).toMatchObject({ + start: { toId: boxAid }, + end: { toId: boxBid }, }) // move back to empty space editor.pointerMove(50, 50) - expect(editor.getShape(arrowId)?.parentId).toBe(editor.getCurrentPageId()) - expect(editor.getShape(arrowId)).toMatchObject({ - props: { - start: { type: 'binding', boundShapeId: boxAid }, - end: { type: 'binding', boundShapeId: frameId }, - }, + expect(arrow(arrowId).parentId).toBe(editor.getCurrentPageId()) + expect(bindings(arrowId)).toMatchObject({ + start: { toId: boxAid }, + end: { toId: frameId }, }) }) @@ -568,21 +535,21 @@ describe("an arrow's parents", () => { editor.pointerDown(15, 15).pointerMove(15, 85).pointerUp() const arrowId = editor.getOnlySelectedShape()!.id - expect(editor.getShape(arrowId)).toMatchObject({ + expect(arrow(arrowId)).toMatchObject({ parentId: frameId, - props: { - start: { type: 'binding', boundShapeId: boxAid }, - end: { type: 'binding', boundShapeId: boxBid }, - }, + }) + expect(bindings(arrowId)).toMatchObject({ + start: { toId: boxAid }, + end: { toId: boxBid }, }) // move b outside of frame editor.select(boxBid).translateSelection(200, 0) - expect(editor.getShape(arrowId)).toMatchObject({ + expect(arrow(arrowId)).toMatchObject({ parentId: editor.getCurrentPageId(), - props: { - start: { type: 'binding', boundShapeId: boxAid }, - end: { type: 'binding', boundShapeId: boxBid }, - }, + }) + expect(bindings(arrowId)).toMatchObject({ + start: { toId: boxAid }, + end: { toId: boxBid }, }) }) @@ -591,12 +558,12 @@ describe("an arrow's parents", () => { editor.setCurrentTool('arrow') editor.pointerDown(15, 15).pointerMove(115, 15).pointerUp() const arrowId = editor.getOnlySelectedShape()!.id - expect(editor.getShape(arrowId)).toMatchObject({ + expect(arrow(arrowId)).toMatchObject({ parentId: editor.getCurrentPageId(), - props: { - start: { type: 'binding', boundShapeId: boxAid }, - end: { type: 'binding', boundShapeId: boxCid }, - }, + }) + expect(bindings(arrowId)).toMatchObject({ + start: { toId: boxAid }, + end: { toId: boxCid }, }) // move c inside of frame @@ -604,10 +571,10 @@ describe("an arrow's parents", () => { expect(editor.getShape(arrowId)).toMatchObject({ parentId: frameId, - props: { - start: { type: 'binding', boundShapeId: boxAid }, - end: { type: 'binding', boundShapeId: boxCid }, - }, + }) + expect(bindings(arrowId)).toMatchObject({ + start: { toId: boxAid }, + end: { toId: boxCid }, }) }) }) diff --git a/packages/tldraw/src/lib/shapes/arrow/ArrowShapeUtil.tsx b/packages/tldraw/src/lib/shapes/arrow/ArrowShapeUtil.tsx index 7036ee1b6..a5378b47b 100644 --- a/packages/tldraw/src/lib/shapes/arrow/ArrowShapeUtil.tsx +++ b/packages/tldraw/src/lib/shapes/arrow/ArrowShapeUtil.tsx @@ -9,12 +9,14 @@ import { SVGContainer, ShapeUtil, SvgExportContext, + TLArrowBinding, + TLArrowBindings, TLArrowShape, - TLArrowShapeProps, TLHandle, TLOnEditEndHandler, TLOnHandleDragHandler, TLOnResizeHandler, + TLOnResizeStartHandler, TLOnTranslateHandler, TLOnTranslateStartHandler, TLShapePartial, @@ -23,10 +25,12 @@ import { Vec, arrowShapeMigrations, arrowShapeProps, + createOrUpdateArrowBinding, + getArrowBindings, getArrowTerminalsInArrowSpace, getDefaultColorTheme, mapObjectMapValues, - objectMapEntries, + removeArrowBinding, structuredClone, toDomPrecision, track, @@ -75,6 +79,11 @@ export class ArrowShapeUtil extends ShapeUtil { override hideSelectionBoundsBg: TLShapeUtilFlag = () => true override hideSelectionBoundsFg: TLShapeUtilFlag = () => true + override canBeLaidOut: TLShapeUtilFlag = (shape) => { + const bindings = getArrowBindings(this.editor, shape) + return !bindings.start && !bindings.end + } + override getDefaultProps(): TLArrowShape['props'] { return { dash: 'draw', @@ -83,8 +92,8 @@ export class ArrowShapeUtil extends ShapeUtil { color: 'black', labelColor: 'black', bend: 0, - start: { type: 'point', x: 0, y: 0 }, - end: { type: 'point', x: 2, y: 0 }, + start: { x: 0, y: 0 }, + end: { x: 2, y: 0 }, arrowheadStart: 'none', arrowheadEnd: 'arrow', text: '', @@ -164,10 +173,11 @@ export class ArrowShapeUtil extends ShapeUtil { override onHandleDrag: TLOnHandleDragHandler = (shape, { handle, isPrecise }) => { const handleId = handle.id as ARROW_HANDLES + const bindings = getArrowBindings(this.editor, shape) if (handleId === ARROW_HANDLES.MIDDLE) { // Bending the arrow... - const { start, end } = getArrowTerminalsInArrowSpace(this.editor, shape) + const { start, end } = getArrowTerminalsInArrowSpace(this.editor, shape, bindings) const delta = Vec.Sub(end, start) const v = Vec.Per(delta) @@ -184,17 +194,23 @@ export class ArrowShapeUtil extends ShapeUtil { // Start or end, pointing the arrow... - const next = structuredClone(shape) as TLArrowShape + const update: TLShapePartial = { id: shape.id, type: 'arrow', props: {} } + + const currentBinding = bindings[handleId] + + const otherHandleId = handleId === ARROW_HANDLES.START ? ARROW_HANDLES.END : ARROW_HANDLES.START + const otherBinding = bindings[otherHandleId] if (this.editor.inputs.ctrlKey) { // todo: maybe double check that this isn't equal to the other handle too? // Skip binding - next.props[handleId] = { - type: 'point', + removeArrowBinding(this.editor, shape, handleId) + + update.props![handleId] = { x: handle.x, y: handle.y, } - return next + return update } const point = this.editor.getShapePageTransform(shape.id)!.applyToPoint(handle) @@ -210,19 +226,20 @@ export class ArrowShapeUtil extends ShapeUtil { if (!target) { // todo: maybe double check that this isn't equal to the other handle too? - next.props[handleId] = { - type: 'point', + removeArrowBinding(this.editor, shape, handleId) + + update.props![handleId] = { x: handle.x, y: handle.y, } - return next + return update } // we've got a target! the handle is being dragged over a shape, bind to it const targetGeometry = this.editor.getShapeGeometry(target) const targetBounds = Box.ZeroFix(targetGeometry.bounds) - const pageTransform = this.editor.getShapePageTransform(next.id)! + const pageTransform = this.editor.getShapePageTransform(update.id)! const pointInPageSpace = pageTransform.applyToPoint(handle) const pointInTargetSpace = this.editor.getPointInShapeSpace(target, pointInPageSpace) @@ -230,11 +247,7 @@ export class ArrowShapeUtil extends ShapeUtil { if (!precise) { // If we're switching to a new bound shape, then precise only if moving slowly - const prevHandle = next.props[handleId] - if ( - prevHandle.type === 'point' || - (prevHandle.type === 'binding' && target.id !== prevHandle.boundShapeId) - ) { + if (!currentBinding || (currentBinding && target.id !== currentBinding.toId)) { precise = this.editor.inputs.pointerVelocity.len() < 0.5 } } @@ -246,13 +259,7 @@ export class ArrowShapeUtil extends ShapeUtil { // Double check that we're not going to be doing an imprecise snap on // the same shape twice, as this would result in a zero length line - const otherHandle = - next.props[handleId === ARROW_HANDLES.START ? ARROW_HANDLES.END : ARROW_HANDLES.START] - if ( - otherHandle.type === 'binding' && - target.id === otherHandle.boundShapeId && - otherHandle.isPrecise - ) { + if (otherBinding && target.id === otherBinding.toId && otherBinding.props.isPrecise) { precise = true } } @@ -276,64 +283,66 @@ export class ArrowShapeUtil extends ShapeUtil { } } - next.props[handleId] = { - type: 'binding', - boundShapeId: target.id, - normalizedAnchor: normalizedAnchor, + const b = { + terminal: handleId, + normalizedAnchor, isPrecise: precise, isExact: this.editor.inputs.altKey, } - if (next.props.start.type === 'binding' && next.props.end.type === 'binding') { - if (next.props.start.boundShapeId === next.props.end.boundShapeId) { - if (Vec.Equals(next.props.start.normalizedAnchor, next.props.end.normalizedAnchor)) { - next.props.end.normalizedAnchor.x += 0.05 - } + createOrUpdateArrowBinding(this.editor, shape, target.id, b) + + this.editor.setHintingShapes([target.id]) + + const newBindings = getArrowBindings(this.editor, shape) + if (newBindings.start && newBindings.end && newBindings.start.toId === newBindings.end.toId) { + if ( + Vec.Equals(newBindings.start.props.normalizedAnchor, newBindings.end.props.normalizedAnchor) + ) { + createOrUpdateArrowBinding(this.editor, shape, newBindings.end.toId, { + ...newBindings.end.props, + normalizedAnchor: { + x: newBindings.end.props.normalizedAnchor.x + 0.05, + y: newBindings.end.props.normalizedAnchor.y, + }, + }) } } - return next + return update } override onTranslateStart: TLOnTranslateStartHandler = (shape) => { - const startBindingId = - shape.props.start.type === 'binding' ? shape.props.start.boundShapeId : null - const endBindingId = shape.props.end.type === 'binding' ? shape.props.end.boundShapeId : null + const bindings = getArrowBindings(this.editor, shape) - const terminalsInArrowSpace = getArrowTerminalsInArrowSpace(this.editor, shape) + const terminalsInArrowSpace = getArrowTerminalsInArrowSpace(this.editor, shape, bindings) const shapePageTransform = this.editor.getShapePageTransform(shape.id)! // If at least one bound shape is in the selection, do nothing; // If no bound shapes are in the selection, unbind any bound shapes const selectedShapeIds = this.editor.getSelectedShapeIds() - const shapesToCheck = new Set() - if (startBindingId) { - // Add shape and all ancestors to set - shapesToCheck.add(startBindingId) - this.editor.getShapeAncestors(startBindingId).forEach((a) => shapesToCheck.add(a.id)) - } - if (endBindingId) { - // Add shape and all ancestors to set - shapesToCheck.add(endBindingId) - this.editor.getShapeAncestors(endBindingId).forEach((a) => shapesToCheck.add(a.id)) - } - // If any of the shapes are selected, return - for (const id of selectedShapeIds) { - if (shapesToCheck.has(id)) return - } - let result = shape + if ( + (bindings.start && + (selectedShapeIds.includes(bindings.start.toId) || + this.editor.isAncestorSelected(bindings.start.toId))) || + (bindings.end && + (selectedShapeIds.includes(bindings.end.toId) || + this.editor.isAncestorSelected(bindings.end.toId))) + ) { + return + } // When we start translating shapes, record where their bindings were in page space so we // can maintain them as we translate the arrow shapeAtTranslationStart.set(shape, { pagePosition: shapePageTransform.applyToPoint(shape), terminalBindings: mapObjectMapValues(terminalsInArrowSpace, (terminalName, point) => { - const terminal = shape.props[terminalName] - if (terminal.type !== 'binding') return null + const binding = bindings[terminalName] + if (!binding) return null return { - binding: terminal, + binding, shapePosition: point, pagePosition: shapePageTransform.applyToPoint(point), } @@ -341,15 +350,16 @@ export class ArrowShapeUtil extends ShapeUtil { }) for (const handleName of [ARROW_HANDLES.START, ARROW_HANDLES.END] as const) { - const terminal = shape.props[handleName] - if (terminal.type !== 'binding') continue - result = { - ...shape, - props: { ...shape.props, [handleName]: { ...terminal, isPrecise: true } }, - } + const binding = bindings[handleName] + if (!binding) continue + + this.editor.updateBinding({ + ...binding, + props: { ...binding.props, isPrecise: true }, + }) } - return result + return } override onTranslate?: TLOnTranslateHandler = (initialShape, shape) => { @@ -362,10 +372,7 @@ export class ArrowShapeUtil extends ShapeUtil { atTranslationStart.pagePosition ) - let result = shape - for (const [terminalName, terminalBinding] of objectMapEntries( - atTranslationStart.terminalBindings - )) { + for (const terminalBinding of Object.values(atTranslationStart.terminalBindings)) { if (!terminalBinding) continue const newPagePoint = Vec.Add(terminalBinding.pagePosition, Vec.Mul(pageDelta, 0.5)) @@ -378,54 +385,46 @@ export class ArrowShapeUtil extends ShapeUtil { }, }) - if (newTarget?.id === terminalBinding.binding.boundShapeId) { + if (newTarget?.id === terminalBinding.binding.toId) { const targetBounds = Box.ZeroFix(this.editor.getShapeGeometry(newTarget).bounds) const pointInTargetSpace = this.editor.getPointInShapeSpace(newTarget, newPagePoint) const normalizedAnchor = { x: (pointInTargetSpace.x - targetBounds.minX) / targetBounds.width, y: (pointInTargetSpace.y - targetBounds.minY) / targetBounds.height, } - result = { - ...result, - props: { - ...result.props, - [terminalName]: { ...terminalBinding.binding, isPrecise: true, normalizedAnchor }, - }, - } + createOrUpdateArrowBinding(this.editor, shape, newTarget.id, { + ...terminalBinding.binding.props, + normalizedAnchor, + isPrecise: true, + }) } else { - result = { - ...result, - props: { - ...result.props, - [terminalName]: { - type: 'point', - x: terminalBinding.shapePosition.x, - y: terminalBinding.shapePosition.y, - }, - }, - } + removeArrowBinding(this.editor, shape, terminalBinding.binding.props.terminal) } } - - return result } + // replace this with memo bag? + private _resizeInitialBindings: TLArrowBindings = { start: undefined, end: undefined } + override onResizeStart?: TLOnResizeStartHandler = (shape) => { + this._resizeInitialBindings = getArrowBindings(this.editor, shape) + } override onResize: TLOnResizeHandler = (shape, info) => { const { scaleX, scaleY } = info - const terminals = getArrowTerminalsInArrowSpace(this.editor, shape) + const bindings = this._resizeInitialBindings + const terminals = getArrowTerminalsInArrowSpace(this.editor, shape, bindings) const { start, end } = structuredClone(shape.props) let { bend } = shape.props // Rescale start handle if it's not bound to a shape - if (start.type === 'point') { + if (!bindings.start) { start.x = terminals.start.x * scaleX start.y = terminals.start.y * scaleY } // Rescale end handle if it's not bound to a shape - if (end.type === 'point') { + if (!bindings.end) { end.x = terminals.end.x * scaleX end.y = terminals.end.y * scaleY } @@ -436,18 +435,23 @@ export class ArrowShapeUtil extends ShapeUtil { const mx = Math.abs(scaleX) const my = Math.abs(scaleY) + const startNormalizedAnchor = bindings?.start + ? Vec.From(bindings.start.props.normalizedAnchor) + : null + const endNormalizedAnchor = bindings?.end ? Vec.From(bindings.end.props.normalizedAnchor) : null + if (scaleX < 0 && scaleY >= 0) { if (bend !== 0) { bend *= -1 bend *= Math.max(mx, my) } - if (start.type === 'binding') { - start.normalizedAnchor.x = 1 - start.normalizedAnchor.x + if (startNormalizedAnchor) { + startNormalizedAnchor.x = 1 - startNormalizedAnchor.x } - if (end.type === 'binding') { - end.normalizedAnchor.x = 1 - end.normalizedAnchor.x + if (endNormalizedAnchor) { + endNormalizedAnchor.x = 1 - endNormalizedAnchor.x } } else if (scaleX >= 0 && scaleY < 0) { if (bend !== 0) { @@ -455,12 +459,12 @@ export class ArrowShapeUtil extends ShapeUtil { bend *= Math.max(mx, my) } - if (start.type === 'binding') { - start.normalizedAnchor.y = 1 - start.normalizedAnchor.y + if (startNormalizedAnchor) { + startNormalizedAnchor.y = 1 - startNormalizedAnchor.y } - if (end.type === 'binding') { - end.normalizedAnchor.y = 1 - end.normalizedAnchor.y + if (endNormalizedAnchor) { + endNormalizedAnchor.y = 1 - endNormalizedAnchor.y } } else if (scaleX >= 0 && scaleY >= 0) { if (bend !== 0) { @@ -471,17 +475,30 @@ export class ArrowShapeUtil extends ShapeUtil { bend *= Math.max(mx, my) } - if (start.type === 'binding') { - start.normalizedAnchor.x = 1 - start.normalizedAnchor.x - start.normalizedAnchor.y = 1 - start.normalizedAnchor.y + if (startNormalizedAnchor) { + startNormalizedAnchor.x = 1 - startNormalizedAnchor.x + startNormalizedAnchor.y = 1 - startNormalizedAnchor.y } - if (end.type === 'binding') { - end.normalizedAnchor.x = 1 - end.normalizedAnchor.x - end.normalizedAnchor.y = 1 - end.normalizedAnchor.y + if (endNormalizedAnchor) { + endNormalizedAnchor.x = 1 - endNormalizedAnchor.x + endNormalizedAnchor.y = 1 - endNormalizedAnchor.y } } + if (bindings.start && startNormalizedAnchor) { + createOrUpdateArrowBinding(this.editor, shape, bindings.start.toId, { + ...bindings.start.props, + normalizedAnchor: startNormalizedAnchor.toJson(), + }) + } + if (bindings.end && endNormalizedAnchor) { + createOrUpdateArrowBinding(this.editor, shape, bindings.end.toId, { + ...bindings.end.props, + normalizedAnchor: endNormalizedAnchor.toJson(), + }) + } + const next = { props: { start, @@ -565,18 +582,18 @@ export class ArrowShapeUtil extends ShapeUtil { } indicator(shape: TLArrowShape) { - const { start, end } = getArrowTerminalsInArrowSpace(this.editor, shape) + // eslint-disable-next-line react-hooks/rules-of-hooks + const isEditing = useIsEditing(shape.id) const info = this.editor.getArrowInfo(shape) + if (!info) return null + + const { start, end } = getArrowTerminalsInArrowSpace(this.editor, shape, info?.bindings) const geometry = this.editor.getShapeGeometry(shape) const bounds = geometry.bounds const labelGeometry = shape.props.text.trim() ? (geometry.children[1] as Rectangle2d) : null - // eslint-disable-next-line react-hooks/rules-of-hooks - const isEditing = useIsEditing(shape.id) - - if (!info) return null if (Vec.Equals(start, end)) return null const strokeWidth = STROKE_SIZES[shape.props.size] @@ -753,6 +770,7 @@ const ArrowSvg = track(function ArrowSvg({ const theme = useDefaultColorTheme() const info = editor.getArrowInfo(shape) const bounds = Box.ZeroFix(editor.getShapeGeometry(shape).bounds) + const bindings = getArrowBindings(editor, shape) const changeIndex = React.useMemo(() => { return editor.environment.isSafari ? (globalRenderIndex += 1) : 0 @@ -783,7 +801,7 @@ const ArrowSvg = track(function ArrowSvg({ ) handlePath = - shape.props.start.type === 'binding' || shape.props.end.type === 'binding' ? ( + bindings.start || bindings.end ? ( + binding: TLArrowBinding } | null > } diff --git a/packages/tldraw/src/lib/shapes/arrow/arrowLabel.ts b/packages/tldraw/src/lib/shapes/arrow/arrowLabel.ts index edef80479..a43d9847b 100644 --- a/packages/tldraw/src/lib/shapes/arrow/arrowLabel.ts +++ b/packages/tldraw/src/lib/shapes/arrow/arrowLabel.ts @@ -268,8 +268,8 @@ export function getArrowLabelPosition(editor: Editor, shape: TLArrowShape) { const debugGeom: Geometry2d[] = [] const info = editor.getArrowInfo(shape)! - const hasStartBinding = shape.props.start.type === 'binding' - const hasEndBinding = shape.props.end.type === 'binding' + const hasStartBinding = !!info.bindings.start + const hasEndBinding = !!info.bindings.end const hasStartArrowhead = info.start.arrowhead !== 'none' const hasEndArrowhead = info.end.arrowhead !== 'none' if (info.isStraight) { diff --git a/packages/tldraw/src/lib/shapes/arrow/toolStates/Pointing.ts b/packages/tldraw/src/lib/shapes/arrow/toolStates/Pointing.ts index 9e0b9414d..074dc1f63 100644 --- a/packages/tldraw/src/lib/shapes/arrow/toolStates/Pointing.ts +++ b/packages/tldraw/src/lib/shapes/arrow/toolStates/Pointing.ts @@ -42,11 +42,12 @@ export class Pointing extends StateNode { if (!this.shape) throw Error(`expected shape`) + // const initialEndHandle = this.editor.getShapeHandles(this.shape)!.find((h) => h.id === 'end')! this.updateArrowShapeEndHandle() this.editor.setCurrentTool('select.dragging_handle', { shape: this.shape, - handle: this.editor.getShapeHandles(this.shape)!.find((h) => h.id === 'end')!, + handle: { id: 'end', type: 'vertex', index: 'a3', x: 0, y: 0, canBind: true }, isCreating: true, onInteractionEnd: 'arrow', }) @@ -111,10 +112,6 @@ export class Pointing extends StateNode { }) if (change) { - const startTerminal = change.props?.start - if (startTerminal?.type === 'binding') { - this.editor.setHintingShapes([startTerminal.boundShapeId]) - } this.editor.updateShapes([change]) } @@ -130,9 +127,20 @@ export class Pointing extends StateNode { const handles = this.editor.getShapeHandles(shape) if (!handles) throw Error(`expected handles for arrow`) - const shapeWithOutEndOffset = { - ...shape, - props: { ...shape.props, end: { ...shape.props.end, x: 0, y: 0 } }, + // start update + { + const util = this.editor.getShapeUtil('arrow') + const initial = this.shape + const startHandle = handles.find((h) => h.id === 'start')! + const change = util.onHandleDrag?.(shape, { + handle: { ...startHandle, x: 0, y: 0 }, + isPrecise: this.didTimeout, // sure about that? + initial: initial, + }) + + if (change) { + this.editor.updateShapes([change]) + } } // end update @@ -141,32 +149,12 @@ export class Pointing extends StateNode { const initial = this.shape const point = this.editor.getPointInShapeSpace(shape, this.editor.inputs.currentPagePoint) const endHandle = handles.find((h) => h.id === 'end')! - const change = util.onHandleDrag?.(shapeWithOutEndOffset, { + const change = util.onHandleDrag?.(this.editor.getShape(shape)!, { handle: { ...endHandle, x: point.x, y: point.y }, isPrecise: false, // sure about that? initial: initial, }) - if (change) { - const endTerminal = change.props?.end - if (endTerminal?.type === 'binding') { - this.editor.setHintingShapes([endTerminal.boundShapeId]) - } - this.editor.updateShapes([change]) - } - } - - // start update - { - const util = this.editor.getShapeUtil('arrow') - const initial = this.shape - const startHandle = handles.find((h) => h.id === 'start')! - const change = util.onHandleDrag?.(shapeWithOutEndOffset, { - handle: { ...startHandle, x: 0, y: 0 }, - isPrecise: this.didTimeout, // sure about that? - initial: initial, - }) - if (change) { this.editor.updateShapes([change]) } diff --git a/packages/tldraw/src/lib/tools/SelectTool/childStates/DraggingHandle.tsx b/packages/tldraw/src/lib/tools/SelectTool/childStates/DraggingHandle.tsx index aab15b79d..0eebe1b29 100644 --- a/packages/tldraw/src/lib/tools/SelectTool/childStates/DraggingHandle.tsx +++ b/packages/tldraw/src/lib/tools/SelectTool/childStates/DraggingHandle.tsx @@ -1,7 +1,6 @@ import { StateNode, TLArrowShape, - TLArrowShapeTerminal, TLCancelEvent, TLEnterEventHandler, TLEventHandlers, @@ -12,6 +11,7 @@ import { TLShapeId, TLShapePartial, Vec, + getArrowBindings, snapAngle, sortByIndex, structuredClone, @@ -112,16 +112,16 @@ export class DraggingHandle extends StateNode { //