SideEffectManager (#1785)
This PR extracts the side effect manager from #1778. ### Change Type - [x] `major` — Breaking change
This commit is contained in:
parent
c478d75117
commit
507bba82fd
8 changed files with 664 additions and 378 deletions
|
@ -944,6 +944,7 @@ export class Editor extends EventEmitter<TLEventMap> {
|
|||
};
|
||||
get sharedOpacity(): SharedStyle<number>;
|
||||
get sharedStyles(): ReadonlySharedStyleMap;
|
||||
readonly sideEffects: SideEffectManager<this>;
|
||||
slideCamera(opts?: {
|
||||
speed: number;
|
||||
direction: VecLike;
|
||||
|
|
|
@ -105,6 +105,7 @@ import { deriveShapeIdsInCurrentPage } from './derivations/shapeIdsInCurrentPage
|
|||
import { ClickManager } from './managers/ClickManager'
|
||||
import { EnvironmentManager } from './managers/EnvironmentManager'
|
||||
import { HistoryManager } from './managers/HistoryManager'
|
||||
import { SideEffectManager } from './managers/SideEffectManager'
|
||||
import { SnapManager } from './managers/SnapManager'
|
||||
import { TextManager } from './managers/TextManager'
|
||||
import { TickManager } from './managers/TickManager'
|
||||
|
@ -242,35 +243,317 @@ export class Editor extends EventEmitter<TLEventMap> {
|
|||
|
||||
this.environment = new EnvironmentManager(this)
|
||||
|
||||
this.store.onBeforeDelete = (record) => {
|
||||
if (record.typeName === 'shape') {
|
||||
this._shapeWillBeDeleted(record)
|
||||
} else if (record.typeName === 'page') {
|
||||
this._pageWillBeDeleted(record)
|
||||
}
|
||||
// Cleanup
|
||||
|
||||
const invalidParents = new Set<TLShapeId>()
|
||||
|
||||
const reparentArrow = (arrowId: TLArrowShape['id']) => {
|
||||
const arrow = this.getShape<TLArrowShape>(arrowId)
|
||||
if (!arrow) return
|
||||
const { start, end } = arrow.props
|
||||
const startShape = start.type === 'binding' ? this.getShape(start.boundShapeId) : undefined
|
||||
const endShape = end.type === 'binding' ? this.getShape(end.boundShapeId) : undefined
|
||||
|
||||
const parentPageId = this.getAncestorPageId(arrow)
|
||||
if (!parentPageId) return
|
||||
|
||||
let nextParentId: TLParentId
|
||||
if (startShape && endShape) {
|
||||
// if arrow has two bindings, always parent arrow to closest common ancestor of the bindings
|
||||
nextParentId = this.findCommonAncestor([startShape, endShape]) ?? parentPageId
|
||||
} else if (startShape || endShape) {
|
||||
// if arrow has one binding, keep arrow on its own page
|
||||
nextParentId = parentPageId
|
||||
} else {
|
||||
return
|
||||
}
|
||||
|
||||
this.store.onAfterChange = (prev, next) => {
|
||||
this._updateDepth++
|
||||
if (this._updateDepth > 1000) {
|
||||
console.error('[onAfterChange] Maximum update depth exceeded, bailing out.')
|
||||
if (nextParentId && nextParentId !== arrow.parentId) {
|
||||
this.reparentShapes([arrowId], nextParentId)
|
||||
}
|
||||
if (prev.typeName === 'shape' && next.typeName === 'shape') {
|
||||
this._shapeDidChange(prev, next)
|
||||
} else if (
|
||||
prev.typeName === 'instance_page_state' &&
|
||||
next.typeName === 'instance_page_state'
|
||||
|
||||
const reparentedArrow = this.getShape<TLArrowShape>(arrowId)
|
||||
if (!reparentedArrow) throw Error('no reparented arrow')
|
||||
|
||||
const startSibling = this.getShapeNearestSibling(reparentedArrow, startShape)
|
||||
const endSibling = this.getShapeNearestSibling(reparentedArrow, endShape)
|
||||
|
||||
let highestSibling: TLShape | undefined
|
||||
|
||||
if (startSibling && endSibling) {
|
||||
highestSibling = startSibling.index > endSibling.index ? startSibling : endSibling
|
||||
} else if (startSibling && !endSibling) {
|
||||
highestSibling = startSibling
|
||||
} else if (endSibling && !startSibling) {
|
||||
highestSibling = endSibling
|
||||
} else {
|
||||
return
|
||||
}
|
||||
|
||||
let finalIndex: string
|
||||
|
||||
const higherSiblings = this.getSortedChildIdsForParent(highestSibling.parentId)
|
||||
.map((id) => this.getShape(id)!)
|
||||
.filter((sibling) => sibling.index > highestSibling!.index)
|
||||
|
||||
if (higherSiblings.length) {
|
||||
// there are siblings above the highest bound sibling, we need to
|
||||
// insert between them.
|
||||
|
||||
// if the next sibling is also a bound arrow though, we can end up
|
||||
// all fighting for the same indexes. so lets find the next
|
||||
// non-arrow sibling...
|
||||
const nextHighestNonArrowSibling = higherSiblings.find(
|
||||
(sibling) => sibling.type !== 'arrow'
|
||||
)
|
||||
|
||||
if (
|
||||
// ...then, if we're above the last shape we want to be above...
|
||||
reparentedArrow.index > highestSibling.index &&
|
||||
// ...but below the next non-arrow sibling...
|
||||
(!nextHighestNonArrowSibling || reparentedArrow.index < nextHighestNonArrowSibling.index)
|
||||
) {
|
||||
this._pageStateDidChange(prev, next)
|
||||
// ...then we're already in the right place. no need to update!
|
||||
return
|
||||
}
|
||||
|
||||
this._updateDepth--
|
||||
// otherwise, we need to find the index between the highest sibling
|
||||
// we want to be above, and the next highest sibling we want to be
|
||||
// below:
|
||||
finalIndex = getIndexBetween(highestSibling.index, higherSiblings[0].index)
|
||||
} else {
|
||||
// if there are no siblings above us, we can just get the next index:
|
||||
finalIndex = getIndexAbove(highestSibling.index)
|
||||
}
|
||||
this.store.onAfterCreate = (record) => {
|
||||
if (record.typeName === 'shape' && this.isShapeOfType<TLArrowShape>(record, 'arrow')) {
|
||||
this._arrowDidUpdate(record)
|
||||
|
||||
if (finalIndex !== reparentedArrow.index) {
|
||||
this.updateShapes<TLArrowShape>([{ id: arrowId, type: 'arrow', index: finalIndex }])
|
||||
}
|
||||
if (record.typeName === 'page') {
|
||||
}
|
||||
|
||||
const unbindArrowTerminal = (arrow: TLArrowShape, handleId: 'start' | 'end') => {
|
||||
const { x, y } = getArrowTerminalsInArrowSpace(this, arrow)[handleId]
|
||||
this.store.put([{ ...arrow, props: { ...arrow.props, [handleId]: { type: 'point', x, y } } }])
|
||||
}
|
||||
|
||||
const arrowDidUpdate = (arrow: TLArrowShape) => {
|
||||
// if the shape is an arrow and its bound shape is on another page
|
||||
// or was deleted, unbind it
|
||||
for (const handle of ['start', 'end'] as const) {
|
||||
const terminal = arrow.props[handle]
|
||||
if (terminal.type !== 'binding') continue
|
||||
const boundShape = this.getShape(terminal.boundShapeId)
|
||||
const isShapeInSamePageAsArrow =
|
||||
this.getAncestorPageId(arrow) === this.getAncestorPageId(boundShape)
|
||||
if (!boundShape || !isShapeInSamePageAsArrow) {
|
||||
unbindArrowTerminal(arrow, handle)
|
||||
}
|
||||
}
|
||||
|
||||
// always check the arrow parents
|
||||
reparentArrow(arrow.id)
|
||||
}
|
||||
|
||||
const cleanupInstancePageState = (
|
||||
prevPageState: TLInstancePageState,
|
||||
shapesNoLongerInPage: Set<TLShapeId>
|
||||
) => {
|
||||
let nextPageState = null as null | TLInstancePageState
|
||||
|
||||
const selectedShapeIds = prevPageState.selectedShapeIds.filter(
|
||||
(id) => !shapesNoLongerInPage.has(id)
|
||||
)
|
||||
if (selectedShapeIds.length !== prevPageState.selectedShapeIds.length) {
|
||||
if (!nextPageState) nextPageState = { ...prevPageState }
|
||||
nextPageState.selectedShapeIds = selectedShapeIds
|
||||
}
|
||||
|
||||
const erasingShapeIds = prevPageState.erasingShapeIds.filter(
|
||||
(id) => !shapesNoLongerInPage.has(id)
|
||||
)
|
||||
if (erasingShapeIds.length !== prevPageState.erasingShapeIds.length) {
|
||||
if (!nextPageState) nextPageState = { ...prevPageState }
|
||||
nextPageState.erasingShapeIds = erasingShapeIds
|
||||
}
|
||||
|
||||
if (prevPageState.hoveredShapeId && shapesNoLongerInPage.has(prevPageState.hoveredShapeId)) {
|
||||
if (!nextPageState) nextPageState = { ...prevPageState }
|
||||
nextPageState.hoveredShapeId = null
|
||||
}
|
||||
|
||||
if (prevPageState.editingShapeId && shapesNoLongerInPage.has(prevPageState.editingShapeId)) {
|
||||
if (!nextPageState) nextPageState = { ...prevPageState }
|
||||
nextPageState.editingShapeId = null
|
||||
}
|
||||
|
||||
const hintingShapeIds = prevPageState.hintingShapeIds.filter(
|
||||
(id) => !shapesNoLongerInPage.has(id)
|
||||
)
|
||||
if (hintingShapeIds.length !== prevPageState.hintingShapeIds.length) {
|
||||
if (!nextPageState) nextPageState = { ...prevPageState }
|
||||
nextPageState.hintingShapeIds = hintingShapeIds
|
||||
}
|
||||
|
||||
if (prevPageState.focusedGroupId && shapesNoLongerInPage.has(prevPageState.focusedGroupId)) {
|
||||
if (!nextPageState) nextPageState = { ...prevPageState }
|
||||
nextPageState.focusedGroupId = null
|
||||
}
|
||||
return nextPageState
|
||||
}
|
||||
|
||||
this.sideEffects = new SideEffectManager(this)
|
||||
|
||||
this.sideEffects.registerBatchCompleteHandler(() => {
|
||||
for (const parentId of invalidParents) {
|
||||
invalidParents.delete(parentId)
|
||||
const parent = this.getShape(parentId)
|
||||
if (!parent) continue
|
||||
|
||||
const util = this.getShapeUtil(parent)
|
||||
const changes = util.onChildrenChange?.(parent)
|
||||
|
||||
if (changes?.length) {
|
||||
this.updateShapes(changes, true)
|
||||
}
|
||||
}
|
||||
|
||||
this.emit('update')
|
||||
})
|
||||
|
||||
this.sideEffects.registerBeforeDeleteHandler('shape', (record) => {
|
||||
// if the deleted shape has a parent shape make sure we call it's onChildrenChange callback
|
||||
if (record.parentId && isShapeId(record.parentId)) {
|
||||
invalidParents.add(record.parentId)
|
||||
}
|
||||
// clean up any arrows bound to this shape
|
||||
const bindings = this._arrowBindingsIndex.value[record.id]
|
||||
if (bindings?.length) {
|
||||
for (const { arrowId, handleId } of bindings) {
|
||||
const arrow = this.getShape<TLArrowShape>(arrowId)
|
||||
if (!arrow) continue
|
||||
unbindArrowTerminal(arrow, handleId)
|
||||
}
|
||||
}
|
||||
const deletedIds = new Set([record.id])
|
||||
const updates = compact(
|
||||
this.pageStates.map((pageState) => {
|
||||
return cleanupInstancePageState(pageState, deletedIds)
|
||||
})
|
||||
)
|
||||
|
||||
if (updates.length) {
|
||||
this.store.put(updates)
|
||||
}
|
||||
})
|
||||
|
||||
this.sideEffects.registerBeforeDeleteHandler('page', (record) => {
|
||||
// page was deleted, need to check whether it's the current page and select another one if so
|
||||
if (this.instanceState.currentPageId !== record.id) return
|
||||
|
||||
const backupPageId = this.pages.find((p) => p.id !== record.id)?.id
|
||||
if (!backupPageId) return
|
||||
this.store.put([{ ...this.instanceState, currentPageId: backupPageId }])
|
||||
|
||||
// delete the camera and state for the page if necessary
|
||||
const cameraId = CameraRecordType.createId(record.id)
|
||||
const instance_PageStateId = InstancePageStateRecordType.createId(record.id)
|
||||
this.store.remove([cameraId, instance_PageStateId])
|
||||
})
|
||||
|
||||
this.sideEffects.registerAfterChangeHandler('shape', (prev, next) => {
|
||||
if (this.isShapeOfType<TLArrowShape>(next, 'arrow')) {
|
||||
arrowDidUpdate(next)
|
||||
}
|
||||
|
||||
// if the shape's parent changed and it is bound to an arrow, update the arrow's parent
|
||||
if (prev.parentId !== next.parentId) {
|
||||
const reparentBoundArrows = (id: TLShapeId) => {
|
||||
const boundArrows = this._arrowBindingsIndex.value[id]
|
||||
if (boundArrows?.length) {
|
||||
for (const arrow of boundArrows) {
|
||||
reparentArrow(arrow.arrowId)
|
||||
}
|
||||
}
|
||||
}
|
||||
reparentBoundArrows(next.id)
|
||||
this.visitDescendants(next.id, reparentBoundArrows)
|
||||
}
|
||||
|
||||
// if this shape moved to a new page, clean up any previous page's instance state
|
||||
if (prev.parentId !== next.parentId && isPageId(next.parentId)) {
|
||||
const allMovingIds = new Set([prev.id])
|
||||
this.visitDescendants(prev.id, (id) => {
|
||||
allMovingIds.add(id)
|
||||
})
|
||||
|
||||
for (const instancePageState of this.pageStates) {
|
||||
if (instancePageState.pageId === next.parentId) continue
|
||||
const nextPageState = cleanupInstancePageState(instancePageState, allMovingIds)
|
||||
|
||||
if (nextPageState) {
|
||||
this.store.put([nextPageState])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (prev.parentId && isShapeId(prev.parentId)) {
|
||||
invalidParents.add(prev.parentId)
|
||||
}
|
||||
|
||||
if (next.parentId !== prev.parentId && isShapeId(next.parentId)) {
|
||||
invalidParents.add(next.parentId)
|
||||
}
|
||||
})
|
||||
|
||||
this.sideEffects.registerAfterChangeHandler('instance_page_state', (prev, next) => {
|
||||
if (prev?.selectedShapeIds !== next?.selectedShapeIds) {
|
||||
// ensure that descendants and ancestors are not selected at the same time
|
||||
const filtered = next.selectedShapeIds.filter((id) => {
|
||||
let parentId = this.getShape(id)?.parentId
|
||||
while (isShapeId(parentId)) {
|
||||
if (next.selectedShapeIds.includes(parentId)) {
|
||||
return false
|
||||
}
|
||||
parentId = this.getShape(parentId)?.parentId
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
let nextFocusedGroupId: null | TLShapeId = null
|
||||
|
||||
if (filtered.length > 0) {
|
||||
const commonGroupAncestor = this.findCommonAncestor(
|
||||
compact(filtered.map((id) => this.getShape(id))),
|
||||
(shape) => this.isShapeOfType<TLGroupShape>(shape, 'group')
|
||||
)
|
||||
|
||||
if (commonGroupAncestor) {
|
||||
nextFocusedGroupId = commonGroupAncestor
|
||||
}
|
||||
} else {
|
||||
if (next?.focusedGroupId) {
|
||||
nextFocusedGroupId = next.focusedGroupId
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
filtered.length !== next.selectedShapeIds.length ||
|
||||
nextFocusedGroupId !== next.focusedGroupId
|
||||
) {
|
||||
this.store.put([
|
||||
{ ...next, selectedShapeIds: filtered, focusedGroupId: nextFocusedGroupId ?? null },
|
||||
])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
this.sideEffects.registerAfterCreateHandler('shape', (record) => {
|
||||
if (this.isShapeOfType<TLArrowShape>(record, 'arrow')) {
|
||||
arrowDidUpdate(record)
|
||||
}
|
||||
})
|
||||
|
||||
this.sideEffects.registerAfterCreateHandler('page', (record) => {
|
||||
const cameraId = CameraRecordType.createId(record.id)
|
||||
const _pageStateId = InstancePageStateRecordType.createId(record.id)
|
||||
if (!this.store.has(cameraId)) {
|
||||
|
@ -281,8 +564,7 @@ export class Editor extends EventEmitter<TLEventMap> {
|
|||
InstancePageStateRecordType.create({ id: _pageStateId, pageId: record.id }),
|
||||
])
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
this._shapeIdsOnCurrentPage = deriveShapeIdsInCurrentPage(this.store, () => this.currentPageId)
|
||||
this._parentIdsToChildIds = parentsToChildren(this.store)
|
||||
|
@ -376,9 +658,6 @@ export class Editor extends EventEmitter<TLEventMap> {
|
|||
/** @internal */
|
||||
private _tickManager = new TickManager(this)
|
||||
|
||||
/** @internal */
|
||||
private _updateDepth = 0
|
||||
|
||||
/**
|
||||
* A manager for the app's snapping feature.
|
||||
*
|
||||
|
@ -419,6 +698,13 @@ export class Editor extends EventEmitter<TLEventMap> {
|
|||
*/
|
||||
getContainer: () => HTMLElement
|
||||
|
||||
/**
|
||||
* A manager for side effects and correct state enforcement.
|
||||
*
|
||||
* @public
|
||||
*/
|
||||
readonly sideEffects: SideEffectManager<this>
|
||||
|
||||
/**
|
||||
* Dispose the editor.
|
||||
*
|
||||
|
@ -474,7 +760,7 @@ export class Editor extends EventEmitter<TLEventMap> {
|
|||
*/
|
||||
readonly history = new HistoryManager(
|
||||
this,
|
||||
() => this._complete(),
|
||||
// () => this._complete(),
|
||||
(error) => {
|
||||
this.annotateError(error, { origin: 'history.batch', willCrashApp: true })
|
||||
this.crash(error)
|
||||
|
@ -620,95 +906,6 @@ export class Editor extends EventEmitter<TLEventMap> {
|
|||
return this._arrowBindingsIndex.value[shapeId] || EMPTY_ARRAY
|
||||
}
|
||||
|
||||
/** @internal */
|
||||
private _reparentArrow(arrowId: TLShapeId) {
|
||||
const arrow = this.getShape<TLArrowShape>(arrowId)
|
||||
if (!arrow) return
|
||||
const { start, end } = arrow.props
|
||||
const startShape = start.type === 'binding' ? this.getShape(start.boundShapeId) : undefined
|
||||
const endShape = end.type === 'binding' ? this.getShape(end.boundShapeId) : undefined
|
||||
|
||||
const parentPageId = this.getAncestorPageId(arrow)
|
||||
if (!parentPageId) return
|
||||
|
||||
let nextParentId: TLParentId
|
||||
if (startShape && endShape) {
|
||||
// if arrow has two bindings, always parent arrow to closest common ancestor of the bindings
|
||||
nextParentId = this.findCommonAncestor([startShape, endShape]) ?? parentPageId
|
||||
} else if (startShape || endShape) {
|
||||
// if arrow has one binding, keep arrow on its own page
|
||||
nextParentId = parentPageId
|
||||
} else {
|
||||
return
|
||||
}
|
||||
|
||||
if (nextParentId && nextParentId !== arrow.parentId) {
|
||||
this.reparentShapes([arrowId], nextParentId)
|
||||
}
|
||||
|
||||
const reparentedArrow = this.getShape<TLArrowShape>(arrowId)
|
||||
if (!reparentedArrow) throw Error('no reparented arrow')
|
||||
|
||||
const startSibling = this.getShapeNearestSibling(reparentedArrow, startShape)
|
||||
const endSibling = this.getShapeNearestSibling(reparentedArrow, endShape)
|
||||
|
||||
let highestSibling: TLShape | undefined
|
||||
|
||||
if (startSibling && endSibling) {
|
||||
highestSibling = startSibling.index > endSibling.index ? startSibling : endSibling
|
||||
} else if (startSibling && !endSibling) {
|
||||
highestSibling = startSibling
|
||||
} else if (endSibling && !startSibling) {
|
||||
highestSibling = endSibling
|
||||
} else {
|
||||
return
|
||||
}
|
||||
|
||||
let finalIndex: string
|
||||
|
||||
const higherSiblings = this.getSortedChildIdsForParent(highestSibling.parentId)
|
||||
.map((id) => this.getShape(id)!)
|
||||
.filter((sibling) => sibling.index > highestSibling!.index)
|
||||
|
||||
if (higherSiblings.length) {
|
||||
// there are siblings above the highest bound sibling, we need to
|
||||
// insert between them.
|
||||
|
||||
// if the next sibling is also a bound arrow though, we can end up
|
||||
// all fighting for the same indexes. so lets find the next
|
||||
// non-arrow sibling...
|
||||
const nextHighestNonArrowSibling = higherSiblings.find((sibling) => sibling.type !== 'arrow')
|
||||
|
||||
if (
|
||||
// ...then, if we're above the last shape we want to be above...
|
||||
reparentedArrow.index > highestSibling.index &&
|
||||
// ...but below the next non-arrow sibling...
|
||||
(!nextHighestNonArrowSibling || reparentedArrow.index < nextHighestNonArrowSibling.index)
|
||||
) {
|
||||
// ...then we're already in the right place. no need to update!
|
||||
return
|
||||
}
|
||||
|
||||
// otherwise, we need to find the index between the highest sibling
|
||||
// we want to be above, and the next highest sibling we want to be
|
||||
// below:
|
||||
finalIndex = getIndexBetween(highestSibling.index, higherSiblings[0].index)
|
||||
} else {
|
||||
// if there are no siblings above us, we can just get the next index:
|
||||
finalIndex = getIndexAbove(highestSibling.index)
|
||||
}
|
||||
|
||||
if (finalIndex !== reparentedArrow.index) {
|
||||
this.updateShapes<TLArrowShape>([{ id: arrowId, type: 'arrow', index: finalIndex }])
|
||||
}
|
||||
}
|
||||
|
||||
/** @internal */
|
||||
private _unbindArrowTerminal(arrow: TLArrowShape, handleId: 'start' | 'end') {
|
||||
const { x, y } = getArrowTerminalsInArrowSpace(this, arrow)[handleId]
|
||||
this.store.put([{ ...arrow, props: { ...arrow.props, [handleId]: { type: 'point', x, y } } }])
|
||||
}
|
||||
|
||||
@computed
|
||||
private get arrowInfoCache() {
|
||||
return this.store.createComputedCache<ArrowInfo, TLArrowShape>('arrow infoCache', (shape) => {
|
||||
|
@ -727,181 +924,6 @@ export class Editor extends EventEmitter<TLEventMap> {
|
|||
// return update ?? next
|
||||
// }
|
||||
|
||||
/** @internal */
|
||||
private _shapeWillBeDeleted(deletedShape: TLShape) {
|
||||
// if the deleted shape has a parent shape make sure we call it's onChildrenChange callback
|
||||
if (deletedShape.parentId && isShapeId(deletedShape.parentId)) {
|
||||
this._invalidParents.add(deletedShape.parentId)
|
||||
}
|
||||
// clean up any arrows bound to this shape
|
||||
const bindings = this._arrowBindingsIndex.value[deletedShape.id]
|
||||
if (bindings?.length) {
|
||||
for (const { arrowId, handleId } of bindings) {
|
||||
const arrow = this.getShape<TLArrowShape>(arrowId)
|
||||
if (!arrow) continue
|
||||
this._unbindArrowTerminal(arrow, handleId)
|
||||
}
|
||||
}
|
||||
const deletedIds = new Set([deletedShape.id])
|
||||
const updates = compact(
|
||||
this.pageStates.map((pageState) => {
|
||||
return this._cleanupInstancePageState(pageState, deletedIds)
|
||||
})
|
||||
)
|
||||
|
||||
if (updates.length) {
|
||||
this.store.put(updates)
|
||||
}
|
||||
}
|
||||
|
||||
/** @internal */
|
||||
private _arrowDidUpdate(arrow: TLArrowShape) {
|
||||
// if the shape is an arrow and its bound shape is on another page
|
||||
// or was deleted, unbind it
|
||||
for (const handle of ['start', 'end'] as const) {
|
||||
const terminal = arrow.props[handle]
|
||||
if (terminal.type !== 'binding') continue
|
||||
const boundShape = this.getShape(terminal.boundShapeId)
|
||||
const isShapeInSamePageAsArrow =
|
||||
this.getAncestorPageId(arrow) === this.getAncestorPageId(boundShape)
|
||||
if (!boundShape || !isShapeInSamePageAsArrow) {
|
||||
this._unbindArrowTerminal(arrow, handle)
|
||||
}
|
||||
}
|
||||
|
||||
// always check the arrow parents
|
||||
this._reparentArrow(arrow.id)
|
||||
}
|
||||
|
||||
/**
|
||||
* _invalidParents is used to trigger the 'onChildrenChange' callback that shapes can have.
|
||||
*
|
||||
* @internal
|
||||
*/
|
||||
private readonly _invalidParents = new Set<TLShapeId>()
|
||||
|
||||
/** @internal */
|
||||
private _complete() {
|
||||
for (const parentId of this._invalidParents) {
|
||||
this._invalidParents.delete(parentId)
|
||||
const parent = this.getShape(parentId)
|
||||
if (!parent) continue
|
||||
|
||||
const util = this.getShapeUtil(parent)
|
||||
const changes = util.onChildrenChange?.(parent)
|
||||
|
||||
if (changes?.length) {
|
||||
this.updateShapes(changes, true)
|
||||
}
|
||||
}
|
||||
|
||||
this.emit('update')
|
||||
}
|
||||
|
||||
/** @internal */
|
||||
private _shapeDidChange(prev: TLShape, next: TLShape) {
|
||||
if (this.isShapeOfType<TLArrowShape>(next, 'arrow')) {
|
||||
this._arrowDidUpdate(next)
|
||||
}
|
||||
|
||||
// if the shape's parent changed and it is bound to an arrow, update the arrow's parent
|
||||
if (prev.parentId !== next.parentId) {
|
||||
const reparentBoundArrows = (id: TLShapeId) => {
|
||||
const boundArrows = this._arrowBindingsIndex.value[id]
|
||||
if (boundArrows?.length) {
|
||||
for (const arrow of boundArrows) {
|
||||
this._reparentArrow(arrow.arrowId)
|
||||
}
|
||||
}
|
||||
}
|
||||
reparentBoundArrows(next.id)
|
||||
this.visitDescendants(next.id, reparentBoundArrows)
|
||||
}
|
||||
|
||||
// if this shape moved to a new page, clean up any previous page's instance state
|
||||
if (prev.parentId !== next.parentId && isPageId(next.parentId)) {
|
||||
const allMovingIds = new Set([prev.id])
|
||||
this.visitDescendants(prev.id, (id) => {
|
||||
allMovingIds.add(id)
|
||||
})
|
||||
|
||||
for (const instancePageState of this.pageStates) {
|
||||
if (instancePageState.pageId === next.parentId) continue
|
||||
const nextPageState = this._cleanupInstancePageState(instancePageState, allMovingIds)
|
||||
|
||||
if (nextPageState) {
|
||||
this.store.put([nextPageState])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (prev.parentId && isShapeId(prev.parentId)) {
|
||||
this._invalidParents.add(prev.parentId)
|
||||
}
|
||||
|
||||
if (next.parentId !== prev.parentId && isShapeId(next.parentId)) {
|
||||
this._invalidParents.add(next.parentId)
|
||||
}
|
||||
}
|
||||
|
||||
/** @internal */
|
||||
private _pageStateDidChange(prev: TLInstancePageState, next: TLInstancePageState) {
|
||||
if (prev?.selectedShapeIds !== next?.selectedShapeIds) {
|
||||
// ensure that descendants and ancestors are not selected at the same time
|
||||
const filtered = next.selectedShapeIds.filter((id) => {
|
||||
let parentId = this.getShape(id)?.parentId
|
||||
while (isShapeId(parentId)) {
|
||||
if (next.selectedShapeIds.includes(parentId)) {
|
||||
return false
|
||||
}
|
||||
parentId = this.getShape(parentId)?.parentId
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
let nextFocusedGroupId: null | TLShapeId = null
|
||||
|
||||
if (filtered.length > 0) {
|
||||
const commonGroupAncestor = this.findCommonAncestor(
|
||||
compact(filtered.map((id) => this.getShape(id))),
|
||||
(shape) => this.isShapeOfType<TLGroupShape>(shape, 'group')
|
||||
)
|
||||
|
||||
if (commonGroupAncestor) {
|
||||
nextFocusedGroupId = commonGroupAncestor
|
||||
}
|
||||
} else {
|
||||
if (next?.focusedGroupId) {
|
||||
nextFocusedGroupId = next.focusedGroupId
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
filtered.length !== next.selectedShapeIds.length ||
|
||||
nextFocusedGroupId !== next.focusedGroupId
|
||||
) {
|
||||
this.store.put([
|
||||
{ ...next, selectedShapeIds: filtered, focusedGroupId: nextFocusedGroupId ?? null },
|
||||
])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** @internal */
|
||||
private _pageWillBeDeleted(page: TLPage) {
|
||||
// page was deleted, need to check whether it's the current page and select another one if so
|
||||
if (this.instanceState.currentPageId !== page.id) return
|
||||
|
||||
const backupPageId = this.pages.find((p) => p.id !== page.id)?.id
|
||||
if (!backupPageId) return
|
||||
this.store.put([{ ...this.instanceState, currentPageId: backupPageId }])
|
||||
|
||||
// delete the camera and state for the page if necessary
|
||||
const cameraId = CameraRecordType.createId(page.id)
|
||||
const instance_PageStateId = InstancePageStateRecordType.createId(page.id)
|
||||
this.store.remove([cameraId, instance_PageStateId])
|
||||
}
|
||||
|
||||
/* --------------------- Errors --------------------- */
|
||||
|
||||
/** @internal */
|
||||
|
@ -1802,54 +1824,6 @@ export class Editor extends EventEmitter<TLEventMap> {
|
|||
return this
|
||||
}
|
||||
|
||||
/** @internal */
|
||||
private _cleanupInstancePageState(
|
||||
prevPageState: TLInstancePageState,
|
||||
shapesNoLongerInPage: Set<TLShapeId>
|
||||
) {
|
||||
let nextPageState = null as null | TLInstancePageState
|
||||
|
||||
const selectedShapeIds = prevPageState.selectedShapeIds.filter(
|
||||
(id) => !shapesNoLongerInPage.has(id)
|
||||
)
|
||||
if (selectedShapeIds.length !== prevPageState.selectedShapeIds.length) {
|
||||
if (!nextPageState) nextPageState = { ...prevPageState }
|
||||
nextPageState.selectedShapeIds = selectedShapeIds
|
||||
}
|
||||
|
||||
const erasingShapeIds = prevPageState.erasingShapeIds.filter(
|
||||
(id) => !shapesNoLongerInPage.has(id)
|
||||
)
|
||||
if (erasingShapeIds.length !== prevPageState.erasingShapeIds.length) {
|
||||
if (!nextPageState) nextPageState = { ...prevPageState }
|
||||
nextPageState.erasingShapeIds = erasingShapeIds
|
||||
}
|
||||
|
||||
if (prevPageState.hoveredShapeId && shapesNoLongerInPage.has(prevPageState.hoveredShapeId)) {
|
||||
if (!nextPageState) nextPageState = { ...prevPageState }
|
||||
nextPageState.hoveredShapeId = null
|
||||
}
|
||||
|
||||
if (prevPageState.editingShapeId && shapesNoLongerInPage.has(prevPageState.editingShapeId)) {
|
||||
if (!nextPageState) nextPageState = { ...prevPageState }
|
||||
nextPageState.editingShapeId = null
|
||||
}
|
||||
|
||||
const hintingShapeIds = prevPageState.hintingShapeIds.filter(
|
||||
(id) => !shapesNoLongerInPage.has(id)
|
||||
)
|
||||
if (hintingShapeIds.length !== prevPageState.hintingShapeIds.length) {
|
||||
if (!nextPageState) nextPageState = { ...prevPageState }
|
||||
nextPageState.hintingShapeIds = hintingShapeIds
|
||||
}
|
||||
|
||||
if (prevPageState.focusedGroupId && shapesNoLongerInPage.has(prevPageState.focusedGroupId)) {
|
||||
if (!nextPageState) nextPageState = { ...prevPageState }
|
||||
nextPageState.focusedGroupId = null
|
||||
}
|
||||
return nextPageState
|
||||
}
|
||||
|
||||
/* --------------------- Camera --------------------- */
|
||||
|
||||
/** @internal */
|
||||
|
|
|
@ -2,13 +2,9 @@ import { HistoryManager } from './HistoryManager'
|
|||
import { stack } from './Stack'
|
||||
|
||||
function createCounterHistoryManager() {
|
||||
const manager = new HistoryManager(
|
||||
{ emit: () => void null },
|
||||
() => null,
|
||||
() => {
|
||||
const manager = new HistoryManager({ emit: () => void null }, () => {
|
||||
return
|
||||
}
|
||||
)
|
||||
})
|
||||
const state = {
|
||||
count: 0,
|
||||
name: 'David',
|
||||
|
|
|
@ -29,10 +29,11 @@ export class HistoryManager<
|
|||
|
||||
constructor(
|
||||
private readonly ctx: CTX,
|
||||
private readonly onBatchComplete: () => void,
|
||||
private readonly annotateError: (error: unknown) => void
|
||||
) {}
|
||||
|
||||
onBatchComplete: () => void = () => void null
|
||||
|
||||
private _commands: Record<string, TLCommandHandler<any>> = {}
|
||||
|
||||
get numUndos() {
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
// let editor: Editor
|
||||
// beforeEach(() => {
|
||||
// editor = new Editor({
|
||||
// shapeUtils: [],
|
||||
// tools: [],
|
||||
// store: createTLStore({ shapeUtils: [] }),
|
||||
// getContainer: () => document.body,
|
||||
// })
|
||||
// })
|
||||
|
||||
describe('Side effect manager', () => {
|
||||
it.todo('Registers an onBeforeCreate handler')
|
||||
it.todo('Registers an onAfterCreate handler')
|
||||
it.todo('Registers an onBeforeChange handler')
|
||||
it.todo('Registers an onAfterChange handler')
|
||||
it.todo('Registers an onBeforeDelete handler')
|
||||
it.todo('Registers an onAfterDelete handler')
|
||||
it.todo('Registers a batch start handler')
|
||||
it.todo('Registers a batch complete handler')
|
||||
})
|
245
packages/editor/src/lib/editor/managers/SideEffectManager.ts
Normal file
245
packages/editor/src/lib/editor/managers/SideEffectManager.ts
Normal file
|
@ -0,0 +1,245 @@
|
|||
import { TLRecord, TLStore } from '@tldraw/tlschema'
|
||||
|
||||
/** @public */
|
||||
export type TLBeforeCreateHandler<R extends TLRecord> = (record: R, source: 'remote' | 'user') => R
|
||||
/** @public */
|
||||
export type TLAfterCreateHandler<R extends TLRecord> = (
|
||||
record: R,
|
||||
source: 'remote' | 'user'
|
||||
) => void
|
||||
/** @public */
|
||||
export type TLBeforeChangeHandler<R extends TLRecord> = (
|
||||
prev: R,
|
||||
next: R,
|
||||
source: 'remote' | 'user'
|
||||
) => R
|
||||
/** @public */
|
||||
export type TLAfterChangeHandler<R extends TLRecord> = (
|
||||
prev: R,
|
||||
next: R,
|
||||
source: 'remote' | 'user'
|
||||
) => void
|
||||
/** @public */
|
||||
export type TLBeforeDeleteHandler<R extends TLRecord> = (
|
||||
record: R,
|
||||
source: 'remote' | 'user'
|
||||
) => void | false
|
||||
/** @public */
|
||||
export type TLAfterDeleteHandler<R extends TLRecord> = (
|
||||
record: R,
|
||||
source: 'remote' | 'user'
|
||||
) => void
|
||||
/** @public */
|
||||
export type TLBatchCompleteHandler = () => void
|
||||
|
||||
/**
|
||||
* The side effect manager (aka a "correct state enforcer") is responsible
|
||||
* for making sure that the editor's state is always correct. This includes
|
||||
* things like: deleting a shape if its parent is deleted; unbinding
|
||||
* arrows when their binding target is deleted; etc.
|
||||
*
|
||||
* @public
|
||||
*/
|
||||
export class SideEffectManager<
|
||||
CTX extends {
|
||||
store: TLStore
|
||||
history: { onBatchComplete: () => void }
|
||||
}
|
||||
> {
|
||||
constructor(public editor: CTX) {
|
||||
editor.store.onBeforeCreate = (record, source) => {
|
||||
const handlers = this._beforeCreateHandlers[
|
||||
record.typeName
|
||||
] as TLBeforeCreateHandler<TLRecord>[]
|
||||
if (handlers) {
|
||||
let r = record
|
||||
for (const handler of handlers) {
|
||||
r = handler(r, source)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
editor.store.onAfterCreate = (record, source) => {
|
||||
const handlers = this._afterCreateHandlers[
|
||||
record.typeName
|
||||
] as TLAfterCreateHandler<TLRecord>[]
|
||||
if (handlers) {
|
||||
for (const handler of handlers) {
|
||||
handler(record, source)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
editor.store.onBeforeChange = (prev, next, source) => {
|
||||
const handlers = this._beforeChangeHandlers[
|
||||
next.typeName
|
||||
] as TLBeforeChangeHandler<TLRecord>[]
|
||||
if (handlers) {
|
||||
let r = next
|
||||
for (const handler of handlers) {
|
||||
r = handler(prev, r, source)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
return next
|
||||
}
|
||||
|
||||
let updateDepth = 0
|
||||
|
||||
editor.store.onAfterChange = (prev, next, source) => {
|
||||
updateDepth++
|
||||
|
||||
if (updateDepth > 1000) {
|
||||
console.error('[CleanupManager.onAfterChange] Maximum update depth exceeded, bailing out.')
|
||||
} else {
|
||||
const handlers = this._afterChangeHandlers[
|
||||
next.typeName
|
||||
] as TLAfterChangeHandler<TLRecord>[]
|
||||
if (handlers) {
|
||||
for (const handler of handlers) {
|
||||
handler(prev, next, source)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
updateDepth--
|
||||
}
|
||||
|
||||
editor.store.onBeforeDelete = (record, source) => {
|
||||
const handlers = this._beforeDeleteHandlers[
|
||||
record.typeName
|
||||
] as TLBeforeDeleteHandler<TLRecord>[]
|
||||
if (handlers) {
|
||||
for (const handler of handlers) {
|
||||
if (handler(record, source) === false) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
editor.store.onAfterDelete = (record, source) => {
|
||||
const handlers = this._afterDeleteHandlers[
|
||||
record.typeName
|
||||
] as TLAfterDeleteHandler<TLRecord>[]
|
||||
if (handlers) {
|
||||
for (const handler of handlers) {
|
||||
handler(record, source)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
editor.history.onBatchComplete = () => {
|
||||
this._batchCompleteHandlers.forEach((fn) => fn())
|
||||
}
|
||||
}
|
||||
|
||||
private _beforeCreateHandlers: Partial<{
|
||||
[K in TLRecord['typeName']]: TLBeforeCreateHandler<TLRecord & { typeName: K }>[]
|
||||
}> = {}
|
||||
private _afterCreateHandlers: Partial<{
|
||||
[K in TLRecord['typeName']]: TLAfterCreateHandler<TLRecord & { typeName: K }>[]
|
||||
}> = {}
|
||||
private _beforeChangeHandlers: Partial<{
|
||||
[K in TLRecord['typeName']]: TLBeforeChangeHandler<TLRecord & { typeName: K }>[]
|
||||
}> = {}
|
||||
private _afterChangeHandlers: Partial<{
|
||||
[K in TLRecord['typeName']]: TLAfterChangeHandler<TLRecord & { typeName: K }>[]
|
||||
}> = {}
|
||||
|
||||
private _beforeDeleteHandlers: Partial<{
|
||||
[K in TLRecord['typeName']]: TLBeforeDeleteHandler<TLRecord & { typeName: K }>[]
|
||||
}> = {}
|
||||
|
||||
private _afterDeleteHandlers: Partial<{
|
||||
[K in TLRecord['typeName']]: TLAfterDeleteHandler<TLRecord & { typeName: K }>[]
|
||||
}> = {}
|
||||
|
||||
private _batchCompleteHandlers: TLBatchCompleteHandler[] = []
|
||||
|
||||
registerBeforeCreateHandler<T extends TLRecord['typeName']>(
|
||||
typeName: T,
|
||||
handler: TLBeforeCreateHandler<TLRecord & { typeName: T }>
|
||||
) {
|
||||
const handlers = this._beforeCreateHandlers[typeName] as TLBeforeCreateHandler<any>[]
|
||||
if (!handlers) this._beforeCreateHandlers[typeName] = []
|
||||
this._beforeCreateHandlers[typeName]!.push(handler)
|
||||
}
|
||||
|
||||
registerAfterCreateHandler<T extends TLRecord['typeName']>(
|
||||
typeName: T,
|
||||
handler: TLAfterCreateHandler<TLRecord & { typeName: T }>
|
||||
) {
|
||||
const handlers = this._afterCreateHandlers[typeName] as TLAfterCreateHandler<any>[]
|
||||
if (!handlers) this._afterCreateHandlers[typeName] = []
|
||||
this._afterCreateHandlers[typeName]!.push(handler)
|
||||
}
|
||||
|
||||
registerBeforeChangeHandler<T extends TLRecord['typeName']>(
|
||||
typeName: T,
|
||||
handler: TLBeforeChangeHandler<TLRecord & { typeName: T }>
|
||||
) {
|
||||
const handlers = this._beforeChangeHandlers[typeName] as TLBeforeChangeHandler<any>[]
|
||||
if (!handlers) this._beforeChangeHandlers[typeName] = []
|
||||
this._beforeChangeHandlers[typeName]!.push(handler)
|
||||
}
|
||||
|
||||
registerAfterChangeHandler<T extends TLRecord['typeName']>(
|
||||
typeName: T,
|
||||
handler: TLAfterChangeHandler<TLRecord & { typeName: T }>
|
||||
) {
|
||||
const handlers = this._afterChangeHandlers[typeName] as TLAfterChangeHandler<any>[]
|
||||
if (!handlers) this._afterChangeHandlers[typeName] = []
|
||||
this._afterChangeHandlers[typeName]!.push(handler as TLAfterChangeHandler<any>)
|
||||
}
|
||||
|
||||
registerBeforeDeleteHandler<T extends TLRecord['typeName']>(
|
||||
typeName: T,
|
||||
handler: TLBeforeDeleteHandler<TLRecord & { typeName: T }>
|
||||
) {
|
||||
const handlers = this._beforeDeleteHandlers[typeName] as TLBeforeDeleteHandler<any>[]
|
||||
if (!handlers) this._beforeDeleteHandlers[typeName] = []
|
||||
this._beforeDeleteHandlers[typeName]!.push(handler as TLBeforeDeleteHandler<any>)
|
||||
}
|
||||
|
||||
registerAfterDeleteHandler<T extends TLRecord['typeName']>(
|
||||
typeName: T,
|
||||
handler: TLAfterDeleteHandler<TLRecord & { typeName: T }>
|
||||
) {
|
||||
const handlers = this._afterDeleteHandlers[typeName] as TLAfterDeleteHandler<any>[]
|
||||
if (!handlers) this._afterDeleteHandlers[typeName] = []
|
||||
this._afterDeleteHandlers[typeName]!.push(handler as TLAfterDeleteHandler<any>)
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a handler to be called when a store completes a batch.
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* let count = 0
|
||||
*
|
||||
* editor.cleanup.registerBatchCompleteHandler(() => count++)
|
||||
*
|
||||
* editor.selectAll()
|
||||
* expect(count).toBe(1)
|
||||
*
|
||||
* editor.batch(() => {
|
||||
* editor.selectNone()
|
||||
* editor.selectAll()
|
||||
* })
|
||||
*
|
||||
* expect(count).toBe(2)
|
||||
* ```
|
||||
*
|
||||
* @param handler - The handler to call
|
||||
*
|
||||
* @public
|
||||
*/
|
||||
registerBatchCompleteHandler(handler: TLBatchCompleteHandler) {
|
||||
this._batchCompleteHandlers.push(handler)
|
||||
}
|
||||
}
|
|
@ -244,6 +244,8 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
|
|||
// (undocumented)
|
||||
_flushHistory(): void;
|
||||
get: <K extends IdOf<R>>(id: K) => RecFromId<K> | undefined;
|
||||
// (undocumented)
|
||||
getRecordType: <T extends R>(record: R) => T;
|
||||
getSnapshot(scope?: 'all' | RecordScope): StoreSnapshot<R>;
|
||||
has: <K extends IdOf<R>>(id: K) => boolean;
|
||||
readonly history: Atom<number, RecordsDiff<R>>;
|
||||
|
@ -255,10 +257,12 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
|
|||
// @internal (undocumented)
|
||||
markAsPossiblyCorrupted(): void;
|
||||
mergeRemoteChanges: (fn: () => void) => void;
|
||||
onAfterChange?: (prev: R, next: R) => void;
|
||||
onAfterCreate?: (record: R) => void;
|
||||
onAfterDelete?: (prev: R) => void;
|
||||
onBeforeDelete?: (prev: R) => void;
|
||||
onAfterChange?: (prev: R, next: R, source: 'remote' | 'user') => void;
|
||||
onAfterCreate?: (record: R, source: 'remote' | 'user') => void;
|
||||
onAfterDelete?: (prev: R, source: 'remote' | 'user') => void;
|
||||
onBeforeChange?: (prev: R, next: R, source: 'remote' | 'user') => R;
|
||||
onBeforeCreate?: (next: R, source: 'remote' | 'user') => R;
|
||||
onBeforeDelete?: (prev: R, source: 'remote' | 'user') => false | void;
|
||||
// (undocumented)
|
||||
readonly props: Props;
|
||||
put: (records: R[], phaseOverride?: 'initialize') => void;
|
||||
|
|
|
@ -297,13 +297,29 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
|
|||
this.allRecords().forEach((record) => this.schema.validateRecord(this, record, phase, null))
|
||||
}
|
||||
|
||||
/**
|
||||
* A callback fired after each record's change.
|
||||
*
|
||||
* @param prev - The previous value, if any.
|
||||
* @param next - The next value.
|
||||
*/
|
||||
onBeforeCreate?: (next: R, source: 'remote' | 'user') => R
|
||||
|
||||
/**
|
||||
* A callback fired after a record is created. Use this to perform related updates to other
|
||||
* records in the store.
|
||||
*
|
||||
* @param record - The record to be created
|
||||
*/
|
||||
onAfterCreate?: (record: R) => void
|
||||
onAfterCreate?: (record: R, source: 'remote' | 'user') => void
|
||||
|
||||
/**
|
||||
* A callback before after each record's change.
|
||||
*
|
||||
* @param prev - The previous value, if any.
|
||||
* @param next - The next value.
|
||||
*/
|
||||
onBeforeChange?: (prev: R, next: R, source: 'remote' | 'user') => R
|
||||
|
||||
/**
|
||||
* A callback fired after each record's change.
|
||||
|
@ -311,21 +327,21 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
|
|||
* @param prev - The previous value, if any.
|
||||
* @param next - The next value.
|
||||
*/
|
||||
onAfterChange?: (prev: R, next: R) => void
|
||||
onAfterChange?: (prev: R, next: R, source: 'remote' | 'user') => void
|
||||
|
||||
/**
|
||||
* A callback fired before a record is deleted.
|
||||
*
|
||||
* @param prev - The record that will be deleted.
|
||||
*/
|
||||
onBeforeDelete?: (prev: R) => void
|
||||
onBeforeDelete?: (prev: R, source: 'remote' | 'user') => false | void
|
||||
|
||||
/**
|
||||
* A callback fired after a record is deleted.
|
||||
*
|
||||
* @param prev - The record that will be deleted.
|
||||
*/
|
||||
onAfterDelete?: (prev: R) => void
|
||||
onAfterDelete?: (prev: R, source: 'remote' | 'user') => void
|
||||
|
||||
// used to avoid running callbacks when rolling back changes in sync client
|
||||
private _runCallbacks = true
|
||||
|
@ -353,12 +369,18 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
|
|||
// changes (e.g. additions, deletions, or updates that produce a new value).
|
||||
let didChange = false
|
||||
|
||||
const beforeCreate = this.onBeforeCreate && this._runCallbacks ? this.onBeforeCreate : null
|
||||
const beforeUpdate = this.onBeforeChange && this._runCallbacks ? this.onBeforeChange : null
|
||||
const source = this.isMergingRemoteChanges ? 'remote' : 'user'
|
||||
|
||||
for (let i = 0, n = records.length; i < n; i++) {
|
||||
record = records[i]
|
||||
|
||||
const recordAtom = (map ?? currentMap)[record.id as IdOf<R>]
|
||||
|
||||
if (recordAtom) {
|
||||
if (beforeUpdate) record = beforeUpdate(recordAtom.value, record, source)
|
||||
|
||||
// If we already have an atom for this record, update its value.
|
||||
|
||||
const initialValue = recordAtom.__unsafe__getWithoutCapture()
|
||||
|
@ -382,6 +404,8 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
|
|||
updates[record.id] = [initialValue, finalValue]
|
||||
}
|
||||
} else {
|
||||
if (beforeCreate) record = beforeCreate(record, source)
|
||||
|
||||
didChange = true
|
||||
|
||||
// If we don't have an atom, create one.
|
||||
|
@ -418,21 +442,23 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
|
|||
removed: {} as Record<IdOf<R>, R>,
|
||||
})
|
||||
|
||||
if (this._runCallbacks) {
|
||||
const { onAfterCreate, onAfterChange } = this
|
||||
|
||||
if (onAfterCreate && this._runCallbacks) {
|
||||
if (onAfterCreate) {
|
||||
// Run the onAfterChange callback for addition.
|
||||
Object.values(additions).forEach((record) => {
|
||||
onAfterCreate(record)
|
||||
onAfterCreate(record, source)
|
||||
})
|
||||
}
|
||||
|
||||
if (onAfterChange && this._runCallbacks) {
|
||||
if (onAfterChange) {
|
||||
// Run the onAfterChange callback for update.
|
||||
Object.values(updates).forEach(([from, to]) => {
|
||||
onAfterChange(from, to)
|
||||
onAfterChange(from, to, source)
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -444,12 +470,17 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
|
|||
*/
|
||||
remove = (ids: IdOf<R>[]): void => {
|
||||
transact(() => {
|
||||
const cancelled = [] as IdOf<R>[]
|
||||
const source = this.isMergingRemoteChanges ? 'remote' : 'user'
|
||||
|
||||
if (this.onBeforeDelete && this._runCallbacks) {
|
||||
for (const id of ids) {
|
||||
const atom = this.atoms.__unsafe__getWithoutCapture()[id]
|
||||
if (!atom) continue
|
||||
|
||||
this.onBeforeDelete(atom.value)
|
||||
if (this.onBeforeDelete(atom.value, source) === false) {
|
||||
cancelled.push(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -460,6 +491,7 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
|
|||
let result: typeof atoms | undefined = undefined
|
||||
|
||||
for (const id of ids) {
|
||||
if (cancelled.includes(id)) continue
|
||||
if (!(id in atoms)) continue
|
||||
if (!result) result = { ...atoms }
|
||||
if (!removed) removed = {} as Record<IdOf<R>, R>
|
||||
|
@ -476,8 +508,12 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
|
|||
|
||||
// If we have an onAfterChange, run it for each removed record.
|
||||
if (this.onAfterDelete && this._runCallbacks) {
|
||||
let record: R
|
||||
for (let i = 0, n = ids.length; i < n; i++) {
|
||||
this.onAfterDelete(removed[ids[i]])
|
||||
record = removed[ids[i]]
|
||||
if (record) {
|
||||
this.onAfterDelete(record, source)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
@ -596,6 +632,7 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
|
|||
console.error(`Record ${id} not found. This is probably an error`)
|
||||
return
|
||||
}
|
||||
|
||||
this.put([updater(atom.__unsafe__getWithoutCapture() as any as RecFromId<K>) as any])
|
||||
}
|
||||
|
||||
|
@ -752,6 +789,14 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
|
|||
}
|
||||
}
|
||||
|
||||
getRecordType = <T extends R>(record: R): T => {
|
||||
const type = this.schema.types[record.typeName as R['typeName']]
|
||||
if (!type) {
|
||||
throw new Error(`Record type ${record.typeName} not found`)
|
||||
}
|
||||
return type as unknown as T
|
||||
}
|
||||
|
||||
private _integrityChecker?: () => void | undefined
|
||||
|
||||
/** @internal */
|
||||
|
|
Loading…
Reference in a new issue