From 507bba82fd8830ad1f6e7f7ae2a2d9a5b5625033 Mon Sep 17 00:00:00 2001 From: Steve Ruiz Date: Wed, 2 Aug 2023 12:05:14 +0100 Subject: [PATCH] SideEffectManager (#1785) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR extracts the side effect manager from #1778. ### Change Type - [x] `major` — Breaking change --- packages/editor/api-report.md | 1 + packages/editor/src/lib/editor/Editor.ts | 670 +++++++++--------- .../editor/managers/HistoryManager.test.ts | 10 +- .../src/lib/editor/managers/HistoryManager.ts | 3 +- .../editor/managers/SIdeEffectManager.test.ts | 20 + .../lib/editor/managers/SideEffectManager.ts | 245 +++++++ packages/store/api-report.md | 12 +- packages/store/src/lib/Store.ts | 81 ++- 8 files changed, 664 insertions(+), 378 deletions(-) create mode 100644 packages/editor/src/lib/editor/managers/SIdeEffectManager.test.ts create mode 100644 packages/editor/src/lib/editor/managers/SideEffectManager.ts diff --git a/packages/editor/api-report.md b/packages/editor/api-report.md index 414ef1826..7e5667fc4 100644 --- a/packages/editor/api-report.md +++ b/packages/editor/api-report.md @@ -944,6 +944,7 @@ export class Editor extends EventEmitter { }; get sharedOpacity(): SharedStyle; get sharedStyles(): ReadonlySharedStyleMap; + readonly sideEffects: SideEffectManager; slideCamera(opts?: { speed: number; direction: VecLike; diff --git a/packages/editor/src/lib/editor/Editor.ts b/packages/editor/src/lib/editor/Editor.ts index 79d5974b9..69f823fdd 100644 --- a/packages/editor/src/lib/editor/Editor.ts +++ b/packages/editor/src/lib/editor/Editor.ts @@ -105,6 +105,7 @@ import { deriveShapeIdsInCurrentPage } from './derivations/shapeIdsInCurrentPage import { ClickManager } from './managers/ClickManager' import { EnvironmentManager } from './managers/EnvironmentManager' import { HistoryManager } from './managers/HistoryManager' +import { SideEffectManager } from './managers/SideEffectManager' import { SnapManager } from './managers/SnapManager' import { TextManager } from './managers/TextManager' import { TickManager } from './managers/TickManager' @@ -242,47 +243,328 @@ export class Editor extends EventEmitter { this.environment = new EnvironmentManager(this) - this.store.onBeforeDelete = (record) => { - if (record.typeName === 'shape') { - this._shapeWillBeDeleted(record) - } else if (record.typeName === 'page') { - this._pageWillBeDeleted(record) - } - } + // Cleanup - this.store.onAfterChange = (prev, next) => { - this._updateDepth++ - if (this._updateDepth > 1000) { - console.error('[onAfterChange] Maximum update depth exceeded, bailing out.') - } - if (prev.typeName === 'shape' && next.typeName === 'shape') { - this._shapeDidChange(prev, next) - } else if ( - prev.typeName === 'instance_page_state' && - next.typeName === 'instance_page_state' - ) { - this._pageStateDidChange(prev, next) + const invalidParents = new Set() + + const reparentArrow = (arrowId: TLArrowShape['id']) => { + const arrow = this.getShape(arrowId) + if (!arrow) return + const { start, end } = arrow.props + const startShape = start.type === 'binding' ? this.getShape(start.boundShapeId) : undefined + const endShape = end.type === 'binding' ? this.getShape(end.boundShapeId) : undefined + + const parentPageId = this.getAncestorPageId(arrow) + if (!parentPageId) return + + let nextParentId: TLParentId + if (startShape && endShape) { + // if arrow has two bindings, always parent arrow to closest common ancestor of the bindings + nextParentId = this.findCommonAncestor([startShape, endShape]) ?? parentPageId + } else if (startShape || endShape) { + // if arrow has one binding, keep arrow on its own page + nextParentId = parentPageId + } else { + return } - this._updateDepth-- - } - this.store.onAfterCreate = (record) => { - if (record.typeName === 'shape' && this.isShapeOfType(record, 'arrow')) { - this._arrowDidUpdate(record) + if (nextParentId && nextParentId !== arrow.parentId) { + this.reparentShapes([arrowId], nextParentId) } - if (record.typeName === 'page') { - const cameraId = CameraRecordType.createId(record.id) - const _pageStateId = InstancePageStateRecordType.createId(record.id) - if (!this.store.has(cameraId)) { - this.store.put([CameraRecordType.create({ id: cameraId })]) + + const reparentedArrow = this.getShape(arrowId) + if (!reparentedArrow) throw Error('no reparented arrow') + + const startSibling = this.getShapeNearestSibling(reparentedArrow, startShape) + const endSibling = this.getShapeNearestSibling(reparentedArrow, endShape) + + let highestSibling: TLShape | undefined + + if (startSibling && endSibling) { + highestSibling = startSibling.index > endSibling.index ? startSibling : endSibling + } else if (startSibling && !endSibling) { + highestSibling = startSibling + } else if (endSibling && !startSibling) { + highestSibling = endSibling + } else { + return + } + + let finalIndex: string + + const higherSiblings = this.getSortedChildIdsForParent(highestSibling.parentId) + .map((id) => this.getShape(id)!) + .filter((sibling) => sibling.index > highestSibling!.index) + + if (higherSiblings.length) { + // there are siblings above the highest bound sibling, we need to + // insert between them. + + // if the next sibling is also a bound arrow though, we can end up + // all fighting for the same indexes. so lets find the next + // non-arrow sibling... + const nextHighestNonArrowSibling = higherSiblings.find( + (sibling) => sibling.type !== 'arrow' + ) + + if ( + // ...then, if we're above the last shape we want to be above... + reparentedArrow.index > highestSibling.index && + // ...but below the next non-arrow sibling... + (!nextHighestNonArrowSibling || reparentedArrow.index < nextHighestNonArrowSibling.index) + ) { + // ...then we're already in the right place. no need to update! + return } - if (!this.store.has(_pageStateId)) { + + // otherwise, we need to find the index between the highest sibling + // we want to be above, and the next highest sibling we want to be + // below: + finalIndex = getIndexBetween(highestSibling.index, higherSiblings[0].index) + } else { + // if there are no siblings above us, we can just get the next index: + finalIndex = getIndexAbove(highestSibling.index) + } + + if (finalIndex !== reparentedArrow.index) { + this.updateShapes([{ id: arrowId, type: 'arrow', index: finalIndex }]) + } + } + + const unbindArrowTerminal = (arrow: TLArrowShape, handleId: 'start' | 'end') => { + const { x, y } = getArrowTerminalsInArrowSpace(this, arrow)[handleId] + this.store.put([{ ...arrow, props: { ...arrow.props, [handleId]: { type: 'point', x, y } } }]) + } + + const arrowDidUpdate = (arrow: TLArrowShape) => { + // if the shape is an arrow and its bound shape is on another page + // or was deleted, unbind it + for (const handle of ['start', 'end'] as const) { + const terminal = arrow.props[handle] + if (terminal.type !== 'binding') continue + const boundShape = this.getShape(terminal.boundShapeId) + const isShapeInSamePageAsArrow = + this.getAncestorPageId(arrow) === this.getAncestorPageId(boundShape) + if (!boundShape || !isShapeInSamePageAsArrow) { + unbindArrowTerminal(arrow, handle) + } + } + + // always check the arrow parents + reparentArrow(arrow.id) + } + + const cleanupInstancePageState = ( + prevPageState: TLInstancePageState, + shapesNoLongerInPage: Set + ) => { + let nextPageState = null as null | TLInstancePageState + + const selectedShapeIds = prevPageState.selectedShapeIds.filter( + (id) => !shapesNoLongerInPage.has(id) + ) + if (selectedShapeIds.length !== prevPageState.selectedShapeIds.length) { + if (!nextPageState) nextPageState = { ...prevPageState } + nextPageState.selectedShapeIds = selectedShapeIds + } + + const erasingShapeIds = prevPageState.erasingShapeIds.filter( + (id) => !shapesNoLongerInPage.has(id) + ) + if (erasingShapeIds.length !== prevPageState.erasingShapeIds.length) { + if (!nextPageState) nextPageState = { ...prevPageState } + nextPageState.erasingShapeIds = erasingShapeIds + } + + if (prevPageState.hoveredShapeId && shapesNoLongerInPage.has(prevPageState.hoveredShapeId)) { + if (!nextPageState) nextPageState = { ...prevPageState } + nextPageState.hoveredShapeId = null + } + + if (prevPageState.editingShapeId && shapesNoLongerInPage.has(prevPageState.editingShapeId)) { + if (!nextPageState) nextPageState = { ...prevPageState } + nextPageState.editingShapeId = null + } + + const hintingShapeIds = prevPageState.hintingShapeIds.filter( + (id) => !shapesNoLongerInPage.has(id) + ) + if (hintingShapeIds.length !== prevPageState.hintingShapeIds.length) { + if (!nextPageState) nextPageState = { ...prevPageState } + nextPageState.hintingShapeIds = hintingShapeIds + } + + if (prevPageState.focusedGroupId && shapesNoLongerInPage.has(prevPageState.focusedGroupId)) { + if (!nextPageState) nextPageState = { ...prevPageState } + nextPageState.focusedGroupId = null + } + return nextPageState + } + + this.sideEffects = new SideEffectManager(this) + + this.sideEffects.registerBatchCompleteHandler(() => { + for (const parentId of invalidParents) { + invalidParents.delete(parentId) + const parent = this.getShape(parentId) + if (!parent) continue + + const util = this.getShapeUtil(parent) + const changes = util.onChildrenChange?.(parent) + + if (changes?.length) { + this.updateShapes(changes, true) + } + } + + this.emit('update') + }) + + this.sideEffects.registerBeforeDeleteHandler('shape', (record) => { + // if the deleted shape has a parent shape make sure we call it's onChildrenChange callback + if (record.parentId && isShapeId(record.parentId)) { + invalidParents.add(record.parentId) + } + // clean up any arrows bound to this shape + const bindings = this._arrowBindingsIndex.value[record.id] + if (bindings?.length) { + for (const { arrowId, handleId } of bindings) { + const arrow = this.getShape(arrowId) + if (!arrow) continue + unbindArrowTerminal(arrow, handleId) + } + } + const deletedIds = new Set([record.id]) + const updates = compact( + this.pageStates.map((pageState) => { + return cleanupInstancePageState(pageState, deletedIds) + }) + ) + + if (updates.length) { + this.store.put(updates) + } + }) + + this.sideEffects.registerBeforeDeleteHandler('page', (record) => { + // page was deleted, need to check whether it's the current page and select another one if so + if (this.instanceState.currentPageId !== record.id) return + + const backupPageId = this.pages.find((p) => p.id !== record.id)?.id + if (!backupPageId) return + this.store.put([{ ...this.instanceState, currentPageId: backupPageId }]) + + // delete the camera and state for the page if necessary + const cameraId = CameraRecordType.createId(record.id) + const instance_PageStateId = InstancePageStateRecordType.createId(record.id) + this.store.remove([cameraId, instance_PageStateId]) + }) + + this.sideEffects.registerAfterChangeHandler('shape', (prev, next) => { + if (this.isShapeOfType(next, 'arrow')) { + arrowDidUpdate(next) + } + + // if the shape's parent changed and it is bound to an arrow, update the arrow's parent + if (prev.parentId !== next.parentId) { + const reparentBoundArrows = (id: TLShapeId) => { + const boundArrows = this._arrowBindingsIndex.value[id] + if (boundArrows?.length) { + for (const arrow of boundArrows) { + reparentArrow(arrow.arrowId) + } + } + } + reparentBoundArrows(next.id) + this.visitDescendants(next.id, reparentBoundArrows) + } + + // if this shape moved to a new page, clean up any previous page's instance state + if (prev.parentId !== next.parentId && isPageId(next.parentId)) { + const allMovingIds = new Set([prev.id]) + this.visitDescendants(prev.id, (id) => { + allMovingIds.add(id) + }) + + for (const instancePageState of this.pageStates) { + if (instancePageState.pageId === next.parentId) continue + const nextPageState = cleanupInstancePageState(instancePageState, allMovingIds) + + if (nextPageState) { + this.store.put([nextPageState]) + } + } + } + + if (prev.parentId && isShapeId(prev.parentId)) { + invalidParents.add(prev.parentId) + } + + if (next.parentId !== prev.parentId && isShapeId(next.parentId)) { + invalidParents.add(next.parentId) + } + }) + + this.sideEffects.registerAfterChangeHandler('instance_page_state', (prev, next) => { + if (prev?.selectedShapeIds !== next?.selectedShapeIds) { + // ensure that descendants and ancestors are not selected at the same time + const filtered = next.selectedShapeIds.filter((id) => { + let parentId = this.getShape(id)?.parentId + while (isShapeId(parentId)) { + if (next.selectedShapeIds.includes(parentId)) { + return false + } + parentId = this.getShape(parentId)?.parentId + } + return true + }) + + let nextFocusedGroupId: null | TLShapeId = null + + if (filtered.length > 0) { + const commonGroupAncestor = this.findCommonAncestor( + compact(filtered.map((id) => this.getShape(id))), + (shape) => this.isShapeOfType(shape, 'group') + ) + + if (commonGroupAncestor) { + nextFocusedGroupId = commonGroupAncestor + } + } else { + if (next?.focusedGroupId) { + nextFocusedGroupId = next.focusedGroupId + } + } + + if ( + filtered.length !== next.selectedShapeIds.length || + nextFocusedGroupId !== next.focusedGroupId + ) { this.store.put([ - InstancePageStateRecordType.create({ id: _pageStateId, pageId: record.id }), + { ...next, selectedShapeIds: filtered, focusedGroupId: nextFocusedGroupId ?? null }, ]) } } - } + }) + + this.sideEffects.registerAfterCreateHandler('shape', (record) => { + if (this.isShapeOfType(record, 'arrow')) { + arrowDidUpdate(record) + } + }) + + this.sideEffects.registerAfterCreateHandler('page', (record) => { + const cameraId = CameraRecordType.createId(record.id) + const _pageStateId = InstancePageStateRecordType.createId(record.id) + if (!this.store.has(cameraId)) { + this.store.put([CameraRecordType.create({ id: cameraId })]) + } + if (!this.store.has(_pageStateId)) { + this.store.put([ + InstancePageStateRecordType.create({ id: _pageStateId, pageId: record.id }), + ]) + } + }) this._shapeIdsOnCurrentPage = deriveShapeIdsInCurrentPage(this.store, () => this.currentPageId) this._parentIdsToChildIds = parentsToChildren(this.store) @@ -376,9 +658,6 @@ export class Editor extends EventEmitter { /** @internal */ private _tickManager = new TickManager(this) - /** @internal */ - private _updateDepth = 0 - /** * A manager for the app's snapping feature. * @@ -419,6 +698,13 @@ export class Editor extends EventEmitter { */ getContainer: () => HTMLElement + /** + * A manager for side effects and correct state enforcement. + * + * @public + */ + readonly sideEffects: SideEffectManager + /** * Dispose the editor. * @@ -474,7 +760,7 @@ export class Editor extends EventEmitter { */ readonly history = new HistoryManager( this, - () => this._complete(), + // () => this._complete(), (error) => { this.annotateError(error, { origin: 'history.batch', willCrashApp: true }) this.crash(error) @@ -620,95 +906,6 @@ export class Editor extends EventEmitter { return this._arrowBindingsIndex.value[shapeId] || EMPTY_ARRAY } - /** @internal */ - private _reparentArrow(arrowId: TLShapeId) { - const arrow = this.getShape(arrowId) - if (!arrow) return - const { start, end } = arrow.props - const startShape = start.type === 'binding' ? this.getShape(start.boundShapeId) : undefined - const endShape = end.type === 'binding' ? this.getShape(end.boundShapeId) : undefined - - const parentPageId = this.getAncestorPageId(arrow) - if (!parentPageId) return - - let nextParentId: TLParentId - if (startShape && endShape) { - // if arrow has two bindings, always parent arrow to closest common ancestor of the bindings - nextParentId = this.findCommonAncestor([startShape, endShape]) ?? parentPageId - } else if (startShape || endShape) { - // if arrow has one binding, keep arrow on its own page - nextParentId = parentPageId - } else { - return - } - - if (nextParentId && nextParentId !== arrow.parentId) { - this.reparentShapes([arrowId], nextParentId) - } - - const reparentedArrow = this.getShape(arrowId) - if (!reparentedArrow) throw Error('no reparented arrow') - - const startSibling = this.getShapeNearestSibling(reparentedArrow, startShape) - const endSibling = this.getShapeNearestSibling(reparentedArrow, endShape) - - let highestSibling: TLShape | undefined - - if (startSibling && endSibling) { - highestSibling = startSibling.index > endSibling.index ? startSibling : endSibling - } else if (startSibling && !endSibling) { - highestSibling = startSibling - } else if (endSibling && !startSibling) { - highestSibling = endSibling - } else { - return - } - - let finalIndex: string - - const higherSiblings = this.getSortedChildIdsForParent(highestSibling.parentId) - .map((id) => this.getShape(id)!) - .filter((sibling) => sibling.index > highestSibling!.index) - - if (higherSiblings.length) { - // there are siblings above the highest bound sibling, we need to - // insert between them. - - // if the next sibling is also a bound arrow though, we can end up - // all fighting for the same indexes. so lets find the next - // non-arrow sibling... - const nextHighestNonArrowSibling = higherSiblings.find((sibling) => sibling.type !== 'arrow') - - if ( - // ...then, if we're above the last shape we want to be above... - reparentedArrow.index > highestSibling.index && - // ...but below the next non-arrow sibling... - (!nextHighestNonArrowSibling || reparentedArrow.index < nextHighestNonArrowSibling.index) - ) { - // ...then we're already in the right place. no need to update! - return - } - - // otherwise, we need to find the index between the highest sibling - // we want to be above, and the next highest sibling we want to be - // below: - finalIndex = getIndexBetween(highestSibling.index, higherSiblings[0].index) - } else { - // if there are no siblings above us, we can just get the next index: - finalIndex = getIndexAbove(highestSibling.index) - } - - if (finalIndex !== reparentedArrow.index) { - this.updateShapes([{ id: arrowId, type: 'arrow', index: finalIndex }]) - } - } - - /** @internal */ - private _unbindArrowTerminal(arrow: TLArrowShape, handleId: 'start' | 'end') { - const { x, y } = getArrowTerminalsInArrowSpace(this, arrow)[handleId] - this.store.put([{ ...arrow, props: { ...arrow.props, [handleId]: { type: 'point', x, y } } }]) - } - @computed private get arrowInfoCache() { return this.store.createComputedCache('arrow infoCache', (shape) => { @@ -727,181 +924,6 @@ export class Editor extends EventEmitter { // return update ?? next // } - /** @internal */ - private _shapeWillBeDeleted(deletedShape: TLShape) { - // if the deleted shape has a parent shape make sure we call it's onChildrenChange callback - if (deletedShape.parentId && isShapeId(deletedShape.parentId)) { - this._invalidParents.add(deletedShape.parentId) - } - // clean up any arrows bound to this shape - const bindings = this._arrowBindingsIndex.value[deletedShape.id] - if (bindings?.length) { - for (const { arrowId, handleId } of bindings) { - const arrow = this.getShape(arrowId) - if (!arrow) continue - this._unbindArrowTerminal(arrow, handleId) - } - } - const deletedIds = new Set([deletedShape.id]) - const updates = compact( - this.pageStates.map((pageState) => { - return this._cleanupInstancePageState(pageState, deletedIds) - }) - ) - - if (updates.length) { - this.store.put(updates) - } - } - - /** @internal */ - private _arrowDidUpdate(arrow: TLArrowShape) { - // if the shape is an arrow and its bound shape is on another page - // or was deleted, unbind it - for (const handle of ['start', 'end'] as const) { - const terminal = arrow.props[handle] - if (terminal.type !== 'binding') continue - const boundShape = this.getShape(terminal.boundShapeId) - const isShapeInSamePageAsArrow = - this.getAncestorPageId(arrow) === this.getAncestorPageId(boundShape) - if (!boundShape || !isShapeInSamePageAsArrow) { - this._unbindArrowTerminal(arrow, handle) - } - } - - // always check the arrow parents - this._reparentArrow(arrow.id) - } - - /** - * _invalidParents is used to trigger the 'onChildrenChange' callback that shapes can have. - * - * @internal - */ - private readonly _invalidParents = new Set() - - /** @internal */ - private _complete() { - for (const parentId of this._invalidParents) { - this._invalidParents.delete(parentId) - const parent = this.getShape(parentId) - if (!parent) continue - - const util = this.getShapeUtil(parent) - const changes = util.onChildrenChange?.(parent) - - if (changes?.length) { - this.updateShapes(changes, true) - } - } - - this.emit('update') - } - - /** @internal */ - private _shapeDidChange(prev: TLShape, next: TLShape) { - if (this.isShapeOfType(next, 'arrow')) { - this._arrowDidUpdate(next) - } - - // if the shape's parent changed and it is bound to an arrow, update the arrow's parent - if (prev.parentId !== next.parentId) { - const reparentBoundArrows = (id: TLShapeId) => { - const boundArrows = this._arrowBindingsIndex.value[id] - if (boundArrows?.length) { - for (const arrow of boundArrows) { - this._reparentArrow(arrow.arrowId) - } - } - } - reparentBoundArrows(next.id) - this.visitDescendants(next.id, reparentBoundArrows) - } - - // if this shape moved to a new page, clean up any previous page's instance state - if (prev.parentId !== next.parentId && isPageId(next.parentId)) { - const allMovingIds = new Set([prev.id]) - this.visitDescendants(prev.id, (id) => { - allMovingIds.add(id) - }) - - for (const instancePageState of this.pageStates) { - if (instancePageState.pageId === next.parentId) continue - const nextPageState = this._cleanupInstancePageState(instancePageState, allMovingIds) - - if (nextPageState) { - this.store.put([nextPageState]) - } - } - } - - if (prev.parentId && isShapeId(prev.parentId)) { - this._invalidParents.add(prev.parentId) - } - - if (next.parentId !== prev.parentId && isShapeId(next.parentId)) { - this._invalidParents.add(next.parentId) - } - } - - /** @internal */ - private _pageStateDidChange(prev: TLInstancePageState, next: TLInstancePageState) { - if (prev?.selectedShapeIds !== next?.selectedShapeIds) { - // ensure that descendants and ancestors are not selected at the same time - const filtered = next.selectedShapeIds.filter((id) => { - let parentId = this.getShape(id)?.parentId - while (isShapeId(parentId)) { - if (next.selectedShapeIds.includes(parentId)) { - return false - } - parentId = this.getShape(parentId)?.parentId - } - return true - }) - - let nextFocusedGroupId: null | TLShapeId = null - - if (filtered.length > 0) { - const commonGroupAncestor = this.findCommonAncestor( - compact(filtered.map((id) => this.getShape(id))), - (shape) => this.isShapeOfType(shape, 'group') - ) - - if (commonGroupAncestor) { - nextFocusedGroupId = commonGroupAncestor - } - } else { - if (next?.focusedGroupId) { - nextFocusedGroupId = next.focusedGroupId - } - } - - if ( - filtered.length !== next.selectedShapeIds.length || - nextFocusedGroupId !== next.focusedGroupId - ) { - this.store.put([ - { ...next, selectedShapeIds: filtered, focusedGroupId: nextFocusedGroupId ?? null }, - ]) - } - } - } - - /** @internal */ - private _pageWillBeDeleted(page: TLPage) { - // page was deleted, need to check whether it's the current page and select another one if so - if (this.instanceState.currentPageId !== page.id) return - - const backupPageId = this.pages.find((p) => p.id !== page.id)?.id - if (!backupPageId) return - this.store.put([{ ...this.instanceState, currentPageId: backupPageId }]) - - // delete the camera and state for the page if necessary - const cameraId = CameraRecordType.createId(page.id) - const instance_PageStateId = InstancePageStateRecordType.createId(page.id) - this.store.remove([cameraId, instance_PageStateId]) - } - /* --------------------- Errors --------------------- */ /** @internal */ @@ -1802,54 +1824,6 @@ export class Editor extends EventEmitter { return this } - /** @internal */ - private _cleanupInstancePageState( - prevPageState: TLInstancePageState, - shapesNoLongerInPage: Set - ) { - let nextPageState = null as null | TLInstancePageState - - const selectedShapeIds = prevPageState.selectedShapeIds.filter( - (id) => !shapesNoLongerInPage.has(id) - ) - if (selectedShapeIds.length !== prevPageState.selectedShapeIds.length) { - if (!nextPageState) nextPageState = { ...prevPageState } - nextPageState.selectedShapeIds = selectedShapeIds - } - - const erasingShapeIds = prevPageState.erasingShapeIds.filter( - (id) => !shapesNoLongerInPage.has(id) - ) - if (erasingShapeIds.length !== prevPageState.erasingShapeIds.length) { - if (!nextPageState) nextPageState = { ...prevPageState } - nextPageState.erasingShapeIds = erasingShapeIds - } - - if (prevPageState.hoveredShapeId && shapesNoLongerInPage.has(prevPageState.hoveredShapeId)) { - if (!nextPageState) nextPageState = { ...prevPageState } - nextPageState.hoveredShapeId = null - } - - if (prevPageState.editingShapeId && shapesNoLongerInPage.has(prevPageState.editingShapeId)) { - if (!nextPageState) nextPageState = { ...prevPageState } - nextPageState.editingShapeId = null - } - - const hintingShapeIds = prevPageState.hintingShapeIds.filter( - (id) => !shapesNoLongerInPage.has(id) - ) - if (hintingShapeIds.length !== prevPageState.hintingShapeIds.length) { - if (!nextPageState) nextPageState = { ...prevPageState } - nextPageState.hintingShapeIds = hintingShapeIds - } - - if (prevPageState.focusedGroupId && shapesNoLongerInPage.has(prevPageState.focusedGroupId)) { - if (!nextPageState) nextPageState = { ...prevPageState } - nextPageState.focusedGroupId = null - } - return nextPageState - } - /* --------------------- Camera --------------------- */ /** @internal */ diff --git a/packages/editor/src/lib/editor/managers/HistoryManager.test.ts b/packages/editor/src/lib/editor/managers/HistoryManager.test.ts index 4a73a43bd..ef784d346 100644 --- a/packages/editor/src/lib/editor/managers/HistoryManager.test.ts +++ b/packages/editor/src/lib/editor/managers/HistoryManager.test.ts @@ -2,13 +2,9 @@ import { HistoryManager } from './HistoryManager' import { stack } from './Stack' function createCounterHistoryManager() { - const manager = new HistoryManager( - { emit: () => void null }, - () => null, - () => { - return - } - ) + const manager = new HistoryManager({ emit: () => void null }, () => { + return + }) const state = { count: 0, name: 'David', diff --git a/packages/editor/src/lib/editor/managers/HistoryManager.ts b/packages/editor/src/lib/editor/managers/HistoryManager.ts index 45d906c63..993435b7a 100644 --- a/packages/editor/src/lib/editor/managers/HistoryManager.ts +++ b/packages/editor/src/lib/editor/managers/HistoryManager.ts @@ -29,10 +29,11 @@ export class HistoryManager< constructor( private readonly ctx: CTX, - private readonly onBatchComplete: () => void, private readonly annotateError: (error: unknown) => void ) {} + onBatchComplete: () => void = () => void null + private _commands: Record> = {} get numUndos() { diff --git a/packages/editor/src/lib/editor/managers/SIdeEffectManager.test.ts b/packages/editor/src/lib/editor/managers/SIdeEffectManager.test.ts new file mode 100644 index 000000000..e8ef393e2 --- /dev/null +++ b/packages/editor/src/lib/editor/managers/SIdeEffectManager.test.ts @@ -0,0 +1,20 @@ +// let editor: Editor +// beforeEach(() => { +// editor = new Editor({ +// shapeUtils: [], +// tools: [], +// store: createTLStore({ shapeUtils: [] }), +// getContainer: () => document.body, +// }) +// }) + +describe('Side effect manager', () => { + it.todo('Registers an onBeforeCreate handler') + it.todo('Registers an onAfterCreate handler') + it.todo('Registers an onBeforeChange handler') + it.todo('Registers an onAfterChange handler') + it.todo('Registers an onBeforeDelete handler') + it.todo('Registers an onAfterDelete handler') + it.todo('Registers a batch start handler') + it.todo('Registers a batch complete handler') +}) diff --git a/packages/editor/src/lib/editor/managers/SideEffectManager.ts b/packages/editor/src/lib/editor/managers/SideEffectManager.ts new file mode 100644 index 000000000..0a3b6d0ac --- /dev/null +++ b/packages/editor/src/lib/editor/managers/SideEffectManager.ts @@ -0,0 +1,245 @@ +import { TLRecord, TLStore } from '@tldraw/tlschema' + +/** @public */ +export type TLBeforeCreateHandler = (record: R, source: 'remote' | 'user') => R +/** @public */ +export type TLAfterCreateHandler = ( + record: R, + source: 'remote' | 'user' +) => void +/** @public */ +export type TLBeforeChangeHandler = ( + prev: R, + next: R, + source: 'remote' | 'user' +) => R +/** @public */ +export type TLAfterChangeHandler = ( + prev: R, + next: R, + source: 'remote' | 'user' +) => void +/** @public */ +export type TLBeforeDeleteHandler = ( + record: R, + source: 'remote' | 'user' +) => void | false +/** @public */ +export type TLAfterDeleteHandler = ( + record: R, + source: 'remote' | 'user' +) => void +/** @public */ +export type TLBatchCompleteHandler = () => void + +/** + * The side effect manager (aka a "correct state enforcer") is responsible + * for making sure that the editor's state is always correct. This includes + * things like: deleting a shape if its parent is deleted; unbinding + * arrows when their binding target is deleted; etc. + * + * @public + */ +export class SideEffectManager< + CTX extends { + store: TLStore + history: { onBatchComplete: () => void } + } +> { + constructor(public editor: CTX) { + editor.store.onBeforeCreate = (record, source) => { + const handlers = this._beforeCreateHandlers[ + record.typeName + ] as TLBeforeCreateHandler[] + if (handlers) { + let r = record + for (const handler of handlers) { + r = handler(r, source) + } + return r + } + + return record + } + + editor.store.onAfterCreate = (record, source) => { + const handlers = this._afterCreateHandlers[ + record.typeName + ] as TLAfterCreateHandler[] + if (handlers) { + for (const handler of handlers) { + handler(record, source) + } + } + } + + editor.store.onBeforeChange = (prev, next, source) => { + const handlers = this._beforeChangeHandlers[ + next.typeName + ] as TLBeforeChangeHandler[] + if (handlers) { + let r = next + for (const handler of handlers) { + r = handler(prev, r, source) + } + return r + } + + return next + } + + let updateDepth = 0 + + editor.store.onAfterChange = (prev, next, source) => { + updateDepth++ + + if (updateDepth > 1000) { + console.error('[CleanupManager.onAfterChange] Maximum update depth exceeded, bailing out.') + } else { + const handlers = this._afterChangeHandlers[ + next.typeName + ] as TLAfterChangeHandler[] + if (handlers) { + for (const handler of handlers) { + handler(prev, next, source) + } + } + } + + updateDepth-- + } + + editor.store.onBeforeDelete = (record, source) => { + const handlers = this._beforeDeleteHandlers[ + record.typeName + ] as TLBeforeDeleteHandler[] + if (handlers) { + for (const handler of handlers) { + if (handler(record, source) === false) { + return false + } + } + } + } + + editor.store.onAfterDelete = (record, source) => { + const handlers = this._afterDeleteHandlers[ + record.typeName + ] as TLAfterDeleteHandler[] + if (handlers) { + for (const handler of handlers) { + handler(record, source) + } + } + } + + editor.history.onBatchComplete = () => { + this._batchCompleteHandlers.forEach((fn) => fn()) + } + } + + private _beforeCreateHandlers: Partial<{ + [K in TLRecord['typeName']]: TLBeforeCreateHandler[] + }> = {} + private _afterCreateHandlers: Partial<{ + [K in TLRecord['typeName']]: TLAfterCreateHandler[] + }> = {} + private _beforeChangeHandlers: Partial<{ + [K in TLRecord['typeName']]: TLBeforeChangeHandler[] + }> = {} + private _afterChangeHandlers: Partial<{ + [K in TLRecord['typeName']]: TLAfterChangeHandler[] + }> = {} + + private _beforeDeleteHandlers: Partial<{ + [K in TLRecord['typeName']]: TLBeforeDeleteHandler[] + }> = {} + + private _afterDeleteHandlers: Partial<{ + [K in TLRecord['typeName']]: TLAfterDeleteHandler[] + }> = {} + + private _batchCompleteHandlers: TLBatchCompleteHandler[] = [] + + registerBeforeCreateHandler( + typeName: T, + handler: TLBeforeCreateHandler + ) { + const handlers = this._beforeCreateHandlers[typeName] as TLBeforeCreateHandler[] + if (!handlers) this._beforeCreateHandlers[typeName] = [] + this._beforeCreateHandlers[typeName]!.push(handler) + } + + registerAfterCreateHandler( + typeName: T, + handler: TLAfterCreateHandler + ) { + const handlers = this._afterCreateHandlers[typeName] as TLAfterCreateHandler[] + if (!handlers) this._afterCreateHandlers[typeName] = [] + this._afterCreateHandlers[typeName]!.push(handler) + } + + registerBeforeChangeHandler( + typeName: T, + handler: TLBeforeChangeHandler + ) { + const handlers = this._beforeChangeHandlers[typeName] as TLBeforeChangeHandler[] + if (!handlers) this._beforeChangeHandlers[typeName] = [] + this._beforeChangeHandlers[typeName]!.push(handler) + } + + registerAfterChangeHandler( + typeName: T, + handler: TLAfterChangeHandler + ) { + const handlers = this._afterChangeHandlers[typeName] as TLAfterChangeHandler[] + if (!handlers) this._afterChangeHandlers[typeName] = [] + this._afterChangeHandlers[typeName]!.push(handler as TLAfterChangeHandler) + } + + registerBeforeDeleteHandler( + typeName: T, + handler: TLBeforeDeleteHandler + ) { + const handlers = this._beforeDeleteHandlers[typeName] as TLBeforeDeleteHandler[] + if (!handlers) this._beforeDeleteHandlers[typeName] = [] + this._beforeDeleteHandlers[typeName]!.push(handler as TLBeforeDeleteHandler) + } + + registerAfterDeleteHandler( + typeName: T, + handler: TLAfterDeleteHandler + ) { + const handlers = this._afterDeleteHandlers[typeName] as TLAfterDeleteHandler[] + if (!handlers) this._afterDeleteHandlers[typeName] = [] + this._afterDeleteHandlers[typeName]!.push(handler as TLAfterDeleteHandler) + } + + /** + * Register a handler to be called when a store completes a batch. + * + * @example + * ```ts + * let count = 0 + * + * editor.cleanup.registerBatchCompleteHandler(() => count++) + * + * editor.selectAll() + * expect(count).toBe(1) + * + * editor.batch(() => { + * editor.selectNone() + * editor.selectAll() + * }) + * + * expect(count).toBe(2) + * ``` + * + * @param handler - The handler to call + * + * @public + */ + registerBatchCompleteHandler(handler: TLBatchCompleteHandler) { + this._batchCompleteHandlers.push(handler) + } +} diff --git a/packages/store/api-report.md b/packages/store/api-report.md index efcbfb7b3..1537a867b 100644 --- a/packages/store/api-report.md +++ b/packages/store/api-report.md @@ -244,6 +244,8 @@ export class Store { // (undocumented) _flushHistory(): void; get: >(id: K) => RecFromId | undefined; + // (undocumented) + getRecordType: (record: R) => T; getSnapshot(scope?: 'all' | RecordScope): StoreSnapshot; has: >(id: K) => boolean; readonly history: Atom>; @@ -255,10 +257,12 @@ export class Store { // @internal (undocumented) markAsPossiblyCorrupted(): void; mergeRemoteChanges: (fn: () => void) => void; - onAfterChange?: (prev: R, next: R) => void; - onAfterCreate?: (record: R) => void; - onAfterDelete?: (prev: R) => void; - onBeforeDelete?: (prev: R) => void; + onAfterChange?: (prev: R, next: R, source: 'remote' | 'user') => void; + onAfterCreate?: (record: R, source: 'remote' | 'user') => void; + onAfterDelete?: (prev: R, source: 'remote' | 'user') => void; + onBeforeChange?: (prev: R, next: R, source: 'remote' | 'user') => R; + onBeforeCreate?: (next: R, source: 'remote' | 'user') => R; + onBeforeDelete?: (prev: R, source: 'remote' | 'user') => false | void; // (undocumented) readonly props: Props; put: (records: R[], phaseOverride?: 'initialize') => void; diff --git a/packages/store/src/lib/Store.ts b/packages/store/src/lib/Store.ts index ec3467b86..56c958447 100644 --- a/packages/store/src/lib/Store.ts +++ b/packages/store/src/lib/Store.ts @@ -297,13 +297,29 @@ export class Store { this.allRecords().forEach((record) => this.schema.validateRecord(this, record, phase, null)) } + /** + * A callback fired after each record's change. + * + * @param prev - The previous value, if any. + * @param next - The next value. + */ + onBeforeCreate?: (next: R, source: 'remote' | 'user') => R + /** * A callback fired after a record is created. Use this to perform related updates to other * records in the store. * * @param record - The record to be created */ - onAfterCreate?: (record: R) => void + onAfterCreate?: (record: R, source: 'remote' | 'user') => void + + /** + * A callback before after each record's change. + * + * @param prev - The previous value, if any. + * @param next - The next value. + */ + onBeforeChange?: (prev: R, next: R, source: 'remote' | 'user') => R /** * A callback fired after each record's change. @@ -311,21 +327,21 @@ export class Store { * @param prev - The previous value, if any. * @param next - The next value. */ - onAfterChange?: (prev: R, next: R) => void + onAfterChange?: (prev: R, next: R, source: 'remote' | 'user') => void /** * A callback fired before a record is deleted. * * @param prev - The record that will be deleted. */ - onBeforeDelete?: (prev: R) => void + onBeforeDelete?: (prev: R, source: 'remote' | 'user') => false | void /** * A callback fired after a record is deleted. * * @param prev - The record that will be deleted. */ - onAfterDelete?: (prev: R) => void + onAfterDelete?: (prev: R, source: 'remote' | 'user') => void // used to avoid running callbacks when rolling back changes in sync client private _runCallbacks = true @@ -353,12 +369,18 @@ export class Store { // changes (e.g. additions, deletions, or updates that produce a new value). let didChange = false + const beforeCreate = this.onBeforeCreate && this._runCallbacks ? this.onBeforeCreate : null + const beforeUpdate = this.onBeforeChange && this._runCallbacks ? this.onBeforeChange : null + const source = this.isMergingRemoteChanges ? 'remote' : 'user' + for (let i = 0, n = records.length; i < n; i++) { record = records[i] const recordAtom = (map ?? currentMap)[record.id as IdOf] if (recordAtom) { + if (beforeUpdate) record = beforeUpdate(recordAtom.value, record, source) + // If we already have an atom for this record, update its value. const initialValue = recordAtom.__unsafe__getWithoutCapture() @@ -382,6 +404,8 @@ export class Store { updates[record.id] = [initialValue, finalValue] } } else { + if (beforeCreate) record = beforeCreate(record, source) + didChange = true // If we don't have an atom, create one. @@ -418,20 +442,22 @@ export class Store { removed: {} as Record, R>, }) - const { onAfterCreate, onAfterChange } = this + if (this._runCallbacks) { + const { onAfterCreate, onAfterChange } = this - if (onAfterCreate && this._runCallbacks) { - // Run the onAfterChange callback for addition. - Object.values(additions).forEach((record) => { - onAfterCreate(record) - }) - } + if (onAfterCreate) { + // Run the onAfterChange callback for addition. + Object.values(additions).forEach((record) => { + onAfterCreate(record, source) + }) + } - if (onAfterChange && this._runCallbacks) { - // Run the onAfterChange callback for update. - Object.values(updates).forEach(([from, to]) => { - onAfterChange(from, to) - }) + if (onAfterChange) { + // Run the onAfterChange callback for update. + Object.values(updates).forEach(([from, to]) => { + onAfterChange(from, to, source) + }) + } } }) } @@ -444,12 +470,17 @@ export class Store { */ remove = (ids: IdOf[]): void => { transact(() => { + const cancelled = [] as IdOf[] + const source = this.isMergingRemoteChanges ? 'remote' : 'user' + if (this.onBeforeDelete && this._runCallbacks) { for (const id of ids) { const atom = this.atoms.__unsafe__getWithoutCapture()[id] if (!atom) continue - this.onBeforeDelete(atom.value) + if (this.onBeforeDelete(atom.value, source) === false) { + cancelled.push(id) + } } } @@ -460,6 +491,7 @@ export class Store { let result: typeof atoms | undefined = undefined for (const id of ids) { + if (cancelled.includes(id)) continue if (!(id in atoms)) continue if (!result) result = { ...atoms } if (!removed) removed = {} as Record, R> @@ -476,8 +508,12 @@ export class Store { // If we have an onAfterChange, run it for each removed record. if (this.onAfterDelete && this._runCallbacks) { + let record: R for (let i = 0, n = ids.length; i < n; i++) { - this.onAfterDelete(removed[ids[i]]) + record = removed[ids[i]] + if (record) { + this.onAfterDelete(record, source) + } } } }) @@ -596,6 +632,7 @@ export class Store { console.error(`Record ${id} not found. This is probably an error`) return } + this.put([updater(atom.__unsafe__getWithoutCapture() as any as RecFromId) as any]) } @@ -752,6 +789,14 @@ export class Store { } } + getRecordType = (record: R): T => { + const type = this.schema.types[record.typeName as R['typeName']] + if (!type) { + throw new Error(`Record type ${record.typeName} not found`) + } + return type as unknown as T + } + private _integrityChecker?: () => void | undefined /** @internal */