diff --git a/apps/examples/src/examples/pin-bindings/PinExample.tsx b/apps/examples/src/examples/pin-bindings/PinExample.tsx new file mode 100644 index 000000000..e6b8431dd --- /dev/null +++ b/apps/examples/src/examples/pin-bindings/PinExample.tsx @@ -0,0 +1,335 @@ +import { + BindingOnShapeChangeOptions, + BindingOnShapeDeleteOptions, + BindingUtil, + Box, + DefaultFillStyle, + DefaultToolbar, + DefaultToolbarContent, + RecordProps, + Rectangle2d, + ShapeUtil, + StateNode, + TLBaseBinding, + TLBaseShape, + TLEditorComponents, + TLEventHandlers, + TLOnTranslateEndHandler, + TLOnTranslateStartHandler, + TLShapeId, + TLUiComponents, + TLUiOverrides, + Tldraw, + TldrawUiMenuItem, + Vec, + VecModel, + createShapeId, + invLerp, + lerp, + useIsToolSelected, + useTools, +} from 'tldraw' + +// eslint-disable-next-line @typescript-eslint/ban-types +type PinShape = TLBaseShape<'pin', {}> + +const offsetX = -16 +const offsetY = -26 +class PinShapeUtil extends ShapeUtil { + static override type = 'pin' as const + static override props: RecordProps = {} + + override getDefaultProps() { + return {} + } + + override canBind = () => false + override canEdit = () => false + override canResize = () => false + override hideRotateHandle = () => true + override isAspectRatioLocked = () => true + + override getGeometry() { + return new Rectangle2d({ + width: 32, + height: 32, + x: offsetX, + y: offsetY, + isFilled: true, + }) + } + + override component() { + return ( +
+ 📍 +
+ ) + } + + override indicator() { + return + } + + override onTranslateStart: TLOnTranslateStartHandler = (shape) => { + const bindings = this.editor.getBindingsFromShape(shape, 'pin') + this.editor.deleteBindings(bindings) + } + + override onTranslateEnd: TLOnTranslateEndHandler = (initial, pin) => { + const pageAnchor = this.editor.getShapePageTransform(pin).applyToPoint({ x: 0, y: 0 }) + + const targets = this.editor + .getShapesAtPoint(pageAnchor, { hitInside: true }) + .filter( + (shape) => + shape.type !== 'pin' && shape.parentId === pin.parentId && shape.index < pin.index + ) + + for (const target of targets) { + const targetBounds = Box.ZeroFix(this.editor.getShapeGeometry(target)!.bounds) + const pointInTargetSpace = this.editor.getPointInShapeSpace(target, pageAnchor) + + const anchor = { + x: invLerp(targetBounds.minX, targetBounds.maxX, pointInTargetSpace.x), + y: invLerp(targetBounds.minY, targetBounds.maxY, pointInTargetSpace.y), + } + + this.editor.createBinding({ + type: 'pin', + fromId: pin.id, + toId: target.id, + props: { + anchor, + }, + }) + } + } +} + +type PinBinding = TLBaseBinding< + 'pin', + { + anchor: VecModel + } +> +class PinBindingUtil extends BindingUtil { + static override type = 'pin' as const + + override getDefaultProps() { + return { + anchor: { x: 0.5, y: 0.5 }, + } + } + + private changedToShapes = new Set() + + override onOperationComplete(): void { + if (this.changedToShapes.size === 0) return + + const fixedShapes = this.changedToShapes + const toCheck = [...this.changedToShapes] + + const initialPositions = new Map() + const targetDeltas = new Map>() + + const addTargetDelta = (fromId: TLShapeId, toId: TLShapeId, delta: VecModel) => { + if (!targetDeltas.has(fromId)) targetDeltas.set(fromId, new Map()) + targetDeltas.get(fromId)!.set(toId, delta) + + if (!targetDeltas.has(toId)) targetDeltas.set(toId, new Map()) + targetDeltas.get(toId)!.set(fromId, { x: -delta.x, y: -delta.y }) + } + + const allShapes = new Set() + while (toCheck.length) { + const shapeId = toCheck.pop()! + + const shape = this.editor.getShape(shapeId) + if (!shape) continue + + if (allShapes.has(shapeId)) continue + allShapes.add(shapeId) + + const bindings = this.editor.getBindingsToShape(shape, 'pin') + for (const binding of bindings) { + if (allShapes.has(binding.fromId)) continue + allShapes.add(binding.fromId) + + const pin = this.editor.getShape(binding.fromId) + if (!pin) continue + + const pinPosition = this.editor.getShapePageTransform(pin).applyToPoint({ x: 0, y: 0 }) + initialPositions.set(pin.id, pinPosition) + + for (const binding of this.editor.getBindingsFromShape(pin.id, 'pin')) { + const shapeBounds = this.editor.getShapeGeometry(binding.toId)!.bounds + const shapeAnchor = { + x: lerp(shapeBounds.minX, shapeBounds.maxX, binding.props.anchor.x), + y: lerp(shapeBounds.minY, shapeBounds.maxY, binding.props.anchor.y), + } + const currentPageAnchor = this.editor + .getShapePageTransform(binding.toId) + .applyToPoint(shapeAnchor) + + const shapeOrigin = this.editor + .getShapePageTransform(binding.toId) + .applyToPoint({ x: 0, y: 0 }) + initialPositions.set(binding.toId, shapeOrigin) + + addTargetDelta(pin.id, binding.toId, { + x: currentPageAnchor.x - shapeOrigin.x, + y: currentPageAnchor.y - shapeOrigin.y, + }) + + if (!allShapes.has(binding.toId)) toCheck.push(binding.toId) + } + } + } + + const currentPositions = new Map(initialPositions) + + const iterations = 30 + for (let i = 0; i < iterations; i++) { + const movements = new Map() + for (const [aId, deltas] of targetDeltas) { + if (fixedShapes.has(aId)) continue + const aPosition = currentPositions.get(aId)! + for (const [bId, targetDelta] of deltas) { + const bPosition = currentPositions.get(bId)! + + const adjustmentDelta = { + x: targetDelta.x - (aPosition.x - bPosition.x), + y: targetDelta.y - (aPosition.y - bPosition.y), + } + + if (!movements.has(aId)) movements.set(aId, []) + movements.get(aId)!.push(adjustmentDelta) + } + } + + for (const [shapeId, deltas] of movements) { + const currentPosition = currentPositions.get(shapeId)! + currentPositions.set(shapeId, Vec.Average(deltas).add(currentPosition)) + } + } + + const updates = [] + for (const [shapeId, position] of currentPositions) { + const delta = Vec.Sub(position, initialPositions.get(shapeId)!) + if (delta.len2() <= 0.01) continue + + const newPosition = this.editor.getPointInParentSpace(shapeId, position) + updates.push({ + id: shapeId, + type: this.editor.getShape(shapeId)!.type, + x: newPosition.x, + y: newPosition.y, + }) + } + + if (updates.length === 0) { + this.changedToShapes.clear() + } else { + this.editor.updateShapes(updates) + } + } + + // when the shape we're stuck to changes, update the pin's position + override onAfterChangeToShape({ binding }: BindingOnShapeChangeOptions): void { + this.changedToShapes.add(binding.toId) + } + + // when the thing we're stuck to is deleted, delete the pin too + override onBeforeDeleteToShape({ binding }: BindingOnShapeDeleteOptions): void { + const pin = this.editor.getShape(binding.fromId) + if (pin) this.editor.deleteShape(pin.id) + } +} + +class PinTool extends StateNode { + static override id = 'pin' + + override onEnter = () => { + this.editor.setCursor({ type: 'cross', rotation: 0 }) + } + + override onPointerDown: TLEventHandlers['onPointerDown'] = (info) => { + const { currentPagePoint } = this.editor.inputs + const pinId = createShapeId() + this.editor.mark(`creating:${pinId}`) + this.editor.createShape({ + id: pinId, + type: 'pin', + x: currentPagePoint.x, + y: currentPagePoint.y, + }) + this.editor.setSelectedShapes([pinId]) + this.editor.setCurrentTool('select.translating', { + ...info, + target: 'shape', + shape: this.editor.getShape(pinId), + isCreating: true, + onInteractionEnd: 'pin', + onCreate: () => { + this.editor.setCurrentTool('pin') + }, + }) + } +} + +const overrides: TLUiOverrides = { + tools(editor, schema) { + schema['pin'] = { + id: 'pin', + label: 'Pin', + icon: 'heart-icon', + kbd: 'p', + onSelect: () => { + editor.setCurrentTool('pin') + }, + } + return schema + }, +} + +const components: TLUiComponents & TLEditorComponents = { + Toolbar: (...props) => { + const pin = useTools().pin + const isPinSelected = useIsToolSelected(pin) + return ( + + + + + ) + }, +} + +export default function PinExample() { + return ( +
+ { + ;(window as any).editor = editor + editor.setStyleForNextShapes(DefaultFillStyle, 'semi') + }} + shapeUtils={[PinShapeUtil]} + bindingUtils={[PinBindingUtil]} + tools={[PinTool]} + overrides={overrides} + components={components} + /> +
+ ) +} diff --git a/apps/examples/src/examples/pin-bindings/README.md b/apps/examples/src/examples/pin-bindings/README.md new file mode 100644 index 000000000..17a8d5183 --- /dev/null +++ b/apps/examples/src/examples/pin-bindings/README.md @@ -0,0 +1,9 @@ +--- +title: Pin (bindings) +component: ./PinExample.tsx +category: shapes/tools +--- + +A pin, using bindings to pin together networks of shapes. + +--- diff --git a/packages/editor/api-report.md b/packages/editor/api-report.md index 28ec3a425..674692ca4 100644 --- a/packages/editor/api-report.md +++ b/packages/editor/api-report.md @@ -37,6 +37,7 @@ import { SerializedStore } from '@tldraw/store'; import { Signal } from '@tldraw/state'; import { Store } from '@tldraw/store'; import { StoreSchema } from '@tldraw/store'; +import { StoreSideEffects } from '@tldraw/store'; import { StoreSnapshot } from '@tldraw/store'; import { StyleProp } from '@tldraw/tlschema'; import { StylePropValue } from '@tldraw/tlschema'; @@ -239,6 +240,8 @@ export abstract class BindingUtil): void; // (undocumented) + onOperationComplete?(): void; + // (undocumented) static props?: RecordProps; static type: string; } @@ -1013,7 +1016,7 @@ export class Editor extends EventEmitter { shapeUtils: { readonly [K in string]?: ShapeUtil; }; - readonly sideEffects: SideEffectManager; + readonly sideEffects: StoreSideEffects; slideCamera(opts?: { direction: VecLike; friction: number; @@ -1853,48 +1856,6 @@ export class SharedStyleMap extends ReadonlySharedStyleMap { // @public export function shortAngleDist(a0: number, a1: number): number; -// @public -export class SideEffectManager void; - }; - store: TLStore; -}> { - constructor(editor: CTX); - // (undocumented) - editor: CTX; - // @internal - register(handlersByType: { - [R in TLRecord as R['typeName']]?: { - afterChange?: TLAfterChangeHandler; - afterCreate?: TLAfterCreateHandler; - afterDelete?: TLAfterDeleteHandler; - beforeChange?: TLBeforeChangeHandler; - beforeCreate?: TLBeforeCreateHandler; - beforeDelete?: TLBeforeDeleteHandler; - }; - }): () => void; - registerAfterChangeHandler(typeName: T, handler: TLAfterChangeHandler): () => void; - registerAfterCreateHandler(typeName: T, handler: TLAfterCreateHandler): () => void; - registerAfterDeleteHandler(typeName: T, handler: TLAfterDeleteHandler): () => void; - registerBatchCompleteHandler(handler: TLBatchCompleteHandler): () => void; - registerBeforeChangeHandler(typeName: T, handler: TLBeforeChangeHandler): () => void; - registerBeforeCreateHandler(typeName: T, handler: TLBeforeCreateHandler): () => void; - registerBeforeDeleteHandler(typeName: T, handler: TLBeforeDeleteHandler): () => void; -} - // @public (undocumented) export const SIDES: readonly ["top", "right", "bottom", "left"]; @@ -2058,15 +2019,6 @@ export interface SvgExportDef { // @public export const TAB_ID: string; -// @public (undocumented) -export type TLAfterChangeHandler = (prev: R, next: R, source: 'remote' | 'user') => void; - -// @public (undocumented) -export type TLAfterCreateHandler = (record: R, source: 'remote' | 'user') => void; - -// @public (undocumented) -export type TLAfterDeleteHandler = (record: R, source: 'remote' | 'user') => void; - // @public (undocumented) export type TLAnyBindingUtilConstructor = TLBindingUtilConstructor; @@ -2091,18 +2043,6 @@ export interface TLBaseEventInfo { type: UiEventType; } -// @public (undocumented) -export type TLBatchCompleteHandler = () => void; - -// @public (undocumented) -export type TLBeforeChangeHandler = (prev: R, next: R, source: 'remote' | 'user') => R; - -// @public (undocumented) -export type TLBeforeCreateHandler = (record: R, source: 'remote' | 'user') => R; - -// @public (undocumented) -export type TLBeforeDeleteHandler = (record: R, source: 'remote' | 'user') => false | void; - // @public (undocumented) export interface TLBindingUtilConstructor = BindingUtil> { // (undocumented) diff --git a/packages/editor/src/index.ts b/packages/editor/src/index.ts index 4ef5a26ea..87d751ba7 100644 --- a/packages/editor/src/index.ts +++ b/packages/editor/src/index.ts @@ -133,16 +133,6 @@ export { type TLBindingUtilConstructor, } from './lib/editor/bindings/BindingUtil' export { HistoryManager } from './lib/editor/managers/HistoryManager' -export type { - SideEffectManager, - TLAfterChangeHandler, - TLAfterCreateHandler, - TLAfterDeleteHandler, - TLBatchCompleteHandler, - TLBeforeChangeHandler, - TLBeforeCreateHandler, - TLBeforeDeleteHandler, -} from './lib/editor/managers/SideEffectManager' export { type BoundsSnapGeometry, type BoundsSnapPoint, diff --git a/packages/editor/src/lib/editor/Editor.ts b/packages/editor/src/lib/editor/Editor.ts index 7d9092173..5453ef912 100644 --- a/packages/editor/src/lib/editor/Editor.ts +++ b/packages/editor/src/lib/editor/Editor.ts @@ -2,6 +2,7 @@ import { EMPTY_ARRAY, atom, computed, transact } from '@tldraw/state' import { ComputedCache, RecordType, + StoreSideEffects, StoreSnapshot, UnknownRecord, reverseRecordsDiff, @@ -127,7 +128,6 @@ import { ClickManager } from './managers/ClickManager' import { EnvironmentManager } from './managers/EnvironmentManager' import { HistoryManager } from './managers/HistoryManager' import { ScribbleManager } from './managers/ScribbleManager' -import { SideEffectManager } from './managers/SideEffectManager' import { SnapManager } from './managers/SnapManager/SnapManager' import { TextManager } from './managers/TextManager' import { TickManager } from './managers/TickManager' @@ -300,8 +300,6 @@ export class Editor extends EventEmitter { // Cleanup - const invalidParents = new Set() - const cleanupInstancePageState = ( prevPageState: TLInstancePageState, shapesNoLongerInPage: Set @@ -349,10 +347,12 @@ export class Editor extends EventEmitter { return nextPageState } - this.sideEffects = new SideEffectManager(this) + this.sideEffects = this.store.sideEffects + const invalidParents = new Set() + let invalidBindingTypes = new Set() this.disposables.add( - this.sideEffects.registerBatchCompleteHandler(() => { + this.sideEffects.registerOperationCompleteHandler(() => { for (const parentId of invalidParents) { invalidParents.delete(parentId) const parent = this.getShape(parentId) @@ -366,6 +366,15 @@ export class Editor extends EventEmitter { } } + if (invalidBindingTypes.size) { + const t = invalidBindingTypes + invalidBindingTypes = new Set() + for (const type of t) { + const util = this.getBindingUtil(type) + util.onOperationComplete?.() + } + } + this.emit('update') }) ) @@ -375,6 +384,7 @@ export class Editor extends EventEmitter { shape: { afterChange: (shapeBefore, shapeAfter) => { for (const binding of this.getBindingsInvolvingShape(shapeAfter)) { + invalidBindingTypes.add(binding.type) if (binding.fromId === shapeAfter.id) { this.getBindingUtil(binding).onAfterChangeFromShape?.({ binding, @@ -398,6 +408,8 @@ export class Editor extends EventEmitter { if (!descendantShape) return for (const binding of this.getBindingsInvolvingShape(descendantShape)) { + invalidBindingTypes.add(binding.type) + if (binding.fromId === descendantShape.id) { this.getBindingUtil(binding).onAfterChangeFromShape?.({ binding, @@ -451,6 +463,7 @@ export class Editor extends EventEmitter { const deleteBindingIds: TLBindingId[] = [] for (const binding of this.getBindingsInvolvingShape(shape)) { + invalidBindingTypes.add(binding.type) if (binding.fromId === shape.id) { this.getBindingUtil(binding).onBeforeDeleteFromShape?.({ binding, shape }) deleteBindingIds.push(binding.id) @@ -481,6 +494,7 @@ export class Editor extends EventEmitter { return binding }, afterCreate: (binding) => { + invalidBindingTypes.add(binding.type) this.getBindingUtil(binding).onAfterCreate?.({ binding }) }, beforeChange: (bindingBefore, bindingAfter) => { @@ -492,12 +506,14 @@ export class Editor extends EventEmitter { return bindingAfter }, afterChange: (bindingBefore, bindingAfter) => { + invalidBindingTypes.add(bindingAfter.type) this.getBindingUtil(bindingAfter).onAfterChange?.({ bindingBefore, bindingAfter }) }, beforeDelete: (binding) => { this.getBindingUtil(binding).onBeforeDelete?.({ binding }) }, afterDelete: (binding) => { + invalidBindingTypes.add(binding.type) this.getBindingUtil(binding).onAfterDelete?.({ binding }) }, }, @@ -705,11 +721,11 @@ export class Editor extends EventEmitter { readonly scribbles: ScribbleManager /** - * A manager for side effects and correct state enforcement. See {@link SideEffectManager} for details. + * A manager for side effects and correct state enforcement. See {@link @tldraw/store#StoreSideEffects} for details. * * @public */ - readonly sideEffects: SideEffectManager + readonly sideEffects: StoreSideEffects /** * The current HTML element containing the editor. diff --git a/packages/editor/src/lib/editor/bindings/BindingUtil.ts b/packages/editor/src/lib/editor/bindings/BindingUtil.ts index 3289e408c..a8b9f5454 100644 --- a/packages/editor/src/lib/editor/bindings/BindingUtil.ts +++ b/packages/editor/src/lib/editor/bindings/BindingUtil.ts @@ -61,6 +61,8 @@ export abstract class BindingUtil + onOperationComplete?(): void + // self lifecycle hooks onBeforeCreate?(options: BindingOnCreateOptions): Binding | void onAfterCreate?(options: BindingOnCreateOptions): void diff --git a/packages/editor/tsconfig.json b/packages/editor/tsconfig.json index 1ed85b35b..f19f69e45 100644 --- a/packages/editor/tsconfig.json +++ b/packages/editor/tsconfig.json @@ -1,6 +1,5 @@ { "extends": "../../config/tsconfig.base.json", - "include": ["src"], "exclude": ["node_modules", "dist", "**/*.css", ".tsbuild*"], "compilerOptions": { "outDir": "./.tsbuild", diff --git a/packages/store/api-report.md b/packages/store/api-report.md index eb829a135..c60d9fafb 100644 --- a/packages/store/api-report.md +++ b/packages/store/api-report.md @@ -319,12 +319,6 @@ export class Store { markAsPossiblyCorrupted(): void; mergeRemoteChanges: (fn: () => void) => void; migrateSnapshot(snapshot: StoreSnapshot): StoreSnapshot; - onAfterChange?: (prev: R, next: R, source: 'remote' | 'user') => void; - onAfterCreate?: (record: R, source: 'remote' | 'user') => void; - onAfterDelete?: (prev: R, source: 'remote' | 'user') => void; - onBeforeChange?: (prev: R, next: R, source: 'remote' | 'user') => R; - onBeforeCreate?: (next: R, source: 'remote' | 'user') => R; - onBeforeDelete?: (prev: R, source: 'remote' | 'user') => false | void; // (undocumented) readonly props: Props; put: (records: R[], phaseOverride?: 'initialize') => void; @@ -337,12 +331,32 @@ export class Store { readonly [K in RecordScope]: ReadonlySet; }; serialize: (scope?: 'all' | RecordScope) => SerializedStore; + // (undocumented) + readonly sideEffects: StoreSideEffects; unsafeGetWithoutCapture: >(id: K) => RecFromId | undefined; update: >(id: K, updater: (record: RecFromId) => RecFromId) => void; // (undocumented) validate(phase: 'createRecord' | 'initialize' | 'tests' | 'updateRecord'): void; } +// @public (undocumented) +export type StoreAfterChangeHandler = (prev: R, next: R, source: 'remote' | 'user') => void; + +// @public (undocumented) +export type StoreAfterCreateHandler = (record: R, source: 'remote' | 'user') => void; + +// @public (undocumented) +export type StoreAfterDeleteHandler = (record: R, source: 'remote' | 'user') => void; + +// @public (undocumented) +export type StoreBeforeChangeHandler = (prev: R, next: R, source: 'remote' | 'user') => R; + +// @public (undocumented) +export type StoreBeforeCreateHandler = (record: R, source: 'remote' | 'user') => R; + +// @public (undocumented) +export type StoreBeforeDeleteHandler = (record: R, source: 'remote' | 'user') => false | void; + // @public (undocumented) export type StoreError = { error: Error; @@ -355,6 +369,9 @@ export type StoreError = { // @public export type StoreListener = (entry: HistoryEntry) => void; +// @public (undocumented) +export type StoreOperationCompleteHandler = (source: 'remote' | 'user') => void; + // @public (undocumented) export class StoreSchema { // (undocumented) @@ -402,6 +419,59 @@ export type StoreSchemaOptions = { migrations?: MigrationSequence[]; }; +// @public +export class StoreSideEffects { + constructor(store: Store); + // @internal (undocumented) + handleAfterChange(prev: R, next: R, source: 'remote' | 'user'): void; + // @internal (undocumented) + handleAfterCreate(record: R, source: 'remote' | 'user'): void; + // @internal (undocumented) + handleAfterDelete(record: R, source: 'remote' | 'user'): void; + // @internal (undocumented) + handleBeforeChange(prev: R, next: R, source: 'remote' | 'user'): R; + // @internal (undocumented) + handleBeforeCreate(record: R, source: 'remote' | 'user'): R; + // @internal (undocumented) + handleBeforeDelete(record: R, source: 'remote' | 'user'): boolean; + // @internal (undocumented) + handleOperationComplete(source: 'remote' | 'user'): void; + // @internal (undocumented) + isEnabled(): boolean; + // @internal + register(handlersByType: { + [T in R as T['typeName']]?: { + afterChange?: StoreAfterChangeHandler; + afterCreate?: StoreAfterCreateHandler; + afterDelete?: StoreAfterDeleteHandler; + beforeChange?: StoreBeforeChangeHandler; + beforeCreate?: StoreBeforeCreateHandler; + beforeDelete?: StoreBeforeDeleteHandler; + }; + }): () => void; + registerAfterChangeHandler(typeName: T, handler: StoreAfterChangeHandler): () => void; + registerAfterCreateHandler(typeName: T, handler: StoreAfterCreateHandler): () => void; + registerAfterDeleteHandler(typeName: T, handler: StoreAfterDeleteHandler): () => void; + registerBeforeChangeHandler(typeName: T, handler: StoreBeforeChangeHandler): () => void; + registerBeforeCreateHandler(typeName: T, handler: StoreBeforeCreateHandler): () => void; + registerBeforeDeleteHandler(typeName: T, handler: StoreBeforeDeleteHandler): () => void; + registerOperationCompleteHandler(handler: StoreOperationCompleteHandler): () => void; + // @internal (undocumented) + setIsEnabled(enabled: boolean): void; +} + // @public (undocumented) export type StoreSnapshot = { schema: SerializedSchema; diff --git a/packages/store/src/index.ts b/packages/store/src/index.ts index cfb301c5d..2023bea4e 100644 --- a/packages/store/src/index.ts +++ b/packages/store/src/index.ts @@ -28,6 +28,16 @@ export type { SerializedSchemaV2, StoreSchemaOptions, } from './lib/StoreSchema' +export { + StoreSideEffects, + type StoreAfterChangeHandler, + type StoreAfterCreateHandler, + type StoreAfterDeleteHandler, + type StoreBeforeChangeHandler, + type StoreBeforeCreateHandler, + type StoreBeforeDeleteHandler, + type StoreOperationCompleteHandler, +} from './lib/StoreSideEffects' export { devFreeze } from './lib/devFreeze' export { MigrationFailureReason, diff --git a/packages/store/src/lib/Store.ts b/packages/store/src/lib/Store.ts index 6bffe9489..ace5e9217 100644 --- a/packages/store/src/lib/Store.ts +++ b/packages/store/src/lib/Store.ts @@ -16,6 +16,7 @@ import { RecordScope } from './RecordType' import { RecordsDiff, squashRecordDiffs } from './RecordsDiff' import { StoreQueries } from './StoreQueries' import { SerializedSchema, StoreSchema } from './StoreSchema' +import { StoreSideEffects } from './StoreSideEffects' import { devFreeze } from './devFreeze' type RecFromId> = K extends RecordId ? R : never @@ -160,6 +161,8 @@ export class Store { public readonly scopedTypes: { readonly [K in RecordScope]: ReadonlySet } + public readonly sideEffects = new StoreSideEffects(this) + constructor(config: { id?: string /** The store's initial data. */ @@ -295,55 +298,6 @@ export class Store { this.allRecords().forEach((record) => this.schema.validateRecord(this, record, phase, null)) } - /** - * A callback fired after each record's change. - * - * @param prev - The previous value, if any. - * @param next - The next value. - */ - onBeforeCreate?: (next: R, source: 'remote' | 'user') => R - - /** - * A callback fired after a record is created. Use this to perform related updates to other - * records in the store. - * - * @param record - The record to be created - */ - onAfterCreate?: (record: R, source: 'remote' | 'user') => void - - /** - * A callback fired before each record's change. - * - * @param prev - The previous value, if any. - * @param next - The next value. - */ - onBeforeChange?: (prev: R, next: R, source: 'remote' | 'user') => R - - /** - * A callback fired after each record's change. - * - * @param prev - The previous value, if any. - * @param next - The next value. - */ - onAfterChange?: (prev: R, next: R, source: 'remote' | 'user') => void - - /** - * A callback fired before a record is deleted. - * - * @param prev - The record that will be deleted. - */ - onBeforeDelete?: (prev: R, source: 'remote' | 'user') => false | void - - /** - * A callback fired after a record is deleted. - * - * @param prev - The record that will be deleted. - */ - onAfterDelete?: (prev: R, source: 'remote' | 'user') => void - - // used to avoid running callbacks when rolling back changes in sync client - private _runCallbacks = true - /** * Add some records to the store. It's an error if they already exist. * @@ -367,8 +321,6 @@ export class Store { // changes (e.g. additions, deletions, or updates that produce a new value). let didChange = false - const beforeCreate = this.onBeforeCreate && this._runCallbacks ? this.onBeforeCreate : null - const beforeUpdate = this.onBeforeChange && this._runCallbacks ? this.onBeforeChange : null const source = this.isMergingRemoteChanges ? 'remote' : 'user' for (let i = 0, n = records.length; i < n; i++) { @@ -381,7 +333,7 @@ export class Store { const initialValue = recordAtom.__unsafe__getWithoutCapture() // If we have a beforeUpdate callback, run it against the initial and next records - if (beforeUpdate) record = beforeUpdate(initialValue, record, source) + record = this.sideEffects.handleBeforeChange(initialValue, record, source) // Validate the record const validated = this.schema.validateRecord( @@ -398,9 +350,9 @@ export class Store { didChange = true const updated = recordAtom.__unsafe__getWithoutCapture() updates[record.id] = [initialValue, updated] - this.addDiffForAfterEvent(initialValue, updated, source) + this.addDiffForAfterEvent(initialValue, updated) } else { - if (beforeCreate) record = beforeCreate(record, source) + record = this.sideEffects.handleBeforeCreate(record, source) didChange = true @@ -416,7 +368,7 @@ export class Store { // Mark the change as a new addition. additions[record.id] = record - this.addDiffForAfterEvent(null, record, source) + this.addDiffForAfterEvent(null, record) // Assign the atom to the map under the record's id. if (!map) { @@ -449,16 +401,16 @@ export class Store { */ remove = (ids: IdOf[]): void => { this.atomic(() => { - const cancelled = [] as IdOf[] + const cancelled = new Set>() const source = this.isMergingRemoteChanges ? 'remote' : 'user' - if (this.onBeforeDelete && this._runCallbacks) { + if (this.sideEffects.isEnabled()) { for (const id of ids) { const atom = this.atoms.__unsafe__getWithoutCapture()[id] if (!atom) continue - if (this.onBeforeDelete(atom.get(), source) === false) { - cancelled.push(id) + if (this.sideEffects.handleBeforeDelete(atom.get(), source) === false) { + cancelled.add(id) } } } @@ -470,14 +422,14 @@ export class Store { let result: typeof atoms | undefined = undefined for (const id of ids) { - if (cancelled.includes(id)) continue + if (cancelled.has(id)) continue if (!(id in atoms)) continue if (!result) result = { ...atoms } if (!removed) removed = {} as Record, R> delete result[id] const record = atoms[id].get() removed[id] = record - this.addDiffForAfterEvent(record, null, source) + this.addDiffForAfterEvent(record, null) } return result ?? atoms @@ -587,16 +539,16 @@ export class Store { throw new Error(`Failed to migrate snapshot: ${migrationResult.reason}`) } - const prevRunCallbacks = this._runCallbacks + const prevSideEffectsEnabled = this.sideEffects.isEnabled() try { - this._runCallbacks = false + this.sideEffects.setIsEnabled(false) this.atomic(() => { this.clear() this.put(Object.values(migrationResult.value)) this.ensureStoreIsUsable() }) } finally { - this._runCallbacks = prevRunCallbacks + this.sideEffects.setIsEnabled(prevSideEffectsEnabled) } } @@ -693,6 +645,10 @@ export class Store { return fn() } + if (this._isInAtomicOp) { + throw new Error('Cannot merge remote changes while in atomic operation') + } + try { this.isMergingRemoteChanges = true transact(fn) @@ -844,11 +800,8 @@ export class Store { return this._isPossiblyCorrupted } - private pendingAfterEvents: Map< - IdOf, - { before: R | null; after: R | null; source: 'remote' | 'user' } - > | null = null - private addDiffForAfterEvent(before: R | null, after: R | null, source: 'remote' | 'user') { + private pendingAfterEvents: Map, { before: R | null; after: R | null }> | null = null + private addDiffForAfterEvent(before: R | null, after: R | null) { assert(this.pendingAfterEvents, 'must be in event operation') if (before === after) return if (before && after) assert(before.id === after.id) @@ -856,34 +809,38 @@ export class Store { const id = (before || after)!.id const existing = this.pendingAfterEvents.get(id) if (existing) { - assert(existing.source === source, 'source cannot change within a single event operation') existing.after = after } else { - this.pendingAfterEvents.set(id, { before, after, source }) + this.pendingAfterEvents.set(id, { before, after }) } } private flushAtomicCallbacks() { let updateDepth = 0 + const source = this.isMergingRemoteChanges ? 'remote' : 'user' while (this.pendingAfterEvents) { const events = this.pendingAfterEvents this.pendingAfterEvents = null - if (!this._runCallbacks) continue + if (!this.sideEffects.isEnabled()) continue updateDepth++ if (updateDepth > 100) { throw new Error('Maximum store update depth exceeded, bailing out') } - for (const { before, after, source } of events.values()) { + for (const { before, after } of events.values()) { if (before && after) { - this.onAfterChange?.(before, after, source) + this.sideEffects.handleAfterChange(before, after, source) } else if (before && !after) { - this.onAfterDelete?.(before, source) + this.sideEffects.handleAfterDelete(before, source) } else if (!before && after) { - this.onAfterCreate?.(after, source) + this.sideEffects.handleAfterCreate(after, source) } } + + if (!this.pendingAfterEvents) { + this.sideEffects.handleOperationComplete(source) + } } } private _isInAtomicOp = false @@ -896,8 +853,8 @@ export class Store { } this.pendingAfterEvents = new Map() - const prevRunCallbacks = this._runCallbacks - this._runCallbacks = runCallbacks ?? prevRunCallbacks + const prevSideEffectsEnabled = this.sideEffects.isEnabled() + this.sideEffects.setIsEnabled(runCallbacks ?? prevSideEffectsEnabled) this._isInAtomicOp = true try { const result = fn() @@ -907,7 +864,7 @@ export class Store { return result } finally { this.pendingAfterEvents = null - this._runCallbacks = prevRunCallbacks + this.sideEffects.setIsEnabled(prevSideEffectsEnabled) this._isInAtomicOp = false } }) diff --git a/packages/editor/src/lib/editor/managers/SIdeEffectManager.test.ts b/packages/store/src/lib/StoreSideEffects.test.ts similarity index 100% rename from packages/editor/src/lib/editor/managers/SIdeEffectManager.test.ts rename to packages/store/src/lib/StoreSideEffects.test.ts diff --git a/packages/editor/src/lib/editor/managers/SideEffectManager.ts b/packages/store/src/lib/StoreSideEffects.ts similarity index 56% rename from packages/editor/src/lib/editor/managers/SideEffectManager.ts rename to packages/store/src/lib/StoreSideEffects.ts index 51d742bf9..b6c27a705 100644 --- a/packages/editor/src/lib/editor/managers/SideEffectManager.ts +++ b/packages/store/src/lib/StoreSideEffects.ts @@ -1,36 +1,41 @@ -import { TLRecord, TLStore } from '@tldraw/tlschema' +import { UnknownRecord } from './BaseRecord' +import { Store } from './Store' /** @public */ -export type TLBeforeCreateHandler = (record: R, source: 'remote' | 'user') => R +export type StoreBeforeCreateHandler = ( + record: R, + source: 'remote' | 'user' +) => R /** @public */ -export type TLAfterCreateHandler = ( +export type StoreAfterCreateHandler = ( record: R, source: 'remote' | 'user' ) => void /** @public */ -export type TLBeforeChangeHandler = ( +export type StoreBeforeChangeHandler = ( prev: R, next: R, source: 'remote' | 'user' ) => R /** @public */ -export type TLAfterChangeHandler = ( +export type StoreAfterChangeHandler = ( prev: R, next: R, source: 'remote' | 'user' ) => void /** @public */ -export type TLBeforeDeleteHandler = ( +export type StoreBeforeDeleteHandler = ( record: R, source: 'remote' | 'user' ) => void | false /** @public */ -export type TLAfterDeleteHandler = ( +export type StoreAfterDeleteHandler = ( record: R, source: 'remote' | 'user' ) => void + /** @public */ -export type TLBatchCompleteHandler = () => void +export type StoreOperationCompleteHandler = (source: 'remote' | 'user') => void /** * The side effect manager (aka a "correct state enforcer") is responsible @@ -40,127 +45,131 @@ export type TLBatchCompleteHandler = () => void * * @public */ -export class SideEffectManager< - CTX extends { - store: TLStore - history: { onBatchComplete: () => void } - }, -> { - constructor(public editor: CTX) { - editor.store.onBeforeCreate = (record, source) => { - const handlers = this._beforeCreateHandlers[ - record.typeName - ] as TLBeforeCreateHandler[] - if (handlers) { - let r = record - for (const handler of handlers) { - r = handler(r, source) - } - return r - } +export class StoreSideEffects { + constructor(private readonly store: Store) {} - return record + private _beforeCreateHandlers: { [K in string]?: StoreBeforeCreateHandler[] } = {} + private _afterCreateHandlers: { [K in string]?: StoreAfterCreateHandler[] } = {} + private _beforeChangeHandlers: { [K in string]?: StoreBeforeChangeHandler[] } = {} + private _afterChangeHandlers: { [K in string]?: StoreAfterChangeHandler[] } = {} + private _beforeDeleteHandlers: { [K in string]?: StoreBeforeDeleteHandler[] } = {} + private _afterDeleteHandlers: { [K in string]?: StoreAfterDeleteHandler[] } = {} + private _operationCompleteHandlers: StoreOperationCompleteHandler[] = [] + + private _isEnabled = true + /** @internal */ + isEnabled() { + return this._isEnabled + } + /** @internal */ + setIsEnabled(enabled: boolean) { + this._isEnabled = enabled + } + + /** @internal */ + handleBeforeCreate(record: R, source: 'remote' | 'user') { + if (!this._isEnabled) return record + + const handlers = this._beforeCreateHandlers[record.typeName] as StoreBeforeCreateHandler[] + if (handlers) { + let r = record + for (const handler of handlers) { + r = handler(r, source) + } + return r } - editor.store.onAfterCreate = (record, source) => { - const handlers = this._afterCreateHandlers[ - record.typeName - ] as TLAfterCreateHandler[] - if (handlers) { - for (const handler of handlers) { - handler(record, source) - } + return record + } + + /** @internal */ + handleAfterCreate(record: R, source: 'remote' | 'user') { + if (!this._isEnabled) return + + const handlers = this._afterCreateHandlers[record.typeName] as StoreAfterCreateHandler[] + if (handlers) { + for (const handler of handlers) { + handler(record, source) } } - - editor.store.onBeforeChange = (prev, next, source) => { - const handlers = this._beforeChangeHandlers[ - next.typeName - ] as TLBeforeChangeHandler[] - if (handlers) { - let r = next - for (const handler of handlers) { - r = handler(prev, r, source) - } - return r - } - - return next - } - - editor.store.onAfterChange = (prev, next, source) => { - const handlers = this._afterChangeHandlers[next.typeName] as TLAfterChangeHandler[] - if (handlers) { - for (const handler of handlers) { - handler(prev, next, source) - } - } - } - - editor.store.onBeforeDelete = (record, source) => { - const handlers = this._beforeDeleteHandlers[ - record.typeName - ] as TLBeforeDeleteHandler[] - if (handlers) { - for (const handler of handlers) { - if (handler(record, source) === false) { - return false - } - } - } - } - - editor.store.onAfterDelete = (record, source) => { - const handlers = this._afterDeleteHandlers[ - record.typeName - ] as TLAfterDeleteHandler[] - if (handlers) { - for (const handler of handlers) { - handler(record, source) - } - } - } - - editor.history.onBatchComplete = () => { - this._batchCompleteHandlers.forEach((fn) => fn()) - } } - private _beforeCreateHandlers: Partial<{ - [K in TLRecord['typeName']]: TLBeforeCreateHandler[] - }> = {} - private _afterCreateHandlers: Partial<{ - [K in TLRecord['typeName']]: TLAfterCreateHandler[] - }> = {} - private _beforeChangeHandlers: Partial<{ - [K in TLRecord['typeName']]: TLBeforeChangeHandler[] - }> = {} - private _afterChangeHandlers: Partial<{ - [K in TLRecord['typeName']]: TLAfterChangeHandler[] - }> = {} + /** @internal */ + handleBeforeChange(prev: R, next: R, source: 'remote' | 'user') { + if (!this._isEnabled) return next - private _beforeDeleteHandlers: Partial<{ - [K in TLRecord['typeName']]: TLBeforeDeleteHandler[] - }> = {} + const handlers = this._beforeChangeHandlers[next.typeName] as StoreBeforeChangeHandler[] + if (handlers) { + let r = next + for (const handler of handlers) { + r = handler(prev, r, source) + } + return r + } - private _afterDeleteHandlers: Partial<{ - [K in TLRecord['typeName']]: TLAfterDeleteHandler[] - }> = {} + return next + } - private _batchCompleteHandlers: TLBatchCompleteHandler[] = [] + /** @internal */ + handleAfterChange(prev: R, next: R, source: 'remote' | 'user') { + if (!this._isEnabled) return + + const handlers = this._afterChangeHandlers[next.typeName] as StoreAfterChangeHandler[] + if (handlers) { + for (const handler of handlers) { + handler(prev, next, source) + } + } + } + + /** @internal */ + handleBeforeDelete(record: R, source: 'remote' | 'user') { + if (!this._isEnabled) return true + + const handlers = this._beforeDeleteHandlers[record.typeName] as StoreBeforeDeleteHandler[] + if (handlers) { + for (const handler of handlers) { + if (handler(record, source) === false) { + return false + } + } + } + return true + } + + /** @internal */ + handleAfterDelete(record: R, source: 'remote' | 'user') { + if (!this._isEnabled) return + + const handlers = this._afterDeleteHandlers[record.typeName] as StoreAfterDeleteHandler[] + if (handlers) { + for (const handler of handlers) { + handler(record, source) + } + } + } + + /** @internal */ + handleOperationComplete(source: 'remote' | 'user') { + if (!this._isEnabled) return + + for (const handler of this._operationCompleteHandlers) { + handler(source) + } + } /** * Internal helper for registering a bunch of side effects at once and keeping them organized. * @internal */ register(handlersByType: { - [R in TLRecord as R['typeName']]?: { - beforeCreate?: TLBeforeCreateHandler - afterCreate?: TLAfterCreateHandler - beforeChange?: TLBeforeChangeHandler - afterChange?: TLAfterChangeHandler - beforeDelete?: TLBeforeDeleteHandler - afterDelete?: TLAfterDeleteHandler + [T in R as T['typeName']]?: { + beforeCreate?: StoreBeforeCreateHandler + afterCreate?: StoreAfterCreateHandler + beforeChange?: StoreBeforeChangeHandler + afterChange?: StoreAfterChangeHandler + beforeDelete?: StoreBeforeDeleteHandler + afterDelete?: StoreAfterDeleteHandler } }) { const disposes: (() => void)[] = [] @@ -195,7 +204,7 @@ export class SideEffectManager< * * Use this handle only to modify the creation of the record itself. If you want to trigger a * side-effect on a different record (for example, moving one shape when another is created), - * use {@link SideEffectManager.registerAfterCreateHandler} instead. + * use {@link StoreSideEffects.registerAfterCreateHandler} instead. * * @example * ```ts @@ -216,11 +225,11 @@ export class SideEffectManager< * @param typeName - The type of record to listen for * @param handler - The handler to call */ - registerBeforeCreateHandler( + registerBeforeCreateHandler( typeName: T, - handler: TLBeforeCreateHandler + handler: StoreBeforeCreateHandler ) { - const handlers = this._beforeCreateHandlers[typeName] as TLBeforeCreateHandler[] + const handlers = this._beforeCreateHandlers[typeName] as StoreBeforeCreateHandler[] if (!handlers) this._beforeCreateHandlers[typeName] = [] this._beforeCreateHandlers[typeName]!.push(handler) return () => remove(this._beforeCreateHandlers[typeName]!, handler) @@ -229,7 +238,7 @@ export class SideEffectManager< /** * Register a handler to be called after a record is created. This is useful for side-effects * that would update _other_ records. If you want to modify the record being created use - * {@link SideEffectManager.registerBeforeCreateHandler} instead. + * {@link StoreSideEffects.registerBeforeCreateHandler} instead. * * @example * ```ts @@ -246,11 +255,11 @@ export class SideEffectManager< * @param typeName - The type of record to listen for * @param handler - The handler to call */ - registerAfterCreateHandler( + registerAfterCreateHandler( typeName: T, - handler: TLAfterCreateHandler + handler: StoreAfterCreateHandler ) { - const handlers = this._afterCreateHandlers[typeName] as TLAfterCreateHandler[] + const handlers = this._afterCreateHandlers[typeName] as StoreAfterCreateHandler[] if (!handlers) this._afterCreateHandlers[typeName] = [] this._afterCreateHandlers[typeName]!.push(handler) return () => remove(this._afterCreateHandlers[typeName]!, handler) @@ -263,7 +272,7 @@ export class SideEffectManager< * * Use this handler only for intercepting updates to the record itself. If you want to update * other records in response to a change, use - * {@link SideEffectManager.registerAfterChangeHandler} instead. + * {@link StoreSideEffects.registerAfterChangeHandler} instead. * * @example * ```ts @@ -280,11 +289,11 @@ export class SideEffectManager< * @param typeName - The type of record to listen for * @param handler - The handler to call */ - registerBeforeChangeHandler( + registerBeforeChangeHandler( typeName: T, - handler: TLBeforeChangeHandler + handler: StoreBeforeChangeHandler ) { - const handlers = this._beforeChangeHandlers[typeName] as TLBeforeChangeHandler[] + const handlers = this._beforeChangeHandlers[typeName] as StoreBeforeChangeHandler[] if (!handlers) this._beforeChangeHandlers[typeName] = [] this._beforeChangeHandlers[typeName]!.push(handler) return () => remove(this._beforeChangeHandlers[typeName]!, handler) @@ -293,7 +302,7 @@ export class SideEffectManager< /** * Register a handler to be called after a record is changed. This is useful for side-effects * that would update _other_ records - if you want to modify the record being changed, use - * {@link SideEffectManager.registerBeforeChangeHandler} instead. + * {@link StoreSideEffects.registerBeforeChangeHandler} instead. * * @example * ```ts @@ -309,13 +318,13 @@ export class SideEffectManager< * @param typeName - The type of record to listen for * @param handler - The handler to call */ - registerAfterChangeHandler( + registerAfterChangeHandler( typeName: T, - handler: TLAfterChangeHandler + handler: StoreAfterChangeHandler ) { - const handlers = this._afterChangeHandlers[typeName] as TLAfterChangeHandler[] + const handlers = this._afterChangeHandlers[typeName] as StoreAfterChangeHandler[] if (!handlers) this._afterChangeHandlers[typeName] = [] - this._afterChangeHandlers[typeName]!.push(handler as TLAfterChangeHandler) + this._afterChangeHandlers[typeName]!.push(handler as StoreAfterChangeHandler) return () => remove(this._afterChangeHandlers[typeName]!, handler) } @@ -325,7 +334,7 @@ export class SideEffectManager< * * Use this handler only for intercepting deletions of the record itself. If you want to do * something to other records in response to a deletion, use - * {@link SideEffectManager.registerAfterDeleteHandler} instead. + * {@link StoreSideEffects.registerAfterDeleteHandler} instead. * * @example * ```ts @@ -340,20 +349,20 @@ export class SideEffectManager< * @param typeName - The type of record to listen for * @param handler - The handler to call */ - registerBeforeDeleteHandler( + registerBeforeDeleteHandler( typeName: T, - handler: TLBeforeDeleteHandler + handler: StoreBeforeDeleteHandler ) { - const handlers = this._beforeDeleteHandlers[typeName] as TLBeforeDeleteHandler[] + const handlers = this._beforeDeleteHandlers[typeName] as StoreBeforeDeleteHandler[] if (!handlers) this._beforeDeleteHandlers[typeName] = [] - this._beforeDeleteHandlers[typeName]!.push(handler as TLBeforeDeleteHandler) + this._beforeDeleteHandlers[typeName]!.push(handler as StoreBeforeDeleteHandler) return () => remove(this._beforeDeleteHandlers[typeName]!, handler) } /** * Register a handler to be called after a record is deleted. This is useful for side-effects * that would update _other_ records - if you want to block the deletion of the record itself, - * use {@link SideEffectManager.registerBeforeDeleteHandler} instead. + * use {@link StoreSideEffects.registerBeforeDeleteHandler} instead. * * @example * ```ts @@ -372,29 +381,29 @@ export class SideEffectManager< * @param typeName - The type of record to listen for * @param handler - The handler to call */ - registerAfterDeleteHandler( + registerAfterDeleteHandler( typeName: T, - handler: TLAfterDeleteHandler + handler: StoreAfterDeleteHandler ) { - const handlers = this._afterDeleteHandlers[typeName] as TLAfterDeleteHandler[] + const handlers = this._afterDeleteHandlers[typeName] as StoreAfterDeleteHandler[] if (!handlers) this._afterDeleteHandlers[typeName] = [] - this._afterDeleteHandlers[typeName]!.push(handler as TLAfterDeleteHandler) + this._afterDeleteHandlers[typeName]!.push(handler as StoreAfterDeleteHandler) return () => remove(this._afterDeleteHandlers[typeName]!, handler) } /** - * Register a handler to be called when a store completes a batch. + * Register a handler to be called when a store completes an atomic operation. * * @example * ```ts * let count = 0 * - * editor.cleanup.registerBatchCompleteHandler(() => count++) + * editor.sideEffects.registerOperationCompleteHandler(() => count++) * * editor.selectAll() * expect(count).toBe(1) * - * editor.batch(() => { + * editor.store.atomic(() => { * editor.selectNone() * editor.selectAll() * }) @@ -406,9 +415,9 @@ export class SideEffectManager< * * @public */ - registerBatchCompleteHandler(handler: TLBatchCompleteHandler) { - this._batchCompleteHandlers.push(handler) - return () => remove(this._batchCompleteHandlers, handler) + registerOperationCompleteHandler(handler: StoreOperationCompleteHandler) { + this._operationCompleteHandlers.push(handler) + return () => remove(this._operationCompleteHandlers, handler) } } diff --git a/packages/store/src/lib/test/recordStore.test.ts b/packages/store/src/lib/test/recordStore.test.ts index 6366e840d..f08166b04 100644 --- a/packages/store/src/lib/test/recordStore.test.ts +++ b/packages/store/src/lib/test/recordStore.test.ts @@ -206,40 +206,41 @@ describe('Store', () => { it('allows adding onAfterChange callbacks that see the final state of the world', () => { /* ADDING */ - store.onAfterCreate = jest.fn((current) => { + const onAfterCreate = jest.fn((current) => { expect(current).toEqual( Author.create({ name: 'J.R.R Tolkein', id: Author.createId('tolkein') }) ) expect([...store.query.ids('author').get()]).toEqual([Author.createId('tolkein')]) }) + store.sideEffects.registerAfterCreateHandler('author', onAfterCreate) store.put([Author.create({ name: 'J.R.R Tolkein', id: Author.createId('tolkein') })]) - expect(store.onAfterCreate).toHaveBeenCalledTimes(1) + expect(onAfterCreate).toHaveBeenCalledTimes(1) /* UPDATING */ - store.onAfterChange = jest.fn((prev, current) => { - if (prev.typeName === 'author' && current.typeName === 'author') { - expect(prev.name).toBe('J.R.R Tolkein') - expect(current.name).toBe('Butch Cassidy') + const onAfterChange = jest.fn((prev, current) => { + expect(prev.name).toBe('J.R.R Tolkein') + expect(current.name).toBe('Butch Cassidy') - expect(store.get(Author.createId('tolkein'))!.name).toBe('Butch Cassidy') - } + expect(store.get(Author.createId('tolkein'))!.name).toBe('Butch Cassidy') }) + store.sideEffects.registerAfterChangeHandler('author', onAfterChange) store.update(Author.createId('tolkein'), (r) => ({ ...r, name: 'Butch Cassidy' })) - expect(store.onAfterChange).toHaveBeenCalledTimes(1) + expect(onAfterChange).toHaveBeenCalledTimes(1) /* REMOVING */ - store.onAfterDelete = jest.fn((prev) => { + const onAfterDelete = jest.fn((prev) => { if (prev.typeName === 'author') { expect(prev.name).toBe('Butch Cassidy') } }) + store.sideEffects.registerAfterDeleteHandler('author', onAfterDelete) store.remove([Author.createId('tolkein')]) - expect(store.onAfterDelete).toHaveBeenCalledTimes(1) + expect(onAfterDelete).toHaveBeenCalledTimes(1) }) it('allows finding and filtering records with a predicate', () => { @@ -1076,76 +1077,137 @@ describe('diffs', () => { }) describe('after callbacks', () => { - let store: Store + let store: Store let callbacks: any[] = [] - const authorId = Author.createId('tolkein') - const bookId = Book.createId('hobbit') + const book1Id = Book.createId('darkness') + const book1 = Book.create({ + title: 'the left hand of darkness', + id: book1Id, + author: Author.createId('ursula'), + numPages: 1, + }) + const book2Id = Book.createId('dispossessed') + const book2 = Book.create({ + title: 'the dispossessed', + id: book2Id, + author: Author.createId('ursula'), + numPages: 1, + }) + + let onAfterCreate: jest.Mock + let onAfterChange: jest.Mock + let onAfterDelete: jest.Mock + let onOperationComplete: jest.Mock beforeEach(() => { store = new Store({ props: {}, - schema: StoreSchema.create({ + schema: StoreSchema.create({ book: Book, - author: Author, - visit: Visit, }), }) - store.onAfterCreate = jest.fn((record) => callbacks.push({ type: 'create', record })) - store.onAfterChange = jest.fn((from, to) => callbacks.push({ type: 'change', from, to })) - store.onAfterDelete = jest.fn((record) => callbacks.push({ type: 'delete', record })) + onAfterCreate = jest.fn((record) => callbacks.push({ type: 'create', record })) + onAfterChange = jest.fn((from, to) => callbacks.push({ type: 'change', from, to })) + onAfterDelete = jest.fn((record) => callbacks.push({ type: 'delete', record })) + onOperationComplete = jest.fn(() => callbacks.push({ type: 'complete' })) callbacks = [] + + store.sideEffects.registerAfterCreateHandler('book', onAfterCreate) + store.sideEffects.registerAfterChangeHandler('book', onAfterChange) + store.sideEffects.registerAfterDeleteHandler('book', onAfterDelete) + store.sideEffects.registerOperationCompleteHandler(onOperationComplete) }) it('fires callbacks at the end of an `atomic` op', () => { store.atomic(() => { expect(callbacks).toHaveLength(0) - store.put([ - Author.create({ name: 'J.R.R Tolkein', id: authorId }), - Book.create({ title: 'The Hobbit', id: bookId, author: authorId, numPages: 300 }), - ]) + store.put([book1, book2]) expect(callbacks).toHaveLength(0) }) expect(callbacks).toMatchObject([ - { type: 'create', record: { id: authorId } }, - { type: 'create', record: { id: bookId } }, + { type: 'create', record: { id: book1Id } }, + { type: 'create', record: { id: book2Id } }, + { type: 'complete' }, ]) }) it('doesnt fire callback for a record created then deleted', () => { store.atomic(() => { - store.put([Author.create({ name: 'J.R.R Tolkein', id: authorId })]) - store.remove([authorId]) + store.put([book1]) + store.remove([book1Id]) }) - expect(callbacks).toHaveLength(0) + expect(callbacks).toMatchObject([{ type: 'complete' }]) }) it('bails out if too many callbacks are fired', () => { let limit = 10 - store.onAfterCreate = (record) => { - if (record.typeName === 'book' && record.numPages < limit) { + onAfterCreate.mockImplementation((record) => { + if (record.numPages < limit) { store.put([{ ...record, numPages: record.numPages + 1 }]) } - } - store.onAfterChange = (from, to) => { - if (to.typeName === 'book' && to.numPages < limit) { + }) + onAfterChange.mockImplementation((from, to) => { + if (to.numPages < limit) { store.put([{ ...to, numPages: to.numPages + 1 }]) } - } + }) // this should be fine: - store.put([Book.create({ title: 'The Hobbit', id: bookId, author: authorId, numPages: 0 })]) - expect(store.get(bookId)!.numPages).toBe(limit) + store.put([book1]) + expect(store.get(book1Id)!.numPages).toBe(limit) // if we increase the limit thought, it should crash: limit = 10000 store.clear() expect(() => { - store.put([Book.create({ title: 'The Hobbit', id: bookId, author: authorId, numPages: 0 })]) + store.put([book2]) }).toThrowErrorMatchingInlineSnapshot(`"Maximum store update depth exceeded, bailing out"`) }) + + it('keeps firing operation complete callbacks until all are cleared', () => { + // steps: + // 0, 1, 2: after change increment pages + // 3: after change, do nothing + // 4: operation complete, increment pages by 1000 + // 5, 6: after change increment pages + // 7: after change, do nothing + // 8: operation complete, do nothing + // 9: done! + let step = 0 + + store.put([book1]) + + onAfterChange.mockImplementation((prev, next) => { + if ([0, 1, 2, 5, 6].includes(step)) { + step++ + store.put([{ ...next, numPages: next.numPages + 1 }]) + } else if ([3, 7].includes(step)) { + step++ + } else { + throw new Error(`Wrong step: ${step}`) + } + }) + + onOperationComplete.mockImplementation(() => { + if (step === 4) { + step++ + const book = store.get(book1Id)! + store.put([{ ...book, numPages: book.numPages + 1000 }]) + } else if (step === 8) { + step++ + } else { + throw new Error(`Wrong step: ${step}`) + } + }) + + store.put([{ ...book1, numPages: 2 }]) + + expect(store.get(book1Id)!.numPages).toBe(1007) + expect(step).toBe(9) + }) }) diff --git a/packages/store/src/lib/test/recordStoreFuzzing.test.ts b/packages/store/src/lib/test/recordStoreFuzzing.test.ts index 55b40bb5f..6fb3f40c9 100644 --- a/packages/store/src/lib/test/recordStoreFuzzing.test.ts +++ b/packages/store/src/lib/test/recordStoreFuzzing.test.ts @@ -333,12 +333,10 @@ function runTest(seed: number) { author: Author, }), }) - store.onBeforeDelete = (record) => { - if (record.typeName === 'author') { - const books = store.query.index('book', 'authorId').get().get(record.id) - if (books) store.remove([...books]) - } - } + store.sideEffects.registerBeforeDeleteHandler('author', (record) => { + const books = store.query.index('book', 'authorId').get().get(record.id) + if (books) store.remove([...books]) + }) const getRandomNumber = rng(seed) const authorNameIndex = store.query.index('author', 'name') const authorIdIndex = store.query.index('book', 'authorId')