Store-level "operation end" event (#3748)

This adds a store-level "operation end" event which fires at the end of
atomic operations. It includes some other changes too:

- The `SideEffectManager` now lives in & is a property of the store as
`StoreSideEffects`. One benefit to this is that instead of overriding
methods on the store to register side effects (meaning the store can
only ever be used in one place) the store now calls directly into the
side effect manager, which is responsible for dealing with any other
callbacks
- The history manager's "batch complete" event is gone, in favour of
this new event. We were using the batch complete event for only one
thing, calling `onChildrenChange` - which meant it wasn't getting called
for undo/redo events, which aren't part of a batch. `onChildrenChange`
is now called after each atomic store operation affecting children.

I've also added a rough pin example which shows (kinda messily) how you
might use the operation complete handler to traverse a graph of bindings
and resolve constraints between them.

### Change Type

- [x] `sdk` — Changes the tldraw SDK
- [x] `feature` — New feature

### Release Notes

#### Breaking changes
`editor.registerBatchCompleteHandler` has been replaced with
`editor.registerOperationCompleteHandler`
This commit is contained in:
alex 2024-05-14 10:42:41 +01:00 committed by GitHub
parent 5a15c49d63
commit ab807afda3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 754 additions and 357 deletions

View file

@ -0,0 +1,335 @@
import {
BindingOnShapeChangeOptions,
BindingOnShapeDeleteOptions,
BindingUtil,
Box,
DefaultFillStyle,
DefaultToolbar,
DefaultToolbarContent,
RecordProps,
Rectangle2d,
ShapeUtil,
StateNode,
TLBaseBinding,
TLBaseShape,
TLEditorComponents,
TLEventHandlers,
TLOnTranslateEndHandler,
TLOnTranslateStartHandler,
TLShapeId,
TLUiComponents,
TLUiOverrides,
Tldraw,
TldrawUiMenuItem,
Vec,
VecModel,
createShapeId,
invLerp,
lerp,
useIsToolSelected,
useTools,
} from 'tldraw'
// eslint-disable-next-line @typescript-eslint/ban-types
type PinShape = TLBaseShape<'pin', {}>
const offsetX = -16
const offsetY = -26
class PinShapeUtil extends ShapeUtil<PinShape> {
static override type = 'pin' as const
static override props: RecordProps<PinShape> = {}
override getDefaultProps() {
return {}
}
override canBind = () => false
override canEdit = () => false
override canResize = () => false
override hideRotateHandle = () => true
override isAspectRatioLocked = () => true
override getGeometry() {
return new Rectangle2d({
width: 32,
height: 32,
x: offsetX,
y: offsetY,
isFilled: true,
})
}
override component() {
return (
<div
style={{
width: '100%',
height: '100%',
marginLeft: offsetX,
marginTop: offsetY,
fontSize: '26px',
textAlign: 'center',
}}
>
📍
</div>
)
}
override indicator() {
return <rect width={32} height={32} x={offsetX} y={offsetY} />
}
override onTranslateStart: TLOnTranslateStartHandler<PinShape> = (shape) => {
const bindings = this.editor.getBindingsFromShape(shape, 'pin')
this.editor.deleteBindings(bindings)
}
override onTranslateEnd: TLOnTranslateEndHandler<PinShape> = (initial, pin) => {
const pageAnchor = this.editor.getShapePageTransform(pin).applyToPoint({ x: 0, y: 0 })
const targets = this.editor
.getShapesAtPoint(pageAnchor, { hitInside: true })
.filter(
(shape) =>
shape.type !== 'pin' && shape.parentId === pin.parentId && shape.index < pin.index
)
for (const target of targets) {
const targetBounds = Box.ZeroFix(this.editor.getShapeGeometry(target)!.bounds)
const pointInTargetSpace = this.editor.getPointInShapeSpace(target, pageAnchor)
const anchor = {
x: invLerp(targetBounds.minX, targetBounds.maxX, pointInTargetSpace.x),
y: invLerp(targetBounds.minY, targetBounds.maxY, pointInTargetSpace.y),
}
this.editor.createBinding({
type: 'pin',
fromId: pin.id,
toId: target.id,
props: {
anchor,
},
})
}
}
}
type PinBinding = TLBaseBinding<
'pin',
{
anchor: VecModel
}
>
class PinBindingUtil extends BindingUtil<PinBinding> {
static override type = 'pin' as const
override getDefaultProps() {
return {
anchor: { x: 0.5, y: 0.5 },
}
}
private changedToShapes = new Set<TLShapeId>()
override onOperationComplete(): void {
if (this.changedToShapes.size === 0) return
const fixedShapes = this.changedToShapes
const toCheck = [...this.changedToShapes]
const initialPositions = new Map<TLShapeId, VecModel>()
const targetDeltas = new Map<TLShapeId, Map<TLShapeId, VecModel>>()
const addTargetDelta = (fromId: TLShapeId, toId: TLShapeId, delta: VecModel) => {
if (!targetDeltas.has(fromId)) targetDeltas.set(fromId, new Map())
targetDeltas.get(fromId)!.set(toId, delta)
if (!targetDeltas.has(toId)) targetDeltas.set(toId, new Map())
targetDeltas.get(toId)!.set(fromId, { x: -delta.x, y: -delta.y })
}
const allShapes = new Set<TLShapeId>()
while (toCheck.length) {
const shapeId = toCheck.pop()!
const shape = this.editor.getShape(shapeId)
if (!shape) continue
if (allShapes.has(shapeId)) continue
allShapes.add(shapeId)
const bindings = this.editor.getBindingsToShape<PinBinding>(shape, 'pin')
for (const binding of bindings) {
if (allShapes.has(binding.fromId)) continue
allShapes.add(binding.fromId)
const pin = this.editor.getShape<PinShape>(binding.fromId)
if (!pin) continue
const pinPosition = this.editor.getShapePageTransform(pin).applyToPoint({ x: 0, y: 0 })
initialPositions.set(pin.id, pinPosition)
for (const binding of this.editor.getBindingsFromShape<PinBinding>(pin.id, 'pin')) {
const shapeBounds = this.editor.getShapeGeometry(binding.toId)!.bounds
const shapeAnchor = {
x: lerp(shapeBounds.minX, shapeBounds.maxX, binding.props.anchor.x),
y: lerp(shapeBounds.minY, shapeBounds.maxY, binding.props.anchor.y),
}
const currentPageAnchor = this.editor
.getShapePageTransform(binding.toId)
.applyToPoint(shapeAnchor)
const shapeOrigin = this.editor
.getShapePageTransform(binding.toId)
.applyToPoint({ x: 0, y: 0 })
initialPositions.set(binding.toId, shapeOrigin)
addTargetDelta(pin.id, binding.toId, {
x: currentPageAnchor.x - shapeOrigin.x,
y: currentPageAnchor.y - shapeOrigin.y,
})
if (!allShapes.has(binding.toId)) toCheck.push(binding.toId)
}
}
}
const currentPositions = new Map(initialPositions)
const iterations = 30
for (let i = 0; i < iterations; i++) {
const movements = new Map<TLShapeId, VecModel[]>()
for (const [aId, deltas] of targetDeltas) {
if (fixedShapes.has(aId)) continue
const aPosition = currentPositions.get(aId)!
for (const [bId, targetDelta] of deltas) {
const bPosition = currentPositions.get(bId)!
const adjustmentDelta = {
x: targetDelta.x - (aPosition.x - bPosition.x),
y: targetDelta.y - (aPosition.y - bPosition.y),
}
if (!movements.has(aId)) movements.set(aId, [])
movements.get(aId)!.push(adjustmentDelta)
}
}
for (const [shapeId, deltas] of movements) {
const currentPosition = currentPositions.get(shapeId)!
currentPositions.set(shapeId, Vec.Average(deltas).add(currentPosition))
}
}
const updates = []
for (const [shapeId, position] of currentPositions) {
const delta = Vec.Sub(position, initialPositions.get(shapeId)!)
if (delta.len2() <= 0.01) continue
const newPosition = this.editor.getPointInParentSpace(shapeId, position)
updates.push({
id: shapeId,
type: this.editor.getShape(shapeId)!.type,
x: newPosition.x,
y: newPosition.y,
})
}
if (updates.length === 0) {
this.changedToShapes.clear()
} else {
this.editor.updateShapes(updates)
}
}
// when the shape we're stuck to changes, update the pin's position
override onAfterChangeToShape({ binding }: BindingOnShapeChangeOptions<PinBinding>): void {
this.changedToShapes.add(binding.toId)
}
// when the thing we're stuck to is deleted, delete the pin too
override onBeforeDeleteToShape({ binding }: BindingOnShapeDeleteOptions<PinBinding>): void {
const pin = this.editor.getShape<PinShape>(binding.fromId)
if (pin) this.editor.deleteShape(pin.id)
}
}
class PinTool extends StateNode {
static override id = 'pin'
override onEnter = () => {
this.editor.setCursor({ type: 'cross', rotation: 0 })
}
override onPointerDown: TLEventHandlers['onPointerDown'] = (info) => {
const { currentPagePoint } = this.editor.inputs
const pinId = createShapeId()
this.editor.mark(`creating:${pinId}`)
this.editor.createShape({
id: pinId,
type: 'pin',
x: currentPagePoint.x,
y: currentPagePoint.y,
})
this.editor.setSelectedShapes([pinId])
this.editor.setCurrentTool('select.translating', {
...info,
target: 'shape',
shape: this.editor.getShape(pinId),
isCreating: true,
onInteractionEnd: 'pin',
onCreate: () => {
this.editor.setCurrentTool('pin')
},
})
}
}
const overrides: TLUiOverrides = {
tools(editor, schema) {
schema['pin'] = {
id: 'pin',
label: 'Pin',
icon: 'heart-icon',
kbd: 'p',
onSelect: () => {
editor.setCurrentTool('pin')
},
}
return schema
},
}
const components: TLUiComponents & TLEditorComponents = {
Toolbar: (...props) => {
const pin = useTools().pin
const isPinSelected = useIsToolSelected(pin)
return (
<DefaultToolbar {...props}>
<TldrawUiMenuItem {...pin} isSelected={isPinSelected} />
<DefaultToolbarContent />
</DefaultToolbar>
)
},
}
export default function PinExample() {
return (
<div className="tldraw__editor">
<Tldraw
persistenceKey="pin-example"
onMount={(editor) => {
;(window as any).editor = editor
editor.setStyleForNextShapes(DefaultFillStyle, 'semi')
}}
shapeUtils={[PinShapeUtil]}
bindingUtils={[PinBindingUtil]}
tools={[PinTool]}
overrides={overrides}
components={components}
/>
</div>
)
}

View file

@ -0,0 +1,9 @@
---
title: Pin (bindings)
component: ./PinExample.tsx
category: shapes/tools
---
A pin, using bindings to pin together networks of shapes.
---

View file

@ -37,6 +37,7 @@ import { SerializedStore } from '@tldraw/store';
import { Signal } from '@tldraw/state';
import { Store } from '@tldraw/store';
import { StoreSchema } from '@tldraw/store';
import { StoreSideEffects } from '@tldraw/store';
import { StoreSnapshot } from '@tldraw/store';
import { StyleProp } from '@tldraw/tlschema';
import { StylePropValue } from '@tldraw/tlschema';
@ -239,6 +240,8 @@ export abstract class BindingUtil<Binding extends TLUnknownBinding = TLUnknownBi
// (undocumented)
onBeforeDeleteToShape?(options: BindingOnShapeDeleteOptions<Binding>): void;
// (undocumented)
onOperationComplete?(): void;
// (undocumented)
static props?: RecordProps<TLUnknownBinding>;
static type: string;
}
@ -1013,7 +1016,7 @@ export class Editor extends EventEmitter<TLEventMap> {
shapeUtils: {
readonly [K in string]?: ShapeUtil<TLUnknownShape>;
};
readonly sideEffects: SideEffectManager<this>;
readonly sideEffects: StoreSideEffects<TLRecord>;
slideCamera(opts?: {
direction: VecLike;
friction: number;
@ -1853,48 +1856,6 @@ export class SharedStyleMap extends ReadonlySharedStyleMap {
// @public
export function shortAngleDist(a0: number, a1: number): number;
// @public
export class SideEffectManager<CTX extends {
history: {
onBatchComplete: () => void;
};
store: TLStore;
}> {
constructor(editor: CTX);
// (undocumented)
editor: CTX;
// @internal
register(handlersByType: {
[R in TLRecord as R['typeName']]?: {
afterChange?: TLAfterChangeHandler<R>;
afterCreate?: TLAfterCreateHandler<R>;
afterDelete?: TLAfterDeleteHandler<R>;
beforeChange?: TLBeforeChangeHandler<R>;
beforeCreate?: TLBeforeCreateHandler<R>;
beforeDelete?: TLBeforeDeleteHandler<R>;
};
}): () => void;
registerAfterChangeHandler<T extends TLRecord['typeName']>(typeName: T, handler: TLAfterChangeHandler<TLRecord & {
typeName: T;
}>): () => void;
registerAfterCreateHandler<T extends TLRecord['typeName']>(typeName: T, handler: TLAfterCreateHandler<TLRecord & {
typeName: T;
}>): () => void;
registerAfterDeleteHandler<T extends TLRecord['typeName']>(typeName: T, handler: TLAfterDeleteHandler<TLRecord & {
typeName: T;
}>): () => void;
registerBatchCompleteHandler(handler: TLBatchCompleteHandler): () => void;
registerBeforeChangeHandler<T extends TLRecord['typeName']>(typeName: T, handler: TLBeforeChangeHandler<TLRecord & {
typeName: T;
}>): () => void;
registerBeforeCreateHandler<T extends TLRecord['typeName']>(typeName: T, handler: TLBeforeCreateHandler<TLRecord & {
typeName: T;
}>): () => void;
registerBeforeDeleteHandler<T extends TLRecord['typeName']>(typeName: T, handler: TLBeforeDeleteHandler<TLRecord & {
typeName: T;
}>): () => void;
}
// @public (undocumented)
export const SIDES: readonly ["top", "right", "bottom", "left"];
@ -2058,15 +2019,6 @@ export interface SvgExportDef {
// @public
export const TAB_ID: string;
// @public (undocumented)
export type TLAfterChangeHandler<R extends TLRecord> = (prev: R, next: R, source: 'remote' | 'user') => void;
// @public (undocumented)
export type TLAfterCreateHandler<R extends TLRecord> = (record: R, source: 'remote' | 'user') => void;
// @public (undocumented)
export type TLAfterDeleteHandler<R extends TLRecord> = (record: R, source: 'remote' | 'user') => void;
// @public (undocumented)
export type TLAnyBindingUtilConstructor = TLBindingUtilConstructor<any>;
@ -2091,18 +2043,6 @@ export interface TLBaseEventInfo {
type: UiEventType;
}
// @public (undocumented)
export type TLBatchCompleteHandler = () => void;
// @public (undocumented)
export type TLBeforeChangeHandler<R extends TLRecord> = (prev: R, next: R, source: 'remote' | 'user') => R;
// @public (undocumented)
export type TLBeforeCreateHandler<R extends TLRecord> = (record: R, source: 'remote' | 'user') => R;
// @public (undocumented)
export type TLBeforeDeleteHandler<R extends TLRecord> = (record: R, source: 'remote' | 'user') => false | void;
// @public (undocumented)
export interface TLBindingUtilConstructor<T extends TLUnknownBinding, U extends BindingUtil<T> = BindingUtil<T>> {
// (undocumented)

View file

@ -133,16 +133,6 @@ export {
type TLBindingUtilConstructor,
} from './lib/editor/bindings/BindingUtil'
export { HistoryManager } from './lib/editor/managers/HistoryManager'
export type {
SideEffectManager,
TLAfterChangeHandler,
TLAfterCreateHandler,
TLAfterDeleteHandler,
TLBatchCompleteHandler,
TLBeforeChangeHandler,
TLBeforeCreateHandler,
TLBeforeDeleteHandler,
} from './lib/editor/managers/SideEffectManager'
export {
type BoundsSnapGeometry,
type BoundsSnapPoint,

View file

@ -2,6 +2,7 @@ import { EMPTY_ARRAY, atom, computed, transact } from '@tldraw/state'
import {
ComputedCache,
RecordType,
StoreSideEffects,
StoreSnapshot,
UnknownRecord,
reverseRecordsDiff,
@ -127,7 +128,6 @@ import { ClickManager } from './managers/ClickManager'
import { EnvironmentManager } from './managers/EnvironmentManager'
import { HistoryManager } from './managers/HistoryManager'
import { ScribbleManager } from './managers/ScribbleManager'
import { SideEffectManager } from './managers/SideEffectManager'
import { SnapManager } from './managers/SnapManager/SnapManager'
import { TextManager } from './managers/TextManager'
import { TickManager } from './managers/TickManager'
@ -300,8 +300,6 @@ export class Editor extends EventEmitter<TLEventMap> {
// Cleanup
const invalidParents = new Set<TLShapeId>()
const cleanupInstancePageState = (
prevPageState: TLInstancePageState,
shapesNoLongerInPage: Set<TLShapeId>
@ -349,10 +347,12 @@ export class Editor extends EventEmitter<TLEventMap> {
return nextPageState
}
this.sideEffects = new SideEffectManager(this)
this.sideEffects = this.store.sideEffects
const invalidParents = new Set<TLShapeId>()
let invalidBindingTypes = new Set<string>()
this.disposables.add(
this.sideEffects.registerBatchCompleteHandler(() => {
this.sideEffects.registerOperationCompleteHandler(() => {
for (const parentId of invalidParents) {
invalidParents.delete(parentId)
const parent = this.getShape(parentId)
@ -366,6 +366,15 @@ export class Editor extends EventEmitter<TLEventMap> {
}
}
if (invalidBindingTypes.size) {
const t = invalidBindingTypes
invalidBindingTypes = new Set()
for (const type of t) {
const util = this.getBindingUtil(type)
util.onOperationComplete?.()
}
}
this.emit('update')
})
)
@ -375,6 +384,7 @@ export class Editor extends EventEmitter<TLEventMap> {
shape: {
afterChange: (shapeBefore, shapeAfter) => {
for (const binding of this.getBindingsInvolvingShape(shapeAfter)) {
invalidBindingTypes.add(binding.type)
if (binding.fromId === shapeAfter.id) {
this.getBindingUtil(binding).onAfterChangeFromShape?.({
binding,
@ -398,6 +408,8 @@ export class Editor extends EventEmitter<TLEventMap> {
if (!descendantShape) return
for (const binding of this.getBindingsInvolvingShape(descendantShape)) {
invalidBindingTypes.add(binding.type)
if (binding.fromId === descendantShape.id) {
this.getBindingUtil(binding).onAfterChangeFromShape?.({
binding,
@ -451,6 +463,7 @@ export class Editor extends EventEmitter<TLEventMap> {
const deleteBindingIds: TLBindingId[] = []
for (const binding of this.getBindingsInvolvingShape(shape)) {
invalidBindingTypes.add(binding.type)
if (binding.fromId === shape.id) {
this.getBindingUtil(binding).onBeforeDeleteFromShape?.({ binding, shape })
deleteBindingIds.push(binding.id)
@ -481,6 +494,7 @@ export class Editor extends EventEmitter<TLEventMap> {
return binding
},
afterCreate: (binding) => {
invalidBindingTypes.add(binding.type)
this.getBindingUtil(binding).onAfterCreate?.({ binding })
},
beforeChange: (bindingBefore, bindingAfter) => {
@ -492,12 +506,14 @@ export class Editor extends EventEmitter<TLEventMap> {
return bindingAfter
},
afterChange: (bindingBefore, bindingAfter) => {
invalidBindingTypes.add(bindingAfter.type)
this.getBindingUtil(bindingAfter).onAfterChange?.({ bindingBefore, bindingAfter })
},
beforeDelete: (binding) => {
this.getBindingUtil(binding).onBeforeDelete?.({ binding })
},
afterDelete: (binding) => {
invalidBindingTypes.add(binding.type)
this.getBindingUtil(binding).onAfterDelete?.({ binding })
},
},
@ -705,11 +721,11 @@ export class Editor extends EventEmitter<TLEventMap> {
readonly scribbles: ScribbleManager
/**
* A manager for side effects and correct state enforcement. See {@link SideEffectManager} for details.
* A manager for side effects and correct state enforcement. See {@link @tldraw/store#StoreSideEffects} for details.
*
* @public
*/
readonly sideEffects: SideEffectManager<this>
readonly sideEffects: StoreSideEffects<TLRecord>
/**
* The current HTML element containing the editor.

View file

@ -61,6 +61,8 @@ export abstract class BindingUtil<Binding extends TLUnknownBinding = TLUnknownBi
*/
abstract getDefaultProps(): Partial<Binding['props']>
onOperationComplete?(): void
// self lifecycle hooks
onBeforeCreate?(options: BindingOnCreateOptions<Binding>): Binding | void
onAfterCreate?(options: BindingOnCreateOptions<Binding>): void

View file

@ -1,6 +1,5 @@
{
"extends": "../../config/tsconfig.base.json",
"include": ["src"],
"exclude": ["node_modules", "dist", "**/*.css", ".tsbuild*"],
"compilerOptions": {
"outDir": "./.tsbuild",

View file

@ -319,12 +319,6 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
markAsPossiblyCorrupted(): void;
mergeRemoteChanges: (fn: () => void) => void;
migrateSnapshot(snapshot: StoreSnapshot<R>): StoreSnapshot<R>;
onAfterChange?: (prev: R, next: R, source: 'remote' | 'user') => void;
onAfterCreate?: (record: R, source: 'remote' | 'user') => void;
onAfterDelete?: (prev: R, source: 'remote' | 'user') => void;
onBeforeChange?: (prev: R, next: R, source: 'remote' | 'user') => R;
onBeforeCreate?: (next: R, source: 'remote' | 'user') => R;
onBeforeDelete?: (prev: R, source: 'remote' | 'user') => false | void;
// (undocumented)
readonly props: Props;
put: (records: R[], phaseOverride?: 'initialize') => void;
@ -337,12 +331,32 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
readonly [K in RecordScope]: ReadonlySet<R['typeName']>;
};
serialize: (scope?: 'all' | RecordScope) => SerializedStore<R>;
// (undocumented)
readonly sideEffects: StoreSideEffects<R>;
unsafeGetWithoutCapture: <K extends IdOf<R>>(id: K) => RecFromId<K> | undefined;
update: <K extends IdOf<R>>(id: K, updater: (record: RecFromId<K>) => RecFromId<K>) => void;
// (undocumented)
validate(phase: 'createRecord' | 'initialize' | 'tests' | 'updateRecord'): void;
}
// @public (undocumented)
export type StoreAfterChangeHandler<R extends UnknownRecord> = (prev: R, next: R, source: 'remote' | 'user') => void;
// @public (undocumented)
export type StoreAfterCreateHandler<R extends UnknownRecord> = (record: R, source: 'remote' | 'user') => void;
// @public (undocumented)
export type StoreAfterDeleteHandler<R extends UnknownRecord> = (record: R, source: 'remote' | 'user') => void;
// @public (undocumented)
export type StoreBeforeChangeHandler<R extends UnknownRecord> = (prev: R, next: R, source: 'remote' | 'user') => R;
// @public (undocumented)
export type StoreBeforeCreateHandler<R extends UnknownRecord> = (record: R, source: 'remote' | 'user') => R;
// @public (undocumented)
export type StoreBeforeDeleteHandler<R extends UnknownRecord> = (record: R, source: 'remote' | 'user') => false | void;
// @public (undocumented)
export type StoreError = {
error: Error;
@ -355,6 +369,9 @@ export type StoreError = {
// @public
export type StoreListener<R extends UnknownRecord> = (entry: HistoryEntry<R>) => void;
// @public (undocumented)
export type StoreOperationCompleteHandler = (source: 'remote' | 'user') => void;
// @public (undocumented)
export class StoreSchema<R extends UnknownRecord, P = unknown> {
// (undocumented)
@ -402,6 +419,59 @@ export type StoreSchemaOptions<R extends UnknownRecord, P> = {
migrations?: MigrationSequence[];
};
// @public
export class StoreSideEffects<R extends UnknownRecord> {
constructor(store: Store<R>);
// @internal (undocumented)
handleAfterChange(prev: R, next: R, source: 'remote' | 'user'): void;
// @internal (undocumented)
handleAfterCreate(record: R, source: 'remote' | 'user'): void;
// @internal (undocumented)
handleAfterDelete(record: R, source: 'remote' | 'user'): void;
// @internal (undocumented)
handleBeforeChange(prev: R, next: R, source: 'remote' | 'user'): R;
// @internal (undocumented)
handleBeforeCreate(record: R, source: 'remote' | 'user'): R;
// @internal (undocumented)
handleBeforeDelete(record: R, source: 'remote' | 'user'): boolean;
// @internal (undocumented)
handleOperationComplete(source: 'remote' | 'user'): void;
// @internal (undocumented)
isEnabled(): boolean;
// @internal
register(handlersByType: {
[T in R as T['typeName']]?: {
afterChange?: StoreAfterChangeHandler<T>;
afterCreate?: StoreAfterCreateHandler<T>;
afterDelete?: StoreAfterDeleteHandler<T>;
beforeChange?: StoreBeforeChangeHandler<T>;
beforeCreate?: StoreBeforeCreateHandler<T>;
beforeDelete?: StoreBeforeDeleteHandler<T>;
};
}): () => void;
registerAfterChangeHandler<T extends R['typeName']>(typeName: T, handler: StoreAfterChangeHandler<R & {
typeName: T;
}>): () => void;
registerAfterCreateHandler<T extends R['typeName']>(typeName: T, handler: StoreAfterCreateHandler<R & {
typeName: T;
}>): () => void;
registerAfterDeleteHandler<T extends R['typeName']>(typeName: T, handler: StoreAfterDeleteHandler<R & {
typeName: T;
}>): () => void;
registerBeforeChangeHandler<T extends R['typeName']>(typeName: T, handler: StoreBeforeChangeHandler<R & {
typeName: T;
}>): () => void;
registerBeforeCreateHandler<T extends R['typeName']>(typeName: T, handler: StoreBeforeCreateHandler<R & {
typeName: T;
}>): () => void;
registerBeforeDeleteHandler<T extends R['typeName']>(typeName: T, handler: StoreBeforeDeleteHandler<R & {
typeName: T;
}>): () => void;
registerOperationCompleteHandler(handler: StoreOperationCompleteHandler): () => void;
// @internal (undocumented)
setIsEnabled(enabled: boolean): void;
}
// @public (undocumented)
export type StoreSnapshot<R extends UnknownRecord> = {
schema: SerializedSchema;

View file

@ -28,6 +28,16 @@ export type {
SerializedSchemaV2,
StoreSchemaOptions,
} from './lib/StoreSchema'
export {
StoreSideEffects,
type StoreAfterChangeHandler,
type StoreAfterCreateHandler,
type StoreAfterDeleteHandler,
type StoreBeforeChangeHandler,
type StoreBeforeCreateHandler,
type StoreBeforeDeleteHandler,
type StoreOperationCompleteHandler,
} from './lib/StoreSideEffects'
export { devFreeze } from './lib/devFreeze'
export {
MigrationFailureReason,

View file

@ -16,6 +16,7 @@ import { RecordScope } from './RecordType'
import { RecordsDiff, squashRecordDiffs } from './RecordsDiff'
import { StoreQueries } from './StoreQueries'
import { SerializedSchema, StoreSchema } from './StoreSchema'
import { StoreSideEffects } from './StoreSideEffects'
import { devFreeze } from './devFreeze'
type RecFromId<K extends RecordId<UnknownRecord>> = K extends RecordId<infer R> ? R : never
@ -160,6 +161,8 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
public readonly scopedTypes: { readonly [K in RecordScope]: ReadonlySet<R['typeName']> }
public readonly sideEffects = new StoreSideEffects<R>(this)
constructor(config: {
id?: string
/** The store's initial data. */
@ -295,55 +298,6 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
this.allRecords().forEach((record) => this.schema.validateRecord(this, record, phase, null))
}
/**
* A callback fired after each record's change.
*
* @param prev - The previous value, if any.
* @param next - The next value.
*/
onBeforeCreate?: (next: R, source: 'remote' | 'user') => R
/**
* A callback fired after a record is created. Use this to perform related updates to other
* records in the store.
*
* @param record - The record to be created
*/
onAfterCreate?: (record: R, source: 'remote' | 'user') => void
/**
* A callback fired before each record's change.
*
* @param prev - The previous value, if any.
* @param next - The next value.
*/
onBeforeChange?: (prev: R, next: R, source: 'remote' | 'user') => R
/**
* A callback fired after each record's change.
*
* @param prev - The previous value, if any.
* @param next - The next value.
*/
onAfterChange?: (prev: R, next: R, source: 'remote' | 'user') => void
/**
* A callback fired before a record is deleted.
*
* @param prev - The record that will be deleted.
*/
onBeforeDelete?: (prev: R, source: 'remote' | 'user') => false | void
/**
* A callback fired after a record is deleted.
*
* @param prev - The record that will be deleted.
*/
onAfterDelete?: (prev: R, source: 'remote' | 'user') => void
// used to avoid running callbacks when rolling back changes in sync client
private _runCallbacks = true
/**
* Add some records to the store. It's an error if they already exist.
*
@ -367,8 +321,6 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
// changes (e.g. additions, deletions, or updates that produce a new value).
let didChange = false
const beforeCreate = this.onBeforeCreate && this._runCallbacks ? this.onBeforeCreate : null
const beforeUpdate = this.onBeforeChange && this._runCallbacks ? this.onBeforeChange : null
const source = this.isMergingRemoteChanges ? 'remote' : 'user'
for (let i = 0, n = records.length; i < n; i++) {
@ -381,7 +333,7 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
const initialValue = recordAtom.__unsafe__getWithoutCapture()
// If we have a beforeUpdate callback, run it against the initial and next records
if (beforeUpdate) record = beforeUpdate(initialValue, record, source)
record = this.sideEffects.handleBeforeChange(initialValue, record, source)
// Validate the record
const validated = this.schema.validateRecord(
@ -398,9 +350,9 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
didChange = true
const updated = recordAtom.__unsafe__getWithoutCapture()
updates[record.id] = [initialValue, updated]
this.addDiffForAfterEvent(initialValue, updated, source)
this.addDiffForAfterEvent(initialValue, updated)
} else {
if (beforeCreate) record = beforeCreate(record, source)
record = this.sideEffects.handleBeforeCreate(record, source)
didChange = true
@ -416,7 +368,7 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
// Mark the change as a new addition.
additions[record.id] = record
this.addDiffForAfterEvent(null, record, source)
this.addDiffForAfterEvent(null, record)
// Assign the atom to the map under the record's id.
if (!map) {
@ -449,16 +401,16 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
*/
remove = (ids: IdOf<R>[]): void => {
this.atomic(() => {
const cancelled = [] as IdOf<R>[]
const cancelled = new Set<IdOf<R>>()
const source = this.isMergingRemoteChanges ? 'remote' : 'user'
if (this.onBeforeDelete && this._runCallbacks) {
if (this.sideEffects.isEnabled()) {
for (const id of ids) {
const atom = this.atoms.__unsafe__getWithoutCapture()[id]
if (!atom) continue
if (this.onBeforeDelete(atom.get(), source) === false) {
cancelled.push(id)
if (this.sideEffects.handleBeforeDelete(atom.get(), source) === false) {
cancelled.add(id)
}
}
}
@ -470,14 +422,14 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
let result: typeof atoms | undefined = undefined
for (const id of ids) {
if (cancelled.includes(id)) continue
if (cancelled.has(id)) continue
if (!(id in atoms)) continue
if (!result) result = { ...atoms }
if (!removed) removed = {} as Record<IdOf<R>, R>
delete result[id]
const record = atoms[id].get()
removed[id] = record
this.addDiffForAfterEvent(record, null, source)
this.addDiffForAfterEvent(record, null)
}
return result ?? atoms
@ -587,16 +539,16 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
throw new Error(`Failed to migrate snapshot: ${migrationResult.reason}`)
}
const prevRunCallbacks = this._runCallbacks
const prevSideEffectsEnabled = this.sideEffects.isEnabled()
try {
this._runCallbacks = false
this.sideEffects.setIsEnabled(false)
this.atomic(() => {
this.clear()
this.put(Object.values(migrationResult.value))
this.ensureStoreIsUsable()
})
} finally {
this._runCallbacks = prevRunCallbacks
this.sideEffects.setIsEnabled(prevSideEffectsEnabled)
}
}
@ -693,6 +645,10 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
return fn()
}
if (this._isInAtomicOp) {
throw new Error('Cannot merge remote changes while in atomic operation')
}
try {
this.isMergingRemoteChanges = true
transact(fn)
@ -844,11 +800,8 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
return this._isPossiblyCorrupted
}
private pendingAfterEvents: Map<
IdOf<R>,
{ before: R | null; after: R | null; source: 'remote' | 'user' }
> | null = null
private addDiffForAfterEvent(before: R | null, after: R | null, source: 'remote' | 'user') {
private pendingAfterEvents: Map<IdOf<R>, { before: R | null; after: R | null }> | null = null
private addDiffForAfterEvent(before: R | null, after: R | null) {
assert(this.pendingAfterEvents, 'must be in event operation')
if (before === after) return
if (before && after) assert(before.id === after.id)
@ -856,34 +809,38 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
const id = (before || after)!.id
const existing = this.pendingAfterEvents.get(id)
if (existing) {
assert(existing.source === source, 'source cannot change within a single event operation')
existing.after = after
} else {
this.pendingAfterEvents.set(id, { before, after, source })
this.pendingAfterEvents.set(id, { before, after })
}
}
private flushAtomicCallbacks() {
let updateDepth = 0
const source = this.isMergingRemoteChanges ? 'remote' : 'user'
while (this.pendingAfterEvents) {
const events = this.pendingAfterEvents
this.pendingAfterEvents = null
if (!this._runCallbacks) continue
if (!this.sideEffects.isEnabled()) continue
updateDepth++
if (updateDepth > 100) {
throw new Error('Maximum store update depth exceeded, bailing out')
}
for (const { before, after, source } of events.values()) {
for (const { before, after } of events.values()) {
if (before && after) {
this.onAfterChange?.(before, after, source)
this.sideEffects.handleAfterChange(before, after, source)
} else if (before && !after) {
this.onAfterDelete?.(before, source)
this.sideEffects.handleAfterDelete(before, source)
} else if (!before && after) {
this.onAfterCreate?.(after, source)
this.sideEffects.handleAfterCreate(after, source)
}
}
if (!this.pendingAfterEvents) {
this.sideEffects.handleOperationComplete(source)
}
}
}
private _isInAtomicOp = false
@ -896,8 +853,8 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
}
this.pendingAfterEvents = new Map()
const prevRunCallbacks = this._runCallbacks
this._runCallbacks = runCallbacks ?? prevRunCallbacks
const prevSideEffectsEnabled = this.sideEffects.isEnabled()
this.sideEffects.setIsEnabled(runCallbacks ?? prevSideEffectsEnabled)
this._isInAtomicOp = true
try {
const result = fn()
@ -907,7 +864,7 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
return result
} finally {
this.pendingAfterEvents = null
this._runCallbacks = prevRunCallbacks
this.sideEffects.setIsEnabled(prevSideEffectsEnabled)
this._isInAtomicOp = false
}
})

View file

@ -1,36 +1,41 @@
import { TLRecord, TLStore } from '@tldraw/tlschema'
import { UnknownRecord } from './BaseRecord'
import { Store } from './Store'
/** @public */
export type TLBeforeCreateHandler<R extends TLRecord> = (record: R, source: 'remote' | 'user') => R
export type StoreBeforeCreateHandler<R extends UnknownRecord> = (
record: R,
source: 'remote' | 'user'
) => R
/** @public */
export type TLAfterCreateHandler<R extends TLRecord> = (
export type StoreAfterCreateHandler<R extends UnknownRecord> = (
record: R,
source: 'remote' | 'user'
) => void
/** @public */
export type TLBeforeChangeHandler<R extends TLRecord> = (
export type StoreBeforeChangeHandler<R extends UnknownRecord> = (
prev: R,
next: R,
source: 'remote' | 'user'
) => R
/** @public */
export type TLAfterChangeHandler<R extends TLRecord> = (
export type StoreAfterChangeHandler<R extends UnknownRecord> = (
prev: R,
next: R,
source: 'remote' | 'user'
) => void
/** @public */
export type TLBeforeDeleteHandler<R extends TLRecord> = (
export type StoreBeforeDeleteHandler<R extends UnknownRecord> = (
record: R,
source: 'remote' | 'user'
) => void | false
/** @public */
export type TLAfterDeleteHandler<R extends TLRecord> = (
export type StoreAfterDeleteHandler<R extends UnknownRecord> = (
record: R,
source: 'remote' | 'user'
) => void
/** @public */
export type TLBatchCompleteHandler = () => void
export type StoreOperationCompleteHandler = (source: 'remote' | 'user') => void
/**
* The side effect manager (aka a "correct state enforcer") is responsible
@ -40,127 +45,131 @@ export type TLBatchCompleteHandler = () => void
*
* @public
*/
export class SideEffectManager<
CTX extends {
store: TLStore
history: { onBatchComplete: () => void }
},
> {
constructor(public editor: CTX) {
editor.store.onBeforeCreate = (record, source) => {
const handlers = this._beforeCreateHandlers[
record.typeName
] as TLBeforeCreateHandler<TLRecord>[]
if (handlers) {
let r = record
for (const handler of handlers) {
r = handler(r, source)
}
return r
}
export class StoreSideEffects<R extends UnknownRecord> {
constructor(private readonly store: Store<R>) {}
return record
private _beforeCreateHandlers: { [K in string]?: StoreBeforeCreateHandler<any>[] } = {}
private _afterCreateHandlers: { [K in string]?: StoreAfterCreateHandler<any>[] } = {}
private _beforeChangeHandlers: { [K in string]?: StoreBeforeChangeHandler<any>[] } = {}
private _afterChangeHandlers: { [K in string]?: StoreAfterChangeHandler<any>[] } = {}
private _beforeDeleteHandlers: { [K in string]?: StoreBeforeDeleteHandler<any>[] } = {}
private _afterDeleteHandlers: { [K in string]?: StoreAfterDeleteHandler<any>[] } = {}
private _operationCompleteHandlers: StoreOperationCompleteHandler[] = []
private _isEnabled = true
/** @internal */
isEnabled() {
return this._isEnabled
}
/** @internal */
setIsEnabled(enabled: boolean) {
this._isEnabled = enabled
}
/** @internal */
handleBeforeCreate(record: R, source: 'remote' | 'user') {
if (!this._isEnabled) return record
const handlers = this._beforeCreateHandlers[record.typeName] as StoreBeforeCreateHandler<R>[]
if (handlers) {
let r = record
for (const handler of handlers) {
r = handler(r, source)
}
return r
}
editor.store.onAfterCreate = (record, source) => {
const handlers = this._afterCreateHandlers[
record.typeName
] as TLAfterCreateHandler<TLRecord>[]
if (handlers) {
for (const handler of handlers) {
handler(record, source)
}
return record
}
/** @internal */
handleAfterCreate(record: R, source: 'remote' | 'user') {
if (!this._isEnabled) return
const handlers = this._afterCreateHandlers[record.typeName] as StoreAfterCreateHandler<R>[]
if (handlers) {
for (const handler of handlers) {
handler(record, source)
}
}
editor.store.onBeforeChange = (prev, next, source) => {
const handlers = this._beforeChangeHandlers[
next.typeName
] as TLBeforeChangeHandler<TLRecord>[]
if (handlers) {
let r = next
for (const handler of handlers) {
r = handler(prev, r, source)
}
return r
}
return next
}
editor.store.onAfterChange = (prev, next, source) => {
const handlers = this._afterChangeHandlers[next.typeName] as TLAfterChangeHandler<TLRecord>[]
if (handlers) {
for (const handler of handlers) {
handler(prev, next, source)
}
}
}
editor.store.onBeforeDelete = (record, source) => {
const handlers = this._beforeDeleteHandlers[
record.typeName
] as TLBeforeDeleteHandler<TLRecord>[]
if (handlers) {
for (const handler of handlers) {
if (handler(record, source) === false) {
return false
}
}
}
}
editor.store.onAfterDelete = (record, source) => {
const handlers = this._afterDeleteHandlers[
record.typeName
] as TLAfterDeleteHandler<TLRecord>[]
if (handlers) {
for (const handler of handlers) {
handler(record, source)
}
}
}
editor.history.onBatchComplete = () => {
this._batchCompleteHandlers.forEach((fn) => fn())
}
}
private _beforeCreateHandlers: Partial<{
[K in TLRecord['typeName']]: TLBeforeCreateHandler<TLRecord & { typeName: K }>[]
}> = {}
private _afterCreateHandlers: Partial<{
[K in TLRecord['typeName']]: TLAfterCreateHandler<TLRecord & { typeName: K }>[]
}> = {}
private _beforeChangeHandlers: Partial<{
[K in TLRecord['typeName']]: TLBeforeChangeHandler<TLRecord & { typeName: K }>[]
}> = {}
private _afterChangeHandlers: Partial<{
[K in TLRecord['typeName']]: TLAfterChangeHandler<TLRecord & { typeName: K }>[]
}> = {}
/** @internal */
handleBeforeChange(prev: R, next: R, source: 'remote' | 'user') {
if (!this._isEnabled) return next
private _beforeDeleteHandlers: Partial<{
[K in TLRecord['typeName']]: TLBeforeDeleteHandler<TLRecord & { typeName: K }>[]
}> = {}
const handlers = this._beforeChangeHandlers[next.typeName] as StoreBeforeChangeHandler<R>[]
if (handlers) {
let r = next
for (const handler of handlers) {
r = handler(prev, r, source)
}
return r
}
private _afterDeleteHandlers: Partial<{
[K in TLRecord['typeName']]: TLAfterDeleteHandler<TLRecord & { typeName: K }>[]
}> = {}
return next
}
private _batchCompleteHandlers: TLBatchCompleteHandler[] = []
/** @internal */
handleAfterChange(prev: R, next: R, source: 'remote' | 'user') {
if (!this._isEnabled) return
const handlers = this._afterChangeHandlers[next.typeName] as StoreAfterChangeHandler<R>[]
if (handlers) {
for (const handler of handlers) {
handler(prev, next, source)
}
}
}
/** @internal */
handleBeforeDelete(record: R, source: 'remote' | 'user') {
if (!this._isEnabled) return true
const handlers = this._beforeDeleteHandlers[record.typeName] as StoreBeforeDeleteHandler<R>[]
if (handlers) {
for (const handler of handlers) {
if (handler(record, source) === false) {
return false
}
}
}
return true
}
/** @internal */
handleAfterDelete(record: R, source: 'remote' | 'user') {
if (!this._isEnabled) return
const handlers = this._afterDeleteHandlers[record.typeName] as StoreAfterDeleteHandler<R>[]
if (handlers) {
for (const handler of handlers) {
handler(record, source)
}
}
}
/** @internal */
handleOperationComplete(source: 'remote' | 'user') {
if (!this._isEnabled) return
for (const handler of this._operationCompleteHandlers) {
handler(source)
}
}
/**
* Internal helper for registering a bunch of side effects at once and keeping them organized.
* @internal
*/
register(handlersByType: {
[R in TLRecord as R['typeName']]?: {
beforeCreate?: TLBeforeCreateHandler<R>
afterCreate?: TLAfterCreateHandler<R>
beforeChange?: TLBeforeChangeHandler<R>
afterChange?: TLAfterChangeHandler<R>
beforeDelete?: TLBeforeDeleteHandler<R>
afterDelete?: TLAfterDeleteHandler<R>
[T in R as T['typeName']]?: {
beforeCreate?: StoreBeforeCreateHandler<T>
afterCreate?: StoreAfterCreateHandler<T>
beforeChange?: StoreBeforeChangeHandler<T>
afterChange?: StoreAfterChangeHandler<T>
beforeDelete?: StoreBeforeDeleteHandler<T>
afterDelete?: StoreAfterDeleteHandler<T>
}
}) {
const disposes: (() => void)[] = []
@ -195,7 +204,7 @@ export class SideEffectManager<
*
* Use this handle only to modify the creation of the record itself. If you want to trigger a
* side-effect on a different record (for example, moving one shape when another is created),
* use {@link SideEffectManager.registerAfterCreateHandler} instead.
* use {@link StoreSideEffects.registerAfterCreateHandler} instead.
*
* @example
* ```ts
@ -216,11 +225,11 @@ export class SideEffectManager<
* @param typeName - The type of record to listen for
* @param handler - The handler to call
*/
registerBeforeCreateHandler<T extends TLRecord['typeName']>(
registerBeforeCreateHandler<T extends R['typeName']>(
typeName: T,
handler: TLBeforeCreateHandler<TLRecord & { typeName: T }>
handler: StoreBeforeCreateHandler<R & { typeName: T }>
) {
const handlers = this._beforeCreateHandlers[typeName] as TLBeforeCreateHandler<any>[]
const handlers = this._beforeCreateHandlers[typeName] as StoreBeforeCreateHandler<any>[]
if (!handlers) this._beforeCreateHandlers[typeName] = []
this._beforeCreateHandlers[typeName]!.push(handler)
return () => remove(this._beforeCreateHandlers[typeName]!, handler)
@ -229,7 +238,7 @@ export class SideEffectManager<
/**
* Register a handler to be called after a record is created. This is useful for side-effects
* that would update _other_ records. If you want to modify the record being created use
* {@link SideEffectManager.registerBeforeCreateHandler} instead.
* {@link StoreSideEffects.registerBeforeCreateHandler} instead.
*
* @example
* ```ts
@ -246,11 +255,11 @@ export class SideEffectManager<
* @param typeName - The type of record to listen for
* @param handler - The handler to call
*/
registerAfterCreateHandler<T extends TLRecord['typeName']>(
registerAfterCreateHandler<T extends R['typeName']>(
typeName: T,
handler: TLAfterCreateHandler<TLRecord & { typeName: T }>
handler: StoreAfterCreateHandler<R & { typeName: T }>
) {
const handlers = this._afterCreateHandlers[typeName] as TLAfterCreateHandler<any>[]
const handlers = this._afterCreateHandlers[typeName] as StoreAfterCreateHandler<any>[]
if (!handlers) this._afterCreateHandlers[typeName] = []
this._afterCreateHandlers[typeName]!.push(handler)
return () => remove(this._afterCreateHandlers[typeName]!, handler)
@ -263,7 +272,7 @@ export class SideEffectManager<
*
* Use this handler only for intercepting updates to the record itself. If you want to update
* other records in response to a change, use
* {@link SideEffectManager.registerAfterChangeHandler} instead.
* {@link StoreSideEffects.registerAfterChangeHandler} instead.
*
* @example
* ```ts
@ -280,11 +289,11 @@ export class SideEffectManager<
* @param typeName - The type of record to listen for
* @param handler - The handler to call
*/
registerBeforeChangeHandler<T extends TLRecord['typeName']>(
registerBeforeChangeHandler<T extends R['typeName']>(
typeName: T,
handler: TLBeforeChangeHandler<TLRecord & { typeName: T }>
handler: StoreBeforeChangeHandler<R & { typeName: T }>
) {
const handlers = this._beforeChangeHandlers[typeName] as TLBeforeChangeHandler<any>[]
const handlers = this._beforeChangeHandlers[typeName] as StoreBeforeChangeHandler<any>[]
if (!handlers) this._beforeChangeHandlers[typeName] = []
this._beforeChangeHandlers[typeName]!.push(handler)
return () => remove(this._beforeChangeHandlers[typeName]!, handler)
@ -293,7 +302,7 @@ export class SideEffectManager<
/**
* Register a handler to be called after a record is changed. This is useful for side-effects
* that would update _other_ records - if you want to modify the record being changed, use
* {@link SideEffectManager.registerBeforeChangeHandler} instead.
* {@link StoreSideEffects.registerBeforeChangeHandler} instead.
*
* @example
* ```ts
@ -309,13 +318,13 @@ export class SideEffectManager<
* @param typeName - The type of record to listen for
* @param handler - The handler to call
*/
registerAfterChangeHandler<T extends TLRecord['typeName']>(
registerAfterChangeHandler<T extends R['typeName']>(
typeName: T,
handler: TLAfterChangeHandler<TLRecord & { typeName: T }>
handler: StoreAfterChangeHandler<R & { typeName: T }>
) {
const handlers = this._afterChangeHandlers[typeName] as TLAfterChangeHandler<any>[]
const handlers = this._afterChangeHandlers[typeName] as StoreAfterChangeHandler<any>[]
if (!handlers) this._afterChangeHandlers[typeName] = []
this._afterChangeHandlers[typeName]!.push(handler as TLAfterChangeHandler<any>)
this._afterChangeHandlers[typeName]!.push(handler as StoreAfterChangeHandler<any>)
return () => remove(this._afterChangeHandlers[typeName]!, handler)
}
@ -325,7 +334,7 @@ export class SideEffectManager<
*
* Use this handler only for intercepting deletions of the record itself. If you want to do
* something to other records in response to a deletion, use
* {@link SideEffectManager.registerAfterDeleteHandler} instead.
* {@link StoreSideEffects.registerAfterDeleteHandler} instead.
*
* @example
* ```ts
@ -340,20 +349,20 @@ export class SideEffectManager<
* @param typeName - The type of record to listen for
* @param handler - The handler to call
*/
registerBeforeDeleteHandler<T extends TLRecord['typeName']>(
registerBeforeDeleteHandler<T extends R['typeName']>(
typeName: T,
handler: TLBeforeDeleteHandler<TLRecord & { typeName: T }>
handler: StoreBeforeDeleteHandler<R & { typeName: T }>
) {
const handlers = this._beforeDeleteHandlers[typeName] as TLBeforeDeleteHandler<any>[]
const handlers = this._beforeDeleteHandlers[typeName] as StoreBeforeDeleteHandler<any>[]
if (!handlers) this._beforeDeleteHandlers[typeName] = []
this._beforeDeleteHandlers[typeName]!.push(handler as TLBeforeDeleteHandler<any>)
this._beforeDeleteHandlers[typeName]!.push(handler as StoreBeforeDeleteHandler<any>)
return () => remove(this._beforeDeleteHandlers[typeName]!, handler)
}
/**
* Register a handler to be called after a record is deleted. This is useful for side-effects
* that would update _other_ records - if you want to block the deletion of the record itself,
* use {@link SideEffectManager.registerBeforeDeleteHandler} instead.
* use {@link StoreSideEffects.registerBeforeDeleteHandler} instead.
*
* @example
* ```ts
@ -372,29 +381,29 @@ export class SideEffectManager<
* @param typeName - The type of record to listen for
* @param handler - The handler to call
*/
registerAfterDeleteHandler<T extends TLRecord['typeName']>(
registerAfterDeleteHandler<T extends R['typeName']>(
typeName: T,
handler: TLAfterDeleteHandler<TLRecord & { typeName: T }>
handler: StoreAfterDeleteHandler<R & { typeName: T }>
) {
const handlers = this._afterDeleteHandlers[typeName] as TLAfterDeleteHandler<any>[]
const handlers = this._afterDeleteHandlers[typeName] as StoreAfterDeleteHandler<any>[]
if (!handlers) this._afterDeleteHandlers[typeName] = []
this._afterDeleteHandlers[typeName]!.push(handler as TLAfterDeleteHandler<any>)
this._afterDeleteHandlers[typeName]!.push(handler as StoreAfterDeleteHandler<any>)
return () => remove(this._afterDeleteHandlers[typeName]!, handler)
}
/**
* Register a handler to be called when a store completes a batch.
* Register a handler to be called when a store completes an atomic operation.
*
* @example
* ```ts
* let count = 0
*
* editor.cleanup.registerBatchCompleteHandler(() => count++)
* editor.sideEffects.registerOperationCompleteHandler(() => count++)
*
* editor.selectAll()
* expect(count).toBe(1)
*
* editor.batch(() => {
* editor.store.atomic(() => {
* editor.selectNone()
* editor.selectAll()
* })
@ -406,9 +415,9 @@ export class SideEffectManager<
*
* @public
*/
registerBatchCompleteHandler(handler: TLBatchCompleteHandler) {
this._batchCompleteHandlers.push(handler)
return () => remove(this._batchCompleteHandlers, handler)
registerOperationCompleteHandler(handler: StoreOperationCompleteHandler) {
this._operationCompleteHandlers.push(handler)
return () => remove(this._operationCompleteHandlers, handler)
}
}

View file

@ -206,40 +206,41 @@ describe('Store', () => {
it('allows adding onAfterChange callbacks that see the final state of the world', () => {
/* ADDING */
store.onAfterCreate = jest.fn((current) => {
const onAfterCreate = jest.fn((current) => {
expect(current).toEqual(
Author.create({ name: 'J.R.R Tolkein', id: Author.createId('tolkein') })
)
expect([...store.query.ids('author').get()]).toEqual([Author.createId('tolkein')])
})
store.sideEffects.registerAfterCreateHandler('author', onAfterCreate)
store.put([Author.create({ name: 'J.R.R Tolkein', id: Author.createId('tolkein') })])
expect(store.onAfterCreate).toHaveBeenCalledTimes(1)
expect(onAfterCreate).toHaveBeenCalledTimes(1)
/* UPDATING */
store.onAfterChange = jest.fn((prev, current) => {
if (prev.typeName === 'author' && current.typeName === 'author') {
expect(prev.name).toBe('J.R.R Tolkein')
expect(current.name).toBe('Butch Cassidy')
const onAfterChange = jest.fn((prev, current) => {
expect(prev.name).toBe('J.R.R Tolkein')
expect(current.name).toBe('Butch Cassidy')
expect(store.get(Author.createId('tolkein'))!.name).toBe('Butch Cassidy')
}
expect(store.get(Author.createId('tolkein'))!.name).toBe('Butch Cassidy')
})
store.sideEffects.registerAfterChangeHandler('author', onAfterChange)
store.update(Author.createId('tolkein'), (r) => ({ ...r, name: 'Butch Cassidy' }))
expect(store.onAfterChange).toHaveBeenCalledTimes(1)
expect(onAfterChange).toHaveBeenCalledTimes(1)
/* REMOVING */
store.onAfterDelete = jest.fn((prev) => {
const onAfterDelete = jest.fn((prev) => {
if (prev.typeName === 'author') {
expect(prev.name).toBe('Butch Cassidy')
}
})
store.sideEffects.registerAfterDeleteHandler('author', onAfterDelete)
store.remove([Author.createId('tolkein')])
expect(store.onAfterDelete).toHaveBeenCalledTimes(1)
expect(onAfterDelete).toHaveBeenCalledTimes(1)
})
it('allows finding and filtering records with a predicate', () => {
@ -1076,76 +1077,137 @@ describe('diffs', () => {
})
describe('after callbacks', () => {
let store: Store<LibraryType>
let store: Store<Book>
let callbacks: any[] = []
const authorId = Author.createId('tolkein')
const bookId = Book.createId('hobbit')
const book1Id = Book.createId('darkness')
const book1 = Book.create({
title: 'the left hand of darkness',
id: book1Id,
author: Author.createId('ursula'),
numPages: 1,
})
const book2Id = Book.createId('dispossessed')
const book2 = Book.create({
title: 'the dispossessed',
id: book2Id,
author: Author.createId('ursula'),
numPages: 1,
})
let onAfterCreate: jest.Mock
let onAfterChange: jest.Mock
let onAfterDelete: jest.Mock
let onOperationComplete: jest.Mock
beforeEach(() => {
store = new Store({
props: {},
schema: StoreSchema.create<LibraryType>({
schema: StoreSchema.create<Book>({
book: Book,
author: Author,
visit: Visit,
}),
})
store.onAfterCreate = jest.fn((record) => callbacks.push({ type: 'create', record }))
store.onAfterChange = jest.fn((from, to) => callbacks.push({ type: 'change', from, to }))
store.onAfterDelete = jest.fn((record) => callbacks.push({ type: 'delete', record }))
onAfterCreate = jest.fn((record) => callbacks.push({ type: 'create', record }))
onAfterChange = jest.fn((from, to) => callbacks.push({ type: 'change', from, to }))
onAfterDelete = jest.fn((record) => callbacks.push({ type: 'delete', record }))
onOperationComplete = jest.fn(() => callbacks.push({ type: 'complete' }))
callbacks = []
store.sideEffects.registerAfterCreateHandler('book', onAfterCreate)
store.sideEffects.registerAfterChangeHandler('book', onAfterChange)
store.sideEffects.registerAfterDeleteHandler('book', onAfterDelete)
store.sideEffects.registerOperationCompleteHandler(onOperationComplete)
})
it('fires callbacks at the end of an `atomic` op', () => {
store.atomic(() => {
expect(callbacks).toHaveLength(0)
store.put([
Author.create({ name: 'J.R.R Tolkein', id: authorId }),
Book.create({ title: 'The Hobbit', id: bookId, author: authorId, numPages: 300 }),
])
store.put([book1, book2])
expect(callbacks).toHaveLength(0)
})
expect(callbacks).toMatchObject([
{ type: 'create', record: { id: authorId } },
{ type: 'create', record: { id: bookId } },
{ type: 'create', record: { id: book1Id } },
{ type: 'create', record: { id: book2Id } },
{ type: 'complete' },
])
})
it('doesnt fire callback for a record created then deleted', () => {
store.atomic(() => {
store.put([Author.create({ name: 'J.R.R Tolkein', id: authorId })])
store.remove([authorId])
store.put([book1])
store.remove([book1Id])
})
expect(callbacks).toHaveLength(0)
expect(callbacks).toMatchObject([{ type: 'complete' }])
})
it('bails out if too many callbacks are fired', () => {
let limit = 10
store.onAfterCreate = (record) => {
if (record.typeName === 'book' && record.numPages < limit) {
onAfterCreate.mockImplementation((record) => {
if (record.numPages < limit) {
store.put([{ ...record, numPages: record.numPages + 1 }])
}
}
store.onAfterChange = (from, to) => {
if (to.typeName === 'book' && to.numPages < limit) {
})
onAfterChange.mockImplementation((from, to) => {
if (to.numPages < limit) {
store.put([{ ...to, numPages: to.numPages + 1 }])
}
}
})
// this should be fine:
store.put([Book.create({ title: 'The Hobbit', id: bookId, author: authorId, numPages: 0 })])
expect(store.get(bookId)!.numPages).toBe(limit)
store.put([book1])
expect(store.get(book1Id)!.numPages).toBe(limit)
// if we increase the limit thought, it should crash:
limit = 10000
store.clear()
expect(() => {
store.put([Book.create({ title: 'The Hobbit', id: bookId, author: authorId, numPages: 0 })])
store.put([book2])
}).toThrowErrorMatchingInlineSnapshot(`"Maximum store update depth exceeded, bailing out"`)
})
it('keeps firing operation complete callbacks until all are cleared', () => {
// steps:
// 0, 1, 2: after change increment pages
// 3: after change, do nothing
// 4: operation complete, increment pages by 1000
// 5, 6: after change increment pages
// 7: after change, do nothing
// 8: operation complete, do nothing
// 9: done!
let step = 0
store.put([book1])
onAfterChange.mockImplementation((prev, next) => {
if ([0, 1, 2, 5, 6].includes(step)) {
step++
store.put([{ ...next, numPages: next.numPages + 1 }])
} else if ([3, 7].includes(step)) {
step++
} else {
throw new Error(`Wrong step: ${step}`)
}
})
onOperationComplete.mockImplementation(() => {
if (step === 4) {
step++
const book = store.get(book1Id)!
store.put([{ ...book, numPages: book.numPages + 1000 }])
} else if (step === 8) {
step++
} else {
throw new Error(`Wrong step: ${step}`)
}
})
store.put([{ ...book1, numPages: 2 }])
expect(store.get(book1Id)!.numPages).toBe(1007)
expect(step).toBe(9)
})
})

View file

@ -333,12 +333,10 @@ function runTest(seed: number) {
author: Author,
}),
})
store.onBeforeDelete = (record) => {
if (record.typeName === 'author') {
const books = store.query.index('book', 'authorId').get().get(record.id)
if (books) store.remove([...books])
}
}
store.sideEffects.registerBeforeDeleteHandler('author', (record) => {
const books = store.query.index('book', 'authorId').get().get(record.id)
if (books) store.remove([...books])
})
const getRandomNumber = rng(seed)
const authorNameIndex = store.query.index('author', 'name')
const authorIdIndex = store.query.index('book', 'authorId')