From c6ba621c114ffcd3bededf642895a9c68fbb420f Mon Sep 17 00:00:00 2001 From: David Sheldrick Date: Wed, 8 May 2024 15:18:24 +0100 Subject: [PATCH] Incremental bindings index (#3685) --- packages/editor/api-report.md | 6 +- packages/editor/src/lib/editor/Editor.ts | 134 ++++----- .../lib/editor/derivations/bindingsIndex.ts | 106 +++++++ .../tldraw/src/test/bindingsIndex.test.tsx | 260 ++++++++++++++++++ 4 files changed, 438 insertions(+), 68 deletions(-) create mode 100644 packages/editor/src/lib/editor/derivations/bindingsIndex.ts create mode 100644 packages/tldraw/src/test/bindingsIndex.test.tsx diff --git a/packages/editor/api-report.md b/packages/editor/api-report.md index e4262d7ab..bff3af908 100644 --- a/packages/editor/api-report.md +++ b/packages/editor/api-report.md @@ -778,10 +778,6 @@ export class Editor extends EventEmitter { findCommonAncestor(shapes: TLShape[] | TLShapeId[], predicate?: (shape: TLShape) => boolean): TLShapeId | undefined; findShapeAncestor(shape: TLShape | TLShapeId, predicate: (parent: TLShape) => boolean): TLShape | undefined; flipShapes(shapes: TLShape[] | TLShapeId[], operation: 'horizontal' | 'vertical'): this; - // (undocumented) - getAllBindingsFromShape(shape: TLShape | TLShapeId): TLBinding[]; - // (undocumented) - getAllBindingsToShape(shape: TLShape | TLShapeId): TLBinding[]; getAncestorPageId(shape?: TLShape | TLShapeId): TLPageId | undefined; getArrowInfo(shape: TLArrowShape | TLShapeId): TLArrowInfo | undefined; getArrowsBoundTo(shapeId: TLShapeId): TLArrowShape[]; @@ -794,6 +790,8 @@ export class Editor extends EventEmitter { // (undocumented) getBindingsFromShape(shape: TLShape | TLShapeId, type: Binding['type']): Binding[]; // (undocumented) + getBindingsInvolvingShape(shape: TLShape | TLShapeId, type?: Binding['type']): Binding[]; + // (undocumented) getBindingsToShape(shape: TLShape | TLShapeId, type: Binding['type']): Binding[]; getBindingUtil(binding: S | TLBindingPartial): BindingUtil; // (undocumented) diff --git a/packages/editor/src/lib/editor/Editor.ts b/packages/editor/src/lib/editor/Editor.ts index 2ff5e8925..73ab85fdd 100644 --- a/packages/editor/src/lib/editor/Editor.ts +++ b/packages/editor/src/lib/editor/Editor.ts @@ -120,6 +120,7 @@ import { getReorderingShapesChanges } from '../utils/reorderShapes' import { applyRotationToSnapshotShapes, getRotationSnapshot } from '../utils/rotation' import { uniqueId } from '../utils/uniqueId' import { BindingUtil, TLBindingUtilConstructor } from './bindings/BindingUtil' +import { bindingsIndex } from './derivations/bindingsIndex' import { notVisibleShapes } from './derivations/notVisibleShapes' import { parentsToChildren } from './derivations/parentsToChildren' import { deriveShapeIdsInCurrentPage } from './derivations/shapeIdsInCurrentPage' @@ -379,19 +380,21 @@ export class Editor extends EventEmitter { this.sideEffects.register({ shape: { afterChange: (shapeBefore, shapeAfter) => { - for (const binding of this.getAllBindingsFromShape(shapeAfter)) { - this.getBindingUtil(binding).onAfterChangeFromShape?.({ - binding, - shapeBefore, - shapeAfter, - }) - } - for (const binding of this.getAllBindingsToShape(shapeAfter)) { - this.getBindingUtil(binding).onAfterChangeToShape?.({ - binding, - shapeBefore, - shapeAfter, - }) + for (const binding of this.getBindingsInvolvingShape(shapeAfter)) { + if (binding.fromId === shapeAfter.id) { + this.getBindingUtil(binding).onAfterChangeFromShape?.({ + binding, + shapeBefore, + shapeAfter, + }) + } + if (binding.toId === shapeAfter.id) { + this.getBindingUtil(binding).onAfterChangeToShape?.({ + binding, + shapeBefore, + shapeAfter, + }) + } } // if the shape's parent changed and it has a binding, update the binding @@ -400,19 +403,21 @@ export class Editor extends EventEmitter { const descendantShape = this.getShape(id) if (!descendantShape) return - for (const binding of this.getAllBindingsFromShape(descendantShape)) { - this.getBindingUtil(binding).onAfterChangeFromShape?.({ - binding, - shapeBefore: descendantShape, - shapeAfter: descendantShape, - }) - } - for (const binding of this.getAllBindingsToShape(descendantShape)) { - this.getBindingUtil(binding).onAfterChangeToShape?.({ - binding, - shapeBefore: descendantShape, - shapeAfter: descendantShape, - }) + for (const binding of this.getBindingsInvolvingShape(descendantShape)) { + if (binding.fromId === descendantShape.id) { + this.getBindingUtil(binding).onAfterChangeFromShape?.({ + binding, + shapeBefore: descendantShape, + shapeAfter: descendantShape, + }) + } + if (binding.toId === descendantShape.id) { + this.getBindingUtil(binding).onAfterChangeToShape?.({ + binding, + shapeBefore: descendantShape, + shapeAfter: descendantShape, + }) + } } } notifyBindingAncestryChange(shapeAfter.id) @@ -451,13 +456,15 @@ export class Editor extends EventEmitter { } const deleteBindingIds: TLBindingId[] = [] - for (const binding of this.getAllBindingsFromShape(shape)) { - this.getBindingUtil(binding).onBeforeDeleteFromShape?.({ binding, shape }) - deleteBindingIds.push(binding.id) - } - for (const binding of this.getAllBindingsToShape(shape)) { - this.getBindingUtil(binding).onBeforeDeleteToShape?.({ binding, shape }) - deleteBindingIds.push(binding.id) + for (const binding of this.getBindingsInvolvingShape(shape)) { + if (binding.fromId === shape.id) { + this.getBindingUtil(binding).onBeforeDeleteFromShape?.({ binding, shape }) + deleteBindingIds.push(binding.id) + } + if (binding.toId === shape.id) { + this.getBindingUtil(binding).onBeforeDeleteToShape?.({ binding, shape }) + deleteBindingIds.push(binding.id) + } } this.deleteBindings(deleteBindingIds) @@ -5032,42 +5039,44 @@ export class Editor extends EventEmitter { /* -------------------- Bindings -------------------- */ + @computed + private _getBindingsIndexCache() { + const index = bindingsIndex(this) + return this.store.createComputedCache('bindingsIndex', (shape) => { + return index.get().get(shape.id) + }) + } + getBinding(id: TLBindingId): TLBinding | undefined { return this.store.get(id) as TLBinding | undefined } - // TODO(alex) #bindings - cache `allBindings` getters and derive type-specific ones from them getBindingsFromShape( shape: TLShape | TLShapeId, type: Binding['type'] ): Binding[] { const id = typeof shape === 'string' ? shape : shape.id - return this.store.query.exec('binding', { - fromId: { eq: id }, - type: { eq: type }, - }) as Binding[] + return this.getBindingsInvolvingShape(id).filter( + (b) => b.fromId === id && b.type === type + ) as Binding[] } getBindingsToShape( shape: TLShape | TLShapeId, type: Binding['type'] ): Binding[] { const id = typeof shape === 'string' ? shape : shape.id - return this.store.query.exec('binding', { - toId: { eq: id }, - type: { eq: type }, - }) as Binding[] + return this.getBindingsInvolvingShape(id).filter( + (b) => b.toId === id && b.type === type + ) as Binding[] } - getAllBindingsFromShape(shape: TLShape | TLShapeId): TLBinding[] { + getBindingsInvolvingShape( + shape: TLShape | TLShapeId, + type?: Binding['type'] + ): Binding[] { const id = typeof shape === 'string' ? shape : shape.id - return this.store.query.exec('binding', { - fromId: { eq: id }, - }) - } - getAllBindingsToShape(shape: TLShape | TLShapeId): TLBinding[] { - const id = typeof shape === 'string' ? shape : shape.id - return this.store.query.exec('binding', { - toId: { eq: id }, - }) + const result = this._getBindingsIndexCache().get(id) ?? EMPTY_ARRAY + if (!type) return result as Binding[] + return result.filter((b) => b.type === type) as Binding[] } createBindings(partials: RequiredKeys[]) { @@ -8744,22 +8753,19 @@ function withoutBindingsToUnrelatedShapes( const shape = editor.getShape(shapeId) if (!shape) continue - for (const binding of editor.getAllBindingsFromShape(shapeId)) { - if (shapeIds.has(binding.toId)) { - // if we have both sides of the binding, we want to recreate it + for (const binding of editor.getBindingsInvolvingShape(shapeId)) { + const hasFrom = shapeIds.has(binding.fromId) + const hasTo = shapeIds.has(binding.toId) + if (hasFrom && hasTo) { bindingsWithBoth.add(binding.id) - } else { - // otherwise, if we only have one side, we need to record that and duplicate - // the shape as if the one it's bound to has been deleted - bindingsWithoutTo.add(binding.id) + continue } - } - for (const binding of editor.getAllBindingsToShape(shapeId)) { - if (shapeIds.has(binding.fromId)) { - bindingsWithBoth.add(binding.id) - } else { + if (!hasFrom) { bindingsWithoutFrom.add(binding.id) } + if (!hasTo) { + bindingsWithoutTo.add(binding.id) + } } } diff --git a/packages/editor/src/lib/editor/derivations/bindingsIndex.ts b/packages/editor/src/lib/editor/derivations/bindingsIndex.ts new file mode 100644 index 000000000..59bccd6d0 --- /dev/null +++ b/packages/editor/src/lib/editor/derivations/bindingsIndex.ts @@ -0,0 +1,106 @@ +import { Computed, RESET_VALUE, computed, isUninitialized } from '@tldraw/state' +import { TLBinding, TLShapeId } from '@tldraw/tlschema' +import { objectMapValues } from '@tldraw/utils' +import { Editor } from '../Editor' + +type TLBindingsIndex = Map + +export const bindingsIndex = (editor: Editor): Computed => { + const { store } = editor + const bindingsHistory = store.query.filterHistory('binding') + const bindingsQuery = store.query.records('binding') + function fromScratch() { + const allBindings = bindingsQuery.get() as TLBinding[] + + const shape2Binding: TLBindingsIndex = new Map() + + for (const binding of allBindings) { + const { fromId, toId } = binding + const bindingsForFromShape = shape2Binding.get(fromId) + if (!bindingsForFromShape) { + shape2Binding.set(fromId, [binding]) + } else { + bindingsForFromShape.push(binding) + } + const bindingsForToShape = shape2Binding.get(toId) + if (!bindingsForToShape) { + shape2Binding.set(toId, [binding]) + } else { + bindingsForToShape.push(binding) + } + } + + return shape2Binding + } + + return computed('arrowBindingsIndex', (_lastValue, lastComputedEpoch) => { + if (isUninitialized(_lastValue)) { + return fromScratch() + } + + const lastValue = _lastValue + + const diff = bindingsHistory.getDiffSince(lastComputedEpoch) + + if (diff === RESET_VALUE) { + return fromScratch() + } + + let nextValue: TLBindingsIndex | undefined = undefined + + function removingBinding(binding: TLBinding) { + nextValue ??= new Map(lastValue) + const prevFrom = lastValue.get(binding.fromId) + const nextFrom = prevFrom?.filter((b) => b.id !== binding.id) + if (!nextFrom?.length) { + nextValue.delete(binding.fromId) + } else { + nextValue.set(binding.fromId, nextFrom) + } + const prevTo = lastValue.get(binding.toId) + const nextTo = prevTo?.filter((b) => b.id !== binding.id) + if (!nextTo?.length) { + nextValue.delete(binding.toId) + } else { + nextValue.set(binding.toId, nextTo) + } + } + + function ensureNewArray(shapeId: TLShapeId) { + nextValue ??= new Map(lastValue) + + let result = nextValue.get(shapeId) + if (!result) { + result = [] + nextValue.set(shapeId, result) + } else if (result === lastValue.get(shapeId)) { + result = result.slice(0) + nextValue.set(shapeId, result) + } + return result + } + + function addBinding(binding: TLBinding) { + ensureNewArray(binding.fromId).push(binding) + ensureNewArray(binding.toId).push(binding) + } + + for (const changes of diff) { + for (const newBinding of objectMapValues(changes.added)) { + addBinding(newBinding) + } + + for (const [prev, next] of objectMapValues(changes.updated)) { + removingBinding(prev) + addBinding(next) + } + + for (const prev of objectMapValues(changes.removed)) { + removingBinding(prev) + } + } + + // TODO: add diff entries if we need them + return nextValue ?? lastValue + }) +} diff --git a/packages/tldraw/src/test/bindingsIndex.test.tsx b/packages/tldraw/src/test/bindingsIndex.test.tsx new file mode 100644 index 000000000..4b6fbfea2 --- /dev/null +++ b/packages/tldraw/src/test/bindingsIndex.test.tsx @@ -0,0 +1,260 @@ +import { TLArrowBinding, TLGeoShape, TLShapeId, createShapeId } from '@tldraw/editor' +import { TestEditor } from './TestEditor' +import { TL } from './test-jsx' + +let editor: TestEditor + +beforeEach(() => { + editor = new TestEditor() +}) + +describe('bindingsIndex', () => { + it('keeps a mapping from bound shapes to their bindings', () => { + const ids = editor.createShapesFromJsx([ + , + , + ]) + + editor.selectNone() + editor.setCurrentTool('arrow') + editor.pointerDown(50, 50) + expect(editor.getOnlySelectedShape()).toBe(null) + expect(editor.getArrowsBoundTo(ids.box1)).toEqual([]) + + editor.pointerMove(50, 55) + expect(editor.getOnlySelectedShape()).not.toBe(null) + const arrow = editor.getOnlySelectedShape()! + expect(arrow.type).toBe('arrow') + expect(editor.getArrowsBoundTo(ids.box1)).toEqual([arrow]) + + editor.pointerMove(250, 50) + expect(editor.getArrowsBoundTo(ids.box1)).toEqual([editor.getShape(arrow.id)]) + expect(editor.getArrowsBoundTo(ids.box2)).toEqual([editor.getShape(arrow.id)]) + }) + + it('works if there are many arrows', () => { + const ids = { + box1: createShapeId('box1'), + box2: createShapeId('box2'), + } + + editor.createShapes([ + { type: 'geo', id: ids.box1, x: 0, y: 0, props: { w: 100, h: 100 } }, + { type: 'geo', id: ids.box2, x: 200, y: 0, props: { w: 100, h: 100 } }, + ]) + + editor.setCurrentTool('arrow') + // start at box 1 and end on box 2 + editor.pointerDown(50, 50) + + expect(editor.getArrowsBoundTo(ids.box1)).toEqual([]) + + editor.pointerMove(250, 50) + const arrow1 = editor.getOnlySelectedShape()! + expect(arrow1.type).toBe('arrow') + + expect(editor.getArrowsBoundTo(ids.box1)).toEqual([arrow1]) + expect(editor.getArrowsBoundTo(ids.box2)).toEqual([arrow1]) + + editor.pointerUp() + + expect(editor.getArrowsBoundTo(ids.box1)).toEqual([arrow1]) + expect(editor.getArrowsBoundTo(ids.box2)).toEqual([arrow1]) + + // start at box 1 and end on the page + editor.setCurrentTool('arrow') + editor.pointerMove(50, 50).pointerDown().pointerMove(50, -50).pointerUp() + const arrow2 = editor.getOnlySelectedShape()! + expect(arrow2.type).toBe('arrow') + + expect(editor.getArrowsBoundTo(ids.box1)).toEqual([arrow1, arrow2]) + + // start outside box 1 and end in box 1 + editor.setCurrentTool('arrow') + editor.pointerDown(0, -50).pointerMove(50, 50).pointerUp(50, 50) + const arrow3 = editor.getOnlySelectedShape()! + expect(arrow3.type).toBe('arrow') + + expect(editor.getArrowsBoundTo(ids.box1)).toEqual([arrow1, arrow2, arrow3]) + + expect(editor.getArrowsBoundTo(ids.box2)).toEqual([arrow1]) + + // start at box 2 and end on the page + editor.selectNone() + editor.setCurrentTool('arrow') + editor.pointerDown(250, 50) + editor.expectToBeIn('arrow.pointing') + editor.pointerMove(250, -50) + editor.expectToBeIn('select.dragging_handle') + const arrow4 = editor.getOnlySelectedShape()! + + expect(editor.getArrowsBoundTo(ids.box2)).toEqual([arrow1, arrow4]) + + editor.pointerUp(250, -50) + editor.expectToBeIn('select.idle') + expect(arrow4.type).toBe('arrow') + + expect(editor.getArrowsBoundTo(ids.box2)).toEqual([arrow1, arrow4]) + + // start outside box 2 and enter in box 2 + editor.setCurrentTool('arrow') + editor.pointerDown(250, -50).pointerMove(250, 50).pointerUp(250, 50) + const arrow5 = editor.getOnlySelectedShape()! + expect(arrow5.type).toBe('arrow') + + expect(editor.getArrowsBoundTo(ids.box1)).toEqual([arrow1, arrow2, arrow3]) + + expect(editor.getArrowsBoundTo(ids.box2)).toEqual([arrow1, arrow4, arrow5]) + }) + + describe('updating shapes', () => { + // ▲ │ │ ▲ + // │ │ │ │ + // b c e d + // ┌───┼─┴─┐ ┌──┴──┼─┐ + // │ │ ▼ │ │ ▼ │ │ + // │ └───┼─────a───┼───► │ │ + // │ 1 │ │ 2 │ + // └───────┘ └───────┘ + let arrowAId: TLShapeId + let arrowBId: TLShapeId + let arrowCId: TLShapeId + let arrowDId: TLShapeId + let arrowEId: TLShapeId + let ids: Record + beforeEach(() => { + ids = editor.createShapesFromJsx([ + , + , + ]) + + // span both boxes + editor.setCurrentTool('arrow') + editor.pointerDown(50, 50).pointerMove(250, 50).pointerUp(250, 50) + arrowAId = editor.getOnlySelectedShape()!.id + // start at box 1 and leave + editor.setCurrentTool('arrow') + editor.pointerDown(50, 50).pointerMove(50, -50).pointerUp(50, -50) + arrowBId = editor.getOnlySelectedShape()!.id + // start outside box 1 and enter + editor.setCurrentTool('arrow') + editor.pointerDown(50, -50).pointerMove(50, 50).pointerUp(50, 50) + arrowCId = editor.getOnlySelectedShape()!.id + // start at box 2 and leave + editor.setCurrentTool('arrow') + editor.pointerDown(250, 50).pointerMove(250, -50).pointerUp(250, -50) + arrowDId = editor.getOnlySelectedShape()!.id + // start outside box 2 and enter + editor.setCurrentTool('arrow') + editor.pointerDown(250, -50).pointerMove(250, 50).pointerUp(250, 50) + arrowEId = editor.getOnlySelectedShape()!.id + }) + it('deletes the entry if you delete the bound shapes', () => { + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(3) + editor.deleteShapes([ids.box2]) + expect(editor.getArrowsBoundTo(ids.box2)).toEqual([]) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3) + }) + it('deletes the entry if you delete an arrow', () => { + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(3) + editor.deleteShapes([arrowEId]) + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(2) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3) + + editor.deleteShapes([arrowDId]) + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(1) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3) + + editor.deleteShapes([arrowCId]) + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(1) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(2) + + editor.deleteShapes([arrowBId]) + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(1) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(1) + + editor.deleteShapes([arrowAId]) + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(0) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(0) + }) + + it('deletes the entries in a batch too', () => { + editor.deleteShapes([arrowAId, arrowBId, arrowCId, arrowDId, arrowEId]) + + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(0) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(0) + }) + + it('adds new entries after initial creation', () => { + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(3) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3) + + // draw from box 2 to box 1 + editor.setCurrentTool('arrow') + editor.pointerDown(250, 50).pointerMove(50, 50).pointerUp(50, 50) + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(4) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(4) + + // create a new box + + const { box3 } = editor.createShapesFromJsx( + + ) + + // draw from box 2 to box 3 + + editor.setCurrentTool('arrow') + editor.pointerDown(250, 50).pointerMove(450, 50).pointerUp(450, 50) + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(5) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(4) + expect(editor.getArrowsBoundTo(box3)).toHaveLength(1) + }) + + it('works when copy pasting', () => { + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(3) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3) + + editor.selectAll() + editor.duplicateShapes(editor.getSelectedShapeIds()) + + const [box1Clone, box2Clone] = editor + .getSelectedShapes() + .filter((shape) => editor.isShapeOfType(shape, 'geo')) + .sort((a, b) => a.x - b.x) + + expect(editor.getArrowsBoundTo(box2Clone.id)).toHaveLength(3) + expect(editor.getArrowsBoundTo(box1Clone.id)).toHaveLength(3) + }) + + it('allows bound shapes to be moved', () => { + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(3) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3) + + editor.nudgeShapes([ids.box2], { x: 0, y: -1 }) + + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(3) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3) + }) + + it('allows the arrows bound shape to change', () => { + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(3) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3) + + // create another box + + const { box3 } = editor.createShapesFromJsx( + + ) + + // move arrowA end from box2 to box3 + const binding = editor + .getBindingsInvolvingShape(ids.box2, 'arrow') + .find((b) => b.props.terminal === 'end')! + editor.updateBinding({ ...binding, toId: box3 } satisfies TLArrowBinding) + + expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(2) + expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3) + expect(editor.getArrowsBoundTo(box3)).toHaveLength(1) + }) + }) +})