SideEffectManager (#1785)

This PR extracts the side effect manager from #1778.

### Change Type

- [x] `major` — Breaking change
This commit is contained in:
Steve Ruiz 2023-08-02 12:05:14 +01:00 committed by GitHub
parent c478d75117
commit 507bba82fd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 664 additions and 378 deletions

View file

@ -944,6 +944,7 @@ export class Editor extends EventEmitter<TLEventMap> {
};
get sharedOpacity(): SharedStyle<number>;
get sharedStyles(): ReadonlySharedStyleMap;
readonly sideEffects: SideEffectManager<this>;
slideCamera(opts?: {
speed: number;
direction: VecLike;

View file

@ -105,6 +105,7 @@ import { deriveShapeIdsInCurrentPage } from './derivations/shapeIdsInCurrentPage
import { ClickManager } from './managers/ClickManager'
import { EnvironmentManager } from './managers/EnvironmentManager'
import { HistoryManager } from './managers/HistoryManager'
import { SideEffectManager } from './managers/SideEffectManager'
import { SnapManager } from './managers/SnapManager'
import { TextManager } from './managers/TextManager'
import { TickManager } from './managers/TickManager'
@ -242,47 +243,328 @@ export class Editor extends EventEmitter<TLEventMap> {
this.environment = new EnvironmentManager(this)
this.store.onBeforeDelete = (record) => {
if (record.typeName === 'shape') {
this._shapeWillBeDeleted(record)
} else if (record.typeName === 'page') {
this._pageWillBeDeleted(record)
}
}
// Cleanup
this.store.onAfterChange = (prev, next) => {
this._updateDepth++
if (this._updateDepth > 1000) {
console.error('[onAfterChange] Maximum update depth exceeded, bailing out.')
}
if (prev.typeName === 'shape' && next.typeName === 'shape') {
this._shapeDidChange(prev, next)
} else if (
prev.typeName === 'instance_page_state' &&
next.typeName === 'instance_page_state'
) {
this._pageStateDidChange(prev, next)
const invalidParents = new Set<TLShapeId>()
const reparentArrow = (arrowId: TLArrowShape['id']) => {
const arrow = this.getShape<TLArrowShape>(arrowId)
if (!arrow) return
const { start, end } = arrow.props
const startShape = start.type === 'binding' ? this.getShape(start.boundShapeId) : undefined
const endShape = end.type === 'binding' ? this.getShape(end.boundShapeId) : undefined
const parentPageId = this.getAncestorPageId(arrow)
if (!parentPageId) return
let nextParentId: TLParentId
if (startShape && endShape) {
// if arrow has two bindings, always parent arrow to closest common ancestor of the bindings
nextParentId = this.findCommonAncestor([startShape, endShape]) ?? parentPageId
} else if (startShape || endShape) {
// if arrow has one binding, keep arrow on its own page
nextParentId = parentPageId
} else {
return
}
this._updateDepth--
}
this.store.onAfterCreate = (record) => {
if (record.typeName === 'shape' && this.isShapeOfType<TLArrowShape>(record, 'arrow')) {
this._arrowDidUpdate(record)
if (nextParentId && nextParentId !== arrow.parentId) {
this.reparentShapes([arrowId], nextParentId)
}
if (record.typeName === 'page') {
const cameraId = CameraRecordType.createId(record.id)
const _pageStateId = InstancePageStateRecordType.createId(record.id)
if (!this.store.has(cameraId)) {
this.store.put([CameraRecordType.create({ id: cameraId })])
const reparentedArrow = this.getShape<TLArrowShape>(arrowId)
if (!reparentedArrow) throw Error('no reparented arrow')
const startSibling = this.getShapeNearestSibling(reparentedArrow, startShape)
const endSibling = this.getShapeNearestSibling(reparentedArrow, endShape)
let highestSibling: TLShape | undefined
if (startSibling && endSibling) {
highestSibling = startSibling.index > endSibling.index ? startSibling : endSibling
} else if (startSibling && !endSibling) {
highestSibling = startSibling
} else if (endSibling && !startSibling) {
highestSibling = endSibling
} else {
return
}
let finalIndex: string
const higherSiblings = this.getSortedChildIdsForParent(highestSibling.parentId)
.map((id) => this.getShape(id)!)
.filter((sibling) => sibling.index > highestSibling!.index)
if (higherSiblings.length) {
// there are siblings above the highest bound sibling, we need to
// insert between them.
// if the next sibling is also a bound arrow though, we can end up
// all fighting for the same indexes. so lets find the next
// non-arrow sibling...
const nextHighestNonArrowSibling = higherSiblings.find(
(sibling) => sibling.type !== 'arrow'
)
if (
// ...then, if we're above the last shape we want to be above...
reparentedArrow.index > highestSibling.index &&
// ...but below the next non-arrow sibling...
(!nextHighestNonArrowSibling || reparentedArrow.index < nextHighestNonArrowSibling.index)
) {
// ...then we're already in the right place. no need to update!
return
}
if (!this.store.has(_pageStateId)) {
// otherwise, we need to find the index between the highest sibling
// we want to be above, and the next highest sibling we want to be
// below:
finalIndex = getIndexBetween(highestSibling.index, higherSiblings[0].index)
} else {
// if there are no siblings above us, we can just get the next index:
finalIndex = getIndexAbove(highestSibling.index)
}
if (finalIndex !== reparentedArrow.index) {
this.updateShapes<TLArrowShape>([{ id: arrowId, type: 'arrow', index: finalIndex }])
}
}
const unbindArrowTerminal = (arrow: TLArrowShape, handleId: 'start' | 'end') => {
const { x, y } = getArrowTerminalsInArrowSpace(this, arrow)[handleId]
this.store.put([{ ...arrow, props: { ...arrow.props, [handleId]: { type: 'point', x, y } } }])
}
const arrowDidUpdate = (arrow: TLArrowShape) => {
// if the shape is an arrow and its bound shape is on another page
// or was deleted, unbind it
for (const handle of ['start', 'end'] as const) {
const terminal = arrow.props[handle]
if (terminal.type !== 'binding') continue
const boundShape = this.getShape(terminal.boundShapeId)
const isShapeInSamePageAsArrow =
this.getAncestorPageId(arrow) === this.getAncestorPageId(boundShape)
if (!boundShape || !isShapeInSamePageAsArrow) {
unbindArrowTerminal(arrow, handle)
}
}
// always check the arrow parents
reparentArrow(arrow.id)
}
const cleanupInstancePageState = (
prevPageState: TLInstancePageState,
shapesNoLongerInPage: Set<TLShapeId>
) => {
let nextPageState = null as null | TLInstancePageState
const selectedShapeIds = prevPageState.selectedShapeIds.filter(
(id) => !shapesNoLongerInPage.has(id)
)
if (selectedShapeIds.length !== prevPageState.selectedShapeIds.length) {
if (!nextPageState) nextPageState = { ...prevPageState }
nextPageState.selectedShapeIds = selectedShapeIds
}
const erasingShapeIds = prevPageState.erasingShapeIds.filter(
(id) => !shapesNoLongerInPage.has(id)
)
if (erasingShapeIds.length !== prevPageState.erasingShapeIds.length) {
if (!nextPageState) nextPageState = { ...prevPageState }
nextPageState.erasingShapeIds = erasingShapeIds
}
if (prevPageState.hoveredShapeId && shapesNoLongerInPage.has(prevPageState.hoveredShapeId)) {
if (!nextPageState) nextPageState = { ...prevPageState }
nextPageState.hoveredShapeId = null
}
if (prevPageState.editingShapeId && shapesNoLongerInPage.has(prevPageState.editingShapeId)) {
if (!nextPageState) nextPageState = { ...prevPageState }
nextPageState.editingShapeId = null
}
const hintingShapeIds = prevPageState.hintingShapeIds.filter(
(id) => !shapesNoLongerInPage.has(id)
)
if (hintingShapeIds.length !== prevPageState.hintingShapeIds.length) {
if (!nextPageState) nextPageState = { ...prevPageState }
nextPageState.hintingShapeIds = hintingShapeIds
}
if (prevPageState.focusedGroupId && shapesNoLongerInPage.has(prevPageState.focusedGroupId)) {
if (!nextPageState) nextPageState = { ...prevPageState }
nextPageState.focusedGroupId = null
}
return nextPageState
}
this.sideEffects = new SideEffectManager(this)
this.sideEffects.registerBatchCompleteHandler(() => {
for (const parentId of invalidParents) {
invalidParents.delete(parentId)
const parent = this.getShape(parentId)
if (!parent) continue
const util = this.getShapeUtil(parent)
const changes = util.onChildrenChange?.(parent)
if (changes?.length) {
this.updateShapes(changes, true)
}
}
this.emit('update')
})
this.sideEffects.registerBeforeDeleteHandler('shape', (record) => {
// if the deleted shape has a parent shape make sure we call it's onChildrenChange callback
if (record.parentId && isShapeId(record.parentId)) {
invalidParents.add(record.parentId)
}
// clean up any arrows bound to this shape
const bindings = this._arrowBindingsIndex.value[record.id]
if (bindings?.length) {
for (const { arrowId, handleId } of bindings) {
const arrow = this.getShape<TLArrowShape>(arrowId)
if (!arrow) continue
unbindArrowTerminal(arrow, handleId)
}
}
const deletedIds = new Set([record.id])
const updates = compact(
this.pageStates.map((pageState) => {
return cleanupInstancePageState(pageState, deletedIds)
})
)
if (updates.length) {
this.store.put(updates)
}
})
this.sideEffects.registerBeforeDeleteHandler('page', (record) => {
// page was deleted, need to check whether it's the current page and select another one if so
if (this.instanceState.currentPageId !== record.id) return
const backupPageId = this.pages.find((p) => p.id !== record.id)?.id
if (!backupPageId) return
this.store.put([{ ...this.instanceState, currentPageId: backupPageId }])
// delete the camera and state for the page if necessary
const cameraId = CameraRecordType.createId(record.id)
const instance_PageStateId = InstancePageStateRecordType.createId(record.id)
this.store.remove([cameraId, instance_PageStateId])
})
this.sideEffects.registerAfterChangeHandler('shape', (prev, next) => {
if (this.isShapeOfType<TLArrowShape>(next, 'arrow')) {
arrowDidUpdate(next)
}
// if the shape's parent changed and it is bound to an arrow, update the arrow's parent
if (prev.parentId !== next.parentId) {
const reparentBoundArrows = (id: TLShapeId) => {
const boundArrows = this._arrowBindingsIndex.value[id]
if (boundArrows?.length) {
for (const arrow of boundArrows) {
reparentArrow(arrow.arrowId)
}
}
}
reparentBoundArrows(next.id)
this.visitDescendants(next.id, reparentBoundArrows)
}
// if this shape moved to a new page, clean up any previous page's instance state
if (prev.parentId !== next.parentId && isPageId(next.parentId)) {
const allMovingIds = new Set([prev.id])
this.visitDescendants(prev.id, (id) => {
allMovingIds.add(id)
})
for (const instancePageState of this.pageStates) {
if (instancePageState.pageId === next.parentId) continue
const nextPageState = cleanupInstancePageState(instancePageState, allMovingIds)
if (nextPageState) {
this.store.put([nextPageState])
}
}
}
if (prev.parentId && isShapeId(prev.parentId)) {
invalidParents.add(prev.parentId)
}
if (next.parentId !== prev.parentId && isShapeId(next.parentId)) {
invalidParents.add(next.parentId)
}
})
this.sideEffects.registerAfterChangeHandler('instance_page_state', (prev, next) => {
if (prev?.selectedShapeIds !== next?.selectedShapeIds) {
// ensure that descendants and ancestors are not selected at the same time
const filtered = next.selectedShapeIds.filter((id) => {
let parentId = this.getShape(id)?.parentId
while (isShapeId(parentId)) {
if (next.selectedShapeIds.includes(parentId)) {
return false
}
parentId = this.getShape(parentId)?.parentId
}
return true
})
let nextFocusedGroupId: null | TLShapeId = null
if (filtered.length > 0) {
const commonGroupAncestor = this.findCommonAncestor(
compact(filtered.map((id) => this.getShape(id))),
(shape) => this.isShapeOfType<TLGroupShape>(shape, 'group')
)
if (commonGroupAncestor) {
nextFocusedGroupId = commonGroupAncestor
}
} else {
if (next?.focusedGroupId) {
nextFocusedGroupId = next.focusedGroupId
}
}
if (
filtered.length !== next.selectedShapeIds.length ||
nextFocusedGroupId !== next.focusedGroupId
) {
this.store.put([
InstancePageStateRecordType.create({ id: _pageStateId, pageId: record.id }),
{ ...next, selectedShapeIds: filtered, focusedGroupId: nextFocusedGroupId ?? null },
])
}
}
}
})
this.sideEffects.registerAfterCreateHandler('shape', (record) => {
if (this.isShapeOfType<TLArrowShape>(record, 'arrow')) {
arrowDidUpdate(record)
}
})
this.sideEffects.registerAfterCreateHandler('page', (record) => {
const cameraId = CameraRecordType.createId(record.id)
const _pageStateId = InstancePageStateRecordType.createId(record.id)
if (!this.store.has(cameraId)) {
this.store.put([CameraRecordType.create({ id: cameraId })])
}
if (!this.store.has(_pageStateId)) {
this.store.put([
InstancePageStateRecordType.create({ id: _pageStateId, pageId: record.id }),
])
}
})
this._shapeIdsOnCurrentPage = deriveShapeIdsInCurrentPage(this.store, () => this.currentPageId)
this._parentIdsToChildIds = parentsToChildren(this.store)
@ -376,9 +658,6 @@ export class Editor extends EventEmitter<TLEventMap> {
/** @internal */
private _tickManager = new TickManager(this)
/** @internal */
private _updateDepth = 0
/**
* A manager for the app's snapping feature.
*
@ -419,6 +698,13 @@ export class Editor extends EventEmitter<TLEventMap> {
*/
getContainer: () => HTMLElement
/**
* A manager for side effects and correct state enforcement.
*
* @public
*/
readonly sideEffects: SideEffectManager<this>
/**
* Dispose the editor.
*
@ -474,7 +760,7 @@ export class Editor extends EventEmitter<TLEventMap> {
*/
readonly history = new HistoryManager(
this,
() => this._complete(),
// () => this._complete(),
(error) => {
this.annotateError(error, { origin: 'history.batch', willCrashApp: true })
this.crash(error)
@ -620,95 +906,6 @@ export class Editor extends EventEmitter<TLEventMap> {
return this._arrowBindingsIndex.value[shapeId] || EMPTY_ARRAY
}
/** @internal */
private _reparentArrow(arrowId: TLShapeId) {
const arrow = this.getShape<TLArrowShape>(arrowId)
if (!arrow) return
const { start, end } = arrow.props
const startShape = start.type === 'binding' ? this.getShape(start.boundShapeId) : undefined
const endShape = end.type === 'binding' ? this.getShape(end.boundShapeId) : undefined
const parentPageId = this.getAncestorPageId(arrow)
if (!parentPageId) return
let nextParentId: TLParentId
if (startShape && endShape) {
// if arrow has two bindings, always parent arrow to closest common ancestor of the bindings
nextParentId = this.findCommonAncestor([startShape, endShape]) ?? parentPageId
} else if (startShape || endShape) {
// if arrow has one binding, keep arrow on its own page
nextParentId = parentPageId
} else {
return
}
if (nextParentId && nextParentId !== arrow.parentId) {
this.reparentShapes([arrowId], nextParentId)
}
const reparentedArrow = this.getShape<TLArrowShape>(arrowId)
if (!reparentedArrow) throw Error('no reparented arrow')
const startSibling = this.getShapeNearestSibling(reparentedArrow, startShape)
const endSibling = this.getShapeNearestSibling(reparentedArrow, endShape)
let highestSibling: TLShape | undefined
if (startSibling && endSibling) {
highestSibling = startSibling.index > endSibling.index ? startSibling : endSibling
} else if (startSibling && !endSibling) {
highestSibling = startSibling
} else if (endSibling && !startSibling) {
highestSibling = endSibling
} else {
return
}
let finalIndex: string
const higherSiblings = this.getSortedChildIdsForParent(highestSibling.parentId)
.map((id) => this.getShape(id)!)
.filter((sibling) => sibling.index > highestSibling!.index)
if (higherSiblings.length) {
// there are siblings above the highest bound sibling, we need to
// insert between them.
// if the next sibling is also a bound arrow though, we can end up
// all fighting for the same indexes. so lets find the next
// non-arrow sibling...
const nextHighestNonArrowSibling = higherSiblings.find((sibling) => sibling.type !== 'arrow')
if (
// ...then, if we're above the last shape we want to be above...
reparentedArrow.index > highestSibling.index &&
// ...but below the next non-arrow sibling...
(!nextHighestNonArrowSibling || reparentedArrow.index < nextHighestNonArrowSibling.index)
) {
// ...then we're already in the right place. no need to update!
return
}
// otherwise, we need to find the index between the highest sibling
// we want to be above, and the next highest sibling we want to be
// below:
finalIndex = getIndexBetween(highestSibling.index, higherSiblings[0].index)
} else {
// if there are no siblings above us, we can just get the next index:
finalIndex = getIndexAbove(highestSibling.index)
}
if (finalIndex !== reparentedArrow.index) {
this.updateShapes<TLArrowShape>([{ id: arrowId, type: 'arrow', index: finalIndex }])
}
}
/** @internal */
private _unbindArrowTerminal(arrow: TLArrowShape, handleId: 'start' | 'end') {
const { x, y } = getArrowTerminalsInArrowSpace(this, arrow)[handleId]
this.store.put([{ ...arrow, props: { ...arrow.props, [handleId]: { type: 'point', x, y } } }])
}
@computed
private get arrowInfoCache() {
return this.store.createComputedCache<ArrowInfo, TLArrowShape>('arrow infoCache', (shape) => {
@ -727,181 +924,6 @@ export class Editor extends EventEmitter<TLEventMap> {
// return update ?? next
// }
/** @internal */
private _shapeWillBeDeleted(deletedShape: TLShape) {
// if the deleted shape has a parent shape make sure we call it's onChildrenChange callback
if (deletedShape.parentId && isShapeId(deletedShape.parentId)) {
this._invalidParents.add(deletedShape.parentId)
}
// clean up any arrows bound to this shape
const bindings = this._arrowBindingsIndex.value[deletedShape.id]
if (bindings?.length) {
for (const { arrowId, handleId } of bindings) {
const arrow = this.getShape<TLArrowShape>(arrowId)
if (!arrow) continue
this._unbindArrowTerminal(arrow, handleId)
}
}
const deletedIds = new Set([deletedShape.id])
const updates = compact(
this.pageStates.map((pageState) => {
return this._cleanupInstancePageState(pageState, deletedIds)
})
)
if (updates.length) {
this.store.put(updates)
}
}
/** @internal */
private _arrowDidUpdate(arrow: TLArrowShape) {
// if the shape is an arrow and its bound shape is on another page
// or was deleted, unbind it
for (const handle of ['start', 'end'] as const) {
const terminal = arrow.props[handle]
if (terminal.type !== 'binding') continue
const boundShape = this.getShape(terminal.boundShapeId)
const isShapeInSamePageAsArrow =
this.getAncestorPageId(arrow) === this.getAncestorPageId(boundShape)
if (!boundShape || !isShapeInSamePageAsArrow) {
this._unbindArrowTerminal(arrow, handle)
}
}
// always check the arrow parents
this._reparentArrow(arrow.id)
}
/**
* _invalidParents is used to trigger the 'onChildrenChange' callback that shapes can have.
*
* @internal
*/
private readonly _invalidParents = new Set<TLShapeId>()
/** @internal */
private _complete() {
for (const parentId of this._invalidParents) {
this._invalidParents.delete(parentId)
const parent = this.getShape(parentId)
if (!parent) continue
const util = this.getShapeUtil(parent)
const changes = util.onChildrenChange?.(parent)
if (changes?.length) {
this.updateShapes(changes, true)
}
}
this.emit('update')
}
/** @internal */
private _shapeDidChange(prev: TLShape, next: TLShape) {
if (this.isShapeOfType<TLArrowShape>(next, 'arrow')) {
this._arrowDidUpdate(next)
}
// if the shape's parent changed and it is bound to an arrow, update the arrow's parent
if (prev.parentId !== next.parentId) {
const reparentBoundArrows = (id: TLShapeId) => {
const boundArrows = this._arrowBindingsIndex.value[id]
if (boundArrows?.length) {
for (const arrow of boundArrows) {
this._reparentArrow(arrow.arrowId)
}
}
}
reparentBoundArrows(next.id)
this.visitDescendants(next.id, reparentBoundArrows)
}
// if this shape moved to a new page, clean up any previous page's instance state
if (prev.parentId !== next.parentId && isPageId(next.parentId)) {
const allMovingIds = new Set([prev.id])
this.visitDescendants(prev.id, (id) => {
allMovingIds.add(id)
})
for (const instancePageState of this.pageStates) {
if (instancePageState.pageId === next.parentId) continue
const nextPageState = this._cleanupInstancePageState(instancePageState, allMovingIds)
if (nextPageState) {
this.store.put([nextPageState])
}
}
}
if (prev.parentId && isShapeId(prev.parentId)) {
this._invalidParents.add(prev.parentId)
}
if (next.parentId !== prev.parentId && isShapeId(next.parentId)) {
this._invalidParents.add(next.parentId)
}
}
/** @internal */
private _pageStateDidChange(prev: TLInstancePageState, next: TLInstancePageState) {
if (prev?.selectedShapeIds !== next?.selectedShapeIds) {
// ensure that descendants and ancestors are not selected at the same time
const filtered = next.selectedShapeIds.filter((id) => {
let parentId = this.getShape(id)?.parentId
while (isShapeId(parentId)) {
if (next.selectedShapeIds.includes(parentId)) {
return false
}
parentId = this.getShape(parentId)?.parentId
}
return true
})
let nextFocusedGroupId: null | TLShapeId = null
if (filtered.length > 0) {
const commonGroupAncestor = this.findCommonAncestor(
compact(filtered.map((id) => this.getShape(id))),
(shape) => this.isShapeOfType<TLGroupShape>(shape, 'group')
)
if (commonGroupAncestor) {
nextFocusedGroupId = commonGroupAncestor
}
} else {
if (next?.focusedGroupId) {
nextFocusedGroupId = next.focusedGroupId
}
}
if (
filtered.length !== next.selectedShapeIds.length ||
nextFocusedGroupId !== next.focusedGroupId
) {
this.store.put([
{ ...next, selectedShapeIds: filtered, focusedGroupId: nextFocusedGroupId ?? null },
])
}
}
}
/** @internal */
private _pageWillBeDeleted(page: TLPage) {
// page was deleted, need to check whether it's the current page and select another one if so
if (this.instanceState.currentPageId !== page.id) return
const backupPageId = this.pages.find((p) => p.id !== page.id)?.id
if (!backupPageId) return
this.store.put([{ ...this.instanceState, currentPageId: backupPageId }])
// delete the camera and state for the page if necessary
const cameraId = CameraRecordType.createId(page.id)
const instance_PageStateId = InstancePageStateRecordType.createId(page.id)
this.store.remove([cameraId, instance_PageStateId])
}
/* --------------------- Errors --------------------- */
/** @internal */
@ -1802,54 +1824,6 @@ export class Editor extends EventEmitter<TLEventMap> {
return this
}
/** @internal */
private _cleanupInstancePageState(
prevPageState: TLInstancePageState,
shapesNoLongerInPage: Set<TLShapeId>
) {
let nextPageState = null as null | TLInstancePageState
const selectedShapeIds = prevPageState.selectedShapeIds.filter(
(id) => !shapesNoLongerInPage.has(id)
)
if (selectedShapeIds.length !== prevPageState.selectedShapeIds.length) {
if (!nextPageState) nextPageState = { ...prevPageState }
nextPageState.selectedShapeIds = selectedShapeIds
}
const erasingShapeIds = prevPageState.erasingShapeIds.filter(
(id) => !shapesNoLongerInPage.has(id)
)
if (erasingShapeIds.length !== prevPageState.erasingShapeIds.length) {
if (!nextPageState) nextPageState = { ...prevPageState }
nextPageState.erasingShapeIds = erasingShapeIds
}
if (prevPageState.hoveredShapeId && shapesNoLongerInPage.has(prevPageState.hoveredShapeId)) {
if (!nextPageState) nextPageState = { ...prevPageState }
nextPageState.hoveredShapeId = null
}
if (prevPageState.editingShapeId && shapesNoLongerInPage.has(prevPageState.editingShapeId)) {
if (!nextPageState) nextPageState = { ...prevPageState }
nextPageState.editingShapeId = null
}
const hintingShapeIds = prevPageState.hintingShapeIds.filter(
(id) => !shapesNoLongerInPage.has(id)
)
if (hintingShapeIds.length !== prevPageState.hintingShapeIds.length) {
if (!nextPageState) nextPageState = { ...prevPageState }
nextPageState.hintingShapeIds = hintingShapeIds
}
if (prevPageState.focusedGroupId && shapesNoLongerInPage.has(prevPageState.focusedGroupId)) {
if (!nextPageState) nextPageState = { ...prevPageState }
nextPageState.focusedGroupId = null
}
return nextPageState
}
/* --------------------- Camera --------------------- */
/** @internal */

View file

@ -2,13 +2,9 @@ import { HistoryManager } from './HistoryManager'
import { stack } from './Stack'
function createCounterHistoryManager() {
const manager = new HistoryManager(
{ emit: () => void null },
() => null,
() => {
return
}
)
const manager = new HistoryManager({ emit: () => void null }, () => {
return
})
const state = {
count: 0,
name: 'David',

View file

@ -29,10 +29,11 @@ export class HistoryManager<
constructor(
private readonly ctx: CTX,
private readonly onBatchComplete: () => void,
private readonly annotateError: (error: unknown) => void
) {}
onBatchComplete: () => void = () => void null
private _commands: Record<string, TLCommandHandler<any>> = {}
get numUndos() {

View file

@ -0,0 +1,20 @@
// let editor: Editor
// beforeEach(() => {
// editor = new Editor({
// shapeUtils: [],
// tools: [],
// store: createTLStore({ shapeUtils: [] }),
// getContainer: () => document.body,
// })
// })
describe('Side effect manager', () => {
it.todo('Registers an onBeforeCreate handler')
it.todo('Registers an onAfterCreate handler')
it.todo('Registers an onBeforeChange handler')
it.todo('Registers an onAfterChange handler')
it.todo('Registers an onBeforeDelete handler')
it.todo('Registers an onAfterDelete handler')
it.todo('Registers a batch start handler')
it.todo('Registers a batch complete handler')
})

View file

@ -0,0 +1,245 @@
import { TLRecord, TLStore } from '@tldraw/tlschema'
/** @public */
export type TLBeforeCreateHandler<R extends TLRecord> = (record: R, source: 'remote' | 'user') => R
/** @public */
export type TLAfterCreateHandler<R extends TLRecord> = (
record: R,
source: 'remote' | 'user'
) => void
/** @public */
export type TLBeforeChangeHandler<R extends TLRecord> = (
prev: R,
next: R,
source: 'remote' | 'user'
) => R
/** @public */
export type TLAfterChangeHandler<R extends TLRecord> = (
prev: R,
next: R,
source: 'remote' | 'user'
) => void
/** @public */
export type TLBeforeDeleteHandler<R extends TLRecord> = (
record: R,
source: 'remote' | 'user'
) => void | false
/** @public */
export type TLAfterDeleteHandler<R extends TLRecord> = (
record: R,
source: 'remote' | 'user'
) => void
/** @public */
export type TLBatchCompleteHandler = () => void
/**
* The side effect manager (aka a "correct state enforcer") is responsible
* for making sure that the editor's state is always correct. This includes
* things like: deleting a shape if its parent is deleted; unbinding
* arrows when their binding target is deleted; etc.
*
* @public
*/
export class SideEffectManager<
CTX extends {
store: TLStore
history: { onBatchComplete: () => void }
}
> {
constructor(public editor: CTX) {
editor.store.onBeforeCreate = (record, source) => {
const handlers = this._beforeCreateHandlers[
record.typeName
] as TLBeforeCreateHandler<TLRecord>[]
if (handlers) {
let r = record
for (const handler of handlers) {
r = handler(r, source)
}
return r
}
return record
}
editor.store.onAfterCreate = (record, source) => {
const handlers = this._afterCreateHandlers[
record.typeName
] as TLAfterCreateHandler<TLRecord>[]
if (handlers) {
for (const handler of handlers) {
handler(record, source)
}
}
}
editor.store.onBeforeChange = (prev, next, source) => {
const handlers = this._beforeChangeHandlers[
next.typeName
] as TLBeforeChangeHandler<TLRecord>[]
if (handlers) {
let r = next
for (const handler of handlers) {
r = handler(prev, r, source)
}
return r
}
return next
}
let updateDepth = 0
editor.store.onAfterChange = (prev, next, source) => {
updateDepth++
if (updateDepth > 1000) {
console.error('[CleanupManager.onAfterChange] Maximum update depth exceeded, bailing out.')
} else {
const handlers = this._afterChangeHandlers[
next.typeName
] as TLAfterChangeHandler<TLRecord>[]
if (handlers) {
for (const handler of handlers) {
handler(prev, next, source)
}
}
}
updateDepth--
}
editor.store.onBeforeDelete = (record, source) => {
const handlers = this._beforeDeleteHandlers[
record.typeName
] as TLBeforeDeleteHandler<TLRecord>[]
if (handlers) {
for (const handler of handlers) {
if (handler(record, source) === false) {
return false
}
}
}
}
editor.store.onAfterDelete = (record, source) => {
const handlers = this._afterDeleteHandlers[
record.typeName
] as TLAfterDeleteHandler<TLRecord>[]
if (handlers) {
for (const handler of handlers) {
handler(record, source)
}
}
}
editor.history.onBatchComplete = () => {
this._batchCompleteHandlers.forEach((fn) => fn())
}
}
private _beforeCreateHandlers: Partial<{
[K in TLRecord['typeName']]: TLBeforeCreateHandler<TLRecord & { typeName: K }>[]
}> = {}
private _afterCreateHandlers: Partial<{
[K in TLRecord['typeName']]: TLAfterCreateHandler<TLRecord & { typeName: K }>[]
}> = {}
private _beforeChangeHandlers: Partial<{
[K in TLRecord['typeName']]: TLBeforeChangeHandler<TLRecord & { typeName: K }>[]
}> = {}
private _afterChangeHandlers: Partial<{
[K in TLRecord['typeName']]: TLAfterChangeHandler<TLRecord & { typeName: K }>[]
}> = {}
private _beforeDeleteHandlers: Partial<{
[K in TLRecord['typeName']]: TLBeforeDeleteHandler<TLRecord & { typeName: K }>[]
}> = {}
private _afterDeleteHandlers: Partial<{
[K in TLRecord['typeName']]: TLAfterDeleteHandler<TLRecord & { typeName: K }>[]
}> = {}
private _batchCompleteHandlers: TLBatchCompleteHandler[] = []
registerBeforeCreateHandler<T extends TLRecord['typeName']>(
typeName: T,
handler: TLBeforeCreateHandler<TLRecord & { typeName: T }>
) {
const handlers = this._beforeCreateHandlers[typeName] as TLBeforeCreateHandler<any>[]
if (!handlers) this._beforeCreateHandlers[typeName] = []
this._beforeCreateHandlers[typeName]!.push(handler)
}
registerAfterCreateHandler<T extends TLRecord['typeName']>(
typeName: T,
handler: TLAfterCreateHandler<TLRecord & { typeName: T }>
) {
const handlers = this._afterCreateHandlers[typeName] as TLAfterCreateHandler<any>[]
if (!handlers) this._afterCreateHandlers[typeName] = []
this._afterCreateHandlers[typeName]!.push(handler)
}
registerBeforeChangeHandler<T extends TLRecord['typeName']>(
typeName: T,
handler: TLBeforeChangeHandler<TLRecord & { typeName: T }>
) {
const handlers = this._beforeChangeHandlers[typeName] as TLBeforeChangeHandler<any>[]
if (!handlers) this._beforeChangeHandlers[typeName] = []
this._beforeChangeHandlers[typeName]!.push(handler)
}
registerAfterChangeHandler<T extends TLRecord['typeName']>(
typeName: T,
handler: TLAfterChangeHandler<TLRecord & { typeName: T }>
) {
const handlers = this._afterChangeHandlers[typeName] as TLAfterChangeHandler<any>[]
if (!handlers) this._afterChangeHandlers[typeName] = []
this._afterChangeHandlers[typeName]!.push(handler as TLAfterChangeHandler<any>)
}
registerBeforeDeleteHandler<T extends TLRecord['typeName']>(
typeName: T,
handler: TLBeforeDeleteHandler<TLRecord & { typeName: T }>
) {
const handlers = this._beforeDeleteHandlers[typeName] as TLBeforeDeleteHandler<any>[]
if (!handlers) this._beforeDeleteHandlers[typeName] = []
this._beforeDeleteHandlers[typeName]!.push(handler as TLBeforeDeleteHandler<any>)
}
registerAfterDeleteHandler<T extends TLRecord['typeName']>(
typeName: T,
handler: TLAfterDeleteHandler<TLRecord & { typeName: T }>
) {
const handlers = this._afterDeleteHandlers[typeName] as TLAfterDeleteHandler<any>[]
if (!handlers) this._afterDeleteHandlers[typeName] = []
this._afterDeleteHandlers[typeName]!.push(handler as TLAfterDeleteHandler<any>)
}
/**
* Register a handler to be called when a store completes a batch.
*
* @example
* ```ts
* let count = 0
*
* editor.cleanup.registerBatchCompleteHandler(() => count++)
*
* editor.selectAll()
* expect(count).toBe(1)
*
* editor.batch(() => {
* editor.selectNone()
* editor.selectAll()
* })
*
* expect(count).toBe(2)
* ```
*
* @param handler - The handler to call
*
* @public
*/
registerBatchCompleteHandler(handler: TLBatchCompleteHandler) {
this._batchCompleteHandlers.push(handler)
}
}

View file

@ -244,6 +244,8 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
// (undocumented)
_flushHistory(): void;
get: <K extends IdOf<R>>(id: K) => RecFromId<K> | undefined;
// (undocumented)
getRecordType: <T extends R>(record: R) => T;
getSnapshot(scope?: 'all' | RecordScope): StoreSnapshot<R>;
has: <K extends IdOf<R>>(id: K) => boolean;
readonly history: Atom<number, RecordsDiff<R>>;
@ -255,10 +257,12 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
// @internal (undocumented)
markAsPossiblyCorrupted(): void;
mergeRemoteChanges: (fn: () => void) => void;
onAfterChange?: (prev: R, next: R) => void;
onAfterCreate?: (record: R) => void;
onAfterDelete?: (prev: R) => void;
onBeforeDelete?: (prev: R) => void;
onAfterChange?: (prev: R, next: R, source: 'remote' | 'user') => void;
onAfterCreate?: (record: R, source: 'remote' | 'user') => void;
onAfterDelete?: (prev: R, source: 'remote' | 'user') => void;
onBeforeChange?: (prev: R, next: R, source: 'remote' | 'user') => R;
onBeforeCreate?: (next: R, source: 'remote' | 'user') => R;
onBeforeDelete?: (prev: R, source: 'remote' | 'user') => false | void;
// (undocumented)
readonly props: Props;
put: (records: R[], phaseOverride?: 'initialize') => void;

View file

@ -297,13 +297,29 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
this.allRecords().forEach((record) => this.schema.validateRecord(this, record, phase, null))
}
/**
* A callback fired after each record's change.
*
* @param prev - The previous value, if any.
* @param next - The next value.
*/
onBeforeCreate?: (next: R, source: 'remote' | 'user') => R
/**
* A callback fired after a record is created. Use this to perform related updates to other
* records in the store.
*
* @param record - The record to be created
*/
onAfterCreate?: (record: R) => void
onAfterCreate?: (record: R, source: 'remote' | 'user') => void
/**
* A callback before after each record's change.
*
* @param prev - The previous value, if any.
* @param next - The next value.
*/
onBeforeChange?: (prev: R, next: R, source: 'remote' | 'user') => R
/**
* A callback fired after each record's change.
@ -311,21 +327,21 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
* @param prev - The previous value, if any.
* @param next - The next value.
*/
onAfterChange?: (prev: R, next: R) => void
onAfterChange?: (prev: R, next: R, source: 'remote' | 'user') => void
/**
* A callback fired before a record is deleted.
*
* @param prev - The record that will be deleted.
*/
onBeforeDelete?: (prev: R) => void
onBeforeDelete?: (prev: R, source: 'remote' | 'user') => false | void
/**
* A callback fired after a record is deleted.
*
* @param prev - The record that will be deleted.
*/
onAfterDelete?: (prev: R) => void
onAfterDelete?: (prev: R, source: 'remote' | 'user') => void
// used to avoid running callbacks when rolling back changes in sync client
private _runCallbacks = true
@ -353,12 +369,18 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
// changes (e.g. additions, deletions, or updates that produce a new value).
let didChange = false
const beforeCreate = this.onBeforeCreate && this._runCallbacks ? this.onBeforeCreate : null
const beforeUpdate = this.onBeforeChange && this._runCallbacks ? this.onBeforeChange : null
const source = this.isMergingRemoteChanges ? 'remote' : 'user'
for (let i = 0, n = records.length; i < n; i++) {
record = records[i]
const recordAtom = (map ?? currentMap)[record.id as IdOf<R>]
if (recordAtom) {
if (beforeUpdate) record = beforeUpdate(recordAtom.value, record, source)
// If we already have an atom for this record, update its value.
const initialValue = recordAtom.__unsafe__getWithoutCapture()
@ -382,6 +404,8 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
updates[record.id] = [initialValue, finalValue]
}
} else {
if (beforeCreate) record = beforeCreate(record, source)
didChange = true
// If we don't have an atom, create one.
@ -418,20 +442,22 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
removed: {} as Record<IdOf<R>, R>,
})
const { onAfterCreate, onAfterChange } = this
if (this._runCallbacks) {
const { onAfterCreate, onAfterChange } = this
if (onAfterCreate && this._runCallbacks) {
// Run the onAfterChange callback for addition.
Object.values(additions).forEach((record) => {
onAfterCreate(record)
})
}
if (onAfterCreate) {
// Run the onAfterChange callback for addition.
Object.values(additions).forEach((record) => {
onAfterCreate(record, source)
})
}
if (onAfterChange && this._runCallbacks) {
// Run the onAfterChange callback for update.
Object.values(updates).forEach(([from, to]) => {
onAfterChange(from, to)
})
if (onAfterChange) {
// Run the onAfterChange callback for update.
Object.values(updates).forEach(([from, to]) => {
onAfterChange(from, to, source)
})
}
}
})
}
@ -444,12 +470,17 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
*/
remove = (ids: IdOf<R>[]): void => {
transact(() => {
const cancelled = [] as IdOf<R>[]
const source = this.isMergingRemoteChanges ? 'remote' : 'user'
if (this.onBeforeDelete && this._runCallbacks) {
for (const id of ids) {
const atom = this.atoms.__unsafe__getWithoutCapture()[id]
if (!atom) continue
this.onBeforeDelete(atom.value)
if (this.onBeforeDelete(atom.value, source) === false) {
cancelled.push(id)
}
}
}
@ -460,6 +491,7 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
let result: typeof atoms | undefined = undefined
for (const id of ids) {
if (cancelled.includes(id)) continue
if (!(id in atoms)) continue
if (!result) result = { ...atoms }
if (!removed) removed = {} as Record<IdOf<R>, R>
@ -476,8 +508,12 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
// If we have an onAfterChange, run it for each removed record.
if (this.onAfterDelete && this._runCallbacks) {
let record: R
for (let i = 0, n = ids.length; i < n; i++) {
this.onAfterDelete(removed[ids[i]])
record = removed[ids[i]]
if (record) {
this.onAfterDelete(record, source)
}
}
}
})
@ -596,6 +632,7 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
console.error(`Record ${id} not found. This is probably an error`)
return
}
this.put([updater(atom.__unsafe__getWithoutCapture() as any as RecFromId<K>) as any])
}
@ -752,6 +789,14 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
}
}
getRecordType = <T extends R>(record: R): T => {
const type = this.schema.types[record.typeName as R['typeName']]
if (!type) {
throw new Error(`Record type ${record.typeName} not found`)
}
return type as unknown as T
}
private _integrityChecker?: () => void | undefined
/** @internal */