Move arrow helpers from editor to tldraw (#3721)

With the new work on bindings, we no longer need to keep any arrows
stuff hard-coded in `editor`, so let's move it to `tldraw` with the rest
of the shapes.

Couple other changes as part of this:
- We had two different types of `WeakMap` backed cache, but we now only
have one
- There's a new free-standing version of `createComputedCache` that
doesn't need access to the editor/store in order to create the cache.
instead, it returns a `{get(editor, id)}` object and instantiates the
cache on a per-editor basis for each call.
- Fixed a bug in `createSelectedComputedCache` where the selector
derivation would get re-created on every call to `get`

### Change Type

- [x] `sdk` — Changes the tldraw SDK
- [x] `improvement` — Improving existing features

### Release Notes

#### Breaking changes
- `editor.getArrowInfo(shape)` has been replaced with
`getArrowInfo(editor, shape)`
- `editor.getArrowsBoundTo(shape)` has been removed. Instead, use
`editor.getBindingsToShape(shape, 'arrow')` and follow the `fromId` of
each binding to the corresponding arrow shape
- These types have moved from `@tldraw/editor` to `tldraw`:
    - `TLArcInfo`
    - `TLArrowInfo`
    - `TLArrowPoint`
- `WeakMapCache` has been removed
This commit is contained in:
alex 2024-05-09 10:48:01 +01:00 committed by GitHub
parent 7b99c8532b
commit 91903c9761
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 251 additions and 291 deletions

View file

@ -40,10 +40,6 @@ import { StoreSchema } from '@tldraw/store';
import { StoreSnapshot } from '@tldraw/store';
import { StyleProp } from '@tldraw/tlschema';
import { StylePropValue } from '@tldraw/tlschema';
import { TLArrowBinding } from '@tldraw/tlschema';
import { TLArrowBindingProps } from '@tldraw/tlschema';
import { TLArrowShape } from '@tldraw/tlschema';
import { TLArrowShapeArrowheadStyle } from '@tldraw/tlschema';
import { TLAsset } from '@tldraw/tlschema';
import { TLAssetId } from '@tldraw/tlschema';
import { TLAssetPartial } from '@tldraw/tlschema';
@ -451,9 +447,6 @@ export const coreShapes: readonly [typeof GroupShapeUtil];
// @public
export function counterClockwiseAngleDist(a0: number, a1: number): number;
// @internal
export function createOrUpdateArrowBinding(editor: Editor, arrow: TLArrowShape | TLShapeId, target: TLShape | TLShapeId, props: TLArrowBindingProps): void;
// @public
export function createSessionStateSnapshotSignal(store: TLStore): Signal<null | TLSessionStateSnapshot>;
@ -779,8 +772,6 @@ export class Editor extends EventEmitter<TLEventMap> {
findShapeAncestor(shape: TLShape | TLShapeId, predicate: (parent: TLShape) => boolean): TLShape | undefined;
flipShapes(shapes: TLShape[] | TLShapeId[], operation: 'horizontal' | 'vertical'): this;
getAncestorPageId(shape?: TLShape | TLShapeId): TLPageId | undefined;
getArrowInfo(shape: TLArrowShape | TLShapeId): TLArrowInfo | undefined;
getArrowsBoundTo(shapeId: TLShapeId): TLArrowShape[];
getAsset(asset: TLAsset | TLAssetId): TLAsset | undefined;
getAssetForExternalContent(info: TLExternalAssetContent): Promise<TLAsset | undefined>;
getAssets(): (TLBookmarkAsset | TLImageAsset | TLVideoAsset)[];
@ -1194,15 +1185,6 @@ export abstract class Geometry2d {
// @public
export function getArcMeasure(A: number, B: number, sweepFlag: number, largeArcFlag: number): number;
// @public (undocumented)
export function getArrowBindings(editor: Editor, shape: TLArrowShape): TLArrowBindings;
// @public (undocumented)
export function getArrowTerminalsInArrowSpace(editor: Editor, shape: TLArrowShape, bindings: TLArrowBindings): {
end: Vec;
start: Vec;
};
// @public (undocumented)
export function getCursor(cursor: TLCursorType, rotation?: number, color?: string): string;
@ -1712,9 +1694,6 @@ export function refreshPage(): void;
// @public (undocumented)
export function releasePointerCapture(element: Element, event: PointerEvent | React_2.PointerEvent<Element>): void;
// @internal
export function removeArrowBinding(editor: Editor, arrow: TLArrowShape, terminal: 'end' | 'start'): void;
// @public (undocumented)
export type RequiredKeys<T, K extends keyof T> = Partial<Omit<T, K>> & Pick<T, K>;
@ -2094,57 +2073,6 @@ export type TLAnyBindingUtilConstructor = TLBindingUtilConstructor<any>;
// @public (undocumented)
export type TLAnyShapeUtilConstructor = TLShapeUtilConstructor<any>;
// @public (undocumented)
export interface TLArcInfo {
// (undocumented)
center: VecLike;
// (undocumented)
largeArcFlag: number;
// (undocumented)
length: number;
// (undocumented)
radius: number;
// (undocumented)
size: number;
// (undocumented)
sweepFlag: number;
}
// @public (undocumented)
export interface TLArrowBindings {
// (undocumented)
end: TLArrowBinding | undefined;
// (undocumented)
start: TLArrowBinding | undefined;
}
// @public (undocumented)
export type TLArrowInfo = {
bindings: TLArrowBindings;
bodyArc: TLArcInfo;
end: TLArrowPoint;
handleArc: TLArcInfo;
isStraight: false;
isValid: boolean;
middle: VecLike;
start: TLArrowPoint;
} | {
bindings: TLArrowBindings;
end: TLArrowPoint;
isStraight: true;
isValid: boolean;
length: number;
middle: VecLike;
start: TLArrowPoint;
};
// @public (undocumented)
export type TLArrowPoint = {
arrowhead: TLArrowShapeArrowheadStyle;
handle: VecLike;
point: VecLike;
};
// @public (undocumented)
export type TLBaseBoxShape = TLBaseShape<string, {
h: number;
@ -3214,24 +3142,6 @@ export class Vec {
// @public (undocumented)
export type VecLike = Vec | VecModel;
// @public (undocumented)
export class WeakMapCache<T extends object, K> {
// (undocumented)
access(item: T): K | undefined;
// (undocumented)
bust(): void;
// (undocumented)
get<P extends T>(item: P, cb: (item: P) => K): NonNullable<K>;
// (undocumented)
has(item: T): boolean;
// (undocumented)
invalidate(item: T): void;
// (undocumented)
items: WeakMap<T, K>;
// (undocumented)
set(item: T, value: K): void;
}
export { whyAmIRunning }

View file

@ -183,18 +183,6 @@ export {
type TLShapeUtilFlag,
} from './lib/editor/shapes/ShapeUtil'
export { GroupShapeUtil } from './lib/editor/shapes/group/GroupShapeUtil'
export {
type TLArcInfo,
type TLArrowInfo,
type TLArrowPoint,
} from './lib/editor/shapes/shared/arrow/arrow-types'
export {
createOrUpdateArrowBinding,
getArrowBindings,
getArrowTerminalsInArrowSpace,
removeArrowBinding,
type TLArrowBindings,
} from './lib/editor/shapes/shared/arrow/shared'
export { resizeBox, type ResizeBoxOptions } from './lib/editor/shapes/shared/resizeBox'
export { BaseBoxShapeTool } from './lib/editor/tools/BaseBoxShapeTool/BaseBoxShapeTool'
export { StateNode, type TLStateNodeConstructor } from './lib/editor/tools/StateNode'
@ -343,7 +331,6 @@ export {
SharedStyleMap,
type SharedStyle,
} from './lib/utils/SharedStylesMap'
export { WeakMapCache } from './lib/utils/WeakMapCache'
export { dataUrlToFile } from './lib/utils/assets'
export { debugFlags, featureFlags, type DebugFlag } from './lib/utils/debug-flags'
export {

View file

@ -12,7 +12,6 @@ import {
PageRecordType,
StyleProp,
StylePropValue,
TLArrowBinding,
TLArrowShape,
TLAsset,
TLAssetId,
@ -112,7 +111,6 @@ import { Group2d } from '../primitives/geometry/Group2d'
import { intersectPolygonPolygon } from '../primitives/intersect'
import { PI2, approximately, areAnglesCompatible, clamp, pointInPolygon } from '../primitives/utils'
import { ReadonlySharedStyleMap, SharedStyle, SharedStyleMap } from '../utils/SharedStylesMap'
import { WeakMapCache } from '../utils/WeakMapCache'
import { dataUrlToFile } from '../utils/assets'
import { debugFlags } from '../utils/debug-flags'
import { getIncrementedName } from '../utils/getIncrementedName'
@ -135,10 +133,6 @@ import { TextManager } from './managers/TextManager'
import { TickManager } from './managers/TickManager'
import { UserPreferencesManager } from './managers/UserPreferencesManager'
import { ShapeUtil, TLResizeMode, TLShapeUtilConstructor } from './shapes/ShapeUtil'
import { TLArrowInfo } from './shapes/shared/arrow/arrow-types'
import { getCurvedArrowInfo } from './shapes/shared/arrow/curved-arrow'
import { getArrowBindings, getIsArrowStraight } from './shapes/shared/arrow/shared'
import { getStraightArrowInfo } from './shapes/shared/arrow/straight-arrow'
import { RootState } from './tools/RootState'
import { StateNode, TLStateNodeConstructor } from './tools/StateNode'
import { TLContent } from './types/clipboard-types'
@ -928,50 +922,6 @@ export class Editor extends EventEmitter<TLEventMap> {
return this
}
/* --------------------- Arrows --------------------- */
// todo: move these to tldraw or replace with a bindings API
/**
* Get all arrows bound to a shape.
*
* @param shapeId - The id of the shape.
*
* @public
*/
getArrowsBoundTo(shapeId: TLShapeId) {
const ids = new Set(
this.getBindingsToShape<TLArrowBinding>(shapeId, 'arrow').map((b) => b.fromId)
)
return compact(Array.from(ids, (id) => this.getShape<TLArrowShape>(id)))
}
@computed
private getArrowInfoCache() {
return this.store.createComputedCache<TLArrowInfo, TLArrowShape>('arrow infoCache', (shape) => {
const bindings = getArrowBindings(this, shape)
return getIsArrowStraight(shape)
? getStraightArrowInfo(this, shape, bindings)
: getCurvedArrowInfo(this, shape, bindings)
})
}
/**
* Get cached info about an arrow.
*
* @example
* ```ts
* const arrowInfo = editor.getArrowInfo(myArrow)
* ```
*
* @param shape - The shape (or shape id) of the arrow to get the info for.
*
* @public
*/
getArrowInfo(shape: TLArrowShape | TLShapeId): TLArrowInfo | undefined {
const id = typeof shape === 'string' ? shape : shape.id
return this.getArrowInfoCache().get(id)
}
/* --------------------- Errors --------------------- */
/** @internal */
@ -4884,13 +4834,6 @@ export class Editor extends EventEmitter<TLEventMap> {
return getIndexAbove(shape.index)
}
/**
* A cache of children for each parent.
*
* @internal
*/
private _childIdsCache = new WeakMapCache<any[], TLShapeId[]>()
/**
* Get an array of all the children of a shape.
*
@ -4907,7 +4850,7 @@ export class Editor extends EventEmitter<TLEventMap> {
const parentId = typeof parent === 'string' ? parent : parent.id
const ids = this._parentIdsToChildIds.get()[parentId]
if (!ids) return EMPTY_ARRAY
return this._childIdsCache.get(ids, () => ids)
return ids
}
/**

View file

@ -1,31 +0,0 @@
/** @public */
export class WeakMapCache<T extends object, K> {
items = new WeakMap<T, K>()
get<P extends T>(item: P, cb: (item: P) => K) {
if (!this.items.has(item)) {
this.items.set(item, cb(item))
}
return this.items.get(item)!
}
access(item: T) {
return this.items.get(item)
}
set(item: T, value: K) {
this.items.set(item, value)
}
has(item: T) {
return this.items.has(item)
}
invalidate(item: T) {
this.items.delete(item)
}
bust() {
this.items = new WeakMap()
}
}

View file

@ -33,6 +33,11 @@ export type ComputedCache<Data, R extends UnknownRecord> = {
get(id: IdOf<R>): Data | undefined;
};
// @public
export function createComputedCache<Context extends StoreContext<any>, Result, Record extends ContextRecordType<Context> = ContextRecordType<Context>>(name: string, derive: (context: Context, record: Record) => Result | undefined, isEqual?: (a: Record, b: Record) => boolean): {
get(context: Context, id: IdOf<Record>): Result | undefined;
};
// @internal (undocumented)
export function createEmptyRecordsDiff<R extends UnknownRecord>(): RecordsDiff<R>;
@ -289,8 +294,8 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
// @internal (undocumented)
atomic<T>(fn: () => T, runCallbacks?: boolean): T;
clear: () => void;
createComputedCache: <T, V extends R = R>(name: string, derive: (record: V) => T | undefined, isEqual?: ((a: V, b: V) => boolean) | undefined) => ComputedCache<T, V>;
createSelectedComputedCache: <T, J, V extends R = R>(name: string, selector: (record: V) => T | undefined, derive: (input: T) => J | undefined) => ComputedCache<J, V>;
createComputedCache: <Result, Record extends R = R>(name: string, derive: (record: Record) => Result | undefined, isEqual?: ((a: Record, b: Record) => boolean) | undefined) => ComputedCache<Result, Record>;
createSelectedComputedCache: <Selection, Result, Record extends R = R>(name: string, selector: (record: Record) => Selection | undefined, derive: (input: Selection) => Result | undefined) => ComputedCache<Result, Record>;
// @internal (undocumented)
ensureStoreIsUsable(): void;
extractingChanges(fn: () => void): RecordsDiff<R>;

View file

@ -9,7 +9,7 @@ export {
squashRecordDiffsMutable,
type RecordsDiff,
} from './lib/RecordsDiff'
export { Store } from './lib/Store'
export { Store, createComputedCache } from './lib/Store'
export type {
CollectionDiff,
ComputedCache,

View file

@ -1,5 +1,6 @@
import { Atom, Computed, Reactor, atom, computed, reactor, transact } from '@tldraw/state'
import {
WeakCache,
assert,
filterEntries,
getOwnProperty,
@ -11,7 +12,6 @@ import {
} from '@tldraw/utils'
import { nanoid } from 'nanoid'
import { IdOf, RecordId, UnknownRecord } from './BaseRecord'
import { Cache } from './Cache'
import { RecordScope } from './RecordType'
import { RecordsDiff, squashRecordDiffs } from './RecordsDiff'
import { StoreQueries } from './StoreQueries'
@ -765,14 +765,14 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
* @param derive - A function used to derive the value of the cache.
* @public
*/
createComputedCache = <T, V extends R = R>(
createComputedCache = <Result, Record extends R = R>(
name: string,
derive: (record: V) => T | undefined,
isEqual?: (a: V, b: V) => boolean
): ComputedCache<T, V> => {
const cache = new Cache<Atom<any>, Computed<T | undefined>>()
derive: (record: Record) => Result | undefined,
isEqual?: (a: Record, b: Record) => boolean
): ComputedCache<Result, Record> => {
const cache = new WeakCache<Atom<any>, Computed<Result | undefined>>()
return {
get: (id: IdOf<V>) => {
get: (id: IdOf<Record>) => {
const atom = this.atoms.get()[id]
if (!atom) {
return undefined
@ -782,8 +782,8 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
const recordSignal = isEqual
? computed(atom.name + ':equals', () => atom.get(), { isEqual })
: atom
return computed<T | undefined>(name + ':' + id, () => {
return derive(recordSignal.get() as V)
return computed<Result | undefined>(name + ':' + id, () => {
return derive(recordSignal.get() as Record)
})
})
.get()
@ -799,24 +799,26 @@ export class Store<R extends UnknownRecord = UnknownRecord, Props = unknown> {
* @param derive - A function used to derive the value of the cache.
* @public
*/
createSelectedComputedCache = <T, J, V extends R = R>(
createSelectedComputedCache = <Selection, Result, Record extends R = R>(
name: string,
selector: (record: V) => T | undefined,
derive: (input: T) => J | undefined
): ComputedCache<J, V> => {
const cache = new Cache<Atom<any>, Computed<J | undefined>>()
selector: (record: Record) => Selection | undefined,
derive: (input: Selection) => Result | undefined
): ComputedCache<Result, Record> => {
const cache = new WeakCache<Atom<any>, Computed<Result | undefined>>()
return {
get: (id: IdOf<V>) => {
get: (id: IdOf<Record>) => {
const atom = this.atoms.get()[id]
if (!atom) {
return undefined
}
const d = computed<T | undefined>(name + ':' + id + ':selector', () =>
selector(atom.get() as V)
)
return cache
.get(atom, () => computed<J | undefined>(name + ':' + id, () => derive(d.get() as T)))
.get(atom, () => {
const d = computed<Selection | undefined>(name + ':' + id + ':selector', () =>
selector(atom.get() as Record)
)
return computed<Result | undefined>(name + ':' + id, () => derive(d.get() as Selection))
})
.get()
},
}
@ -989,3 +991,42 @@ class HistoryAccumulator<T extends UnknownRecord> {
return this._history.length > 0
}
}
type StoreContext<R extends UnknownRecord> = Store<R> | { store: Store<R> }
type ContextRecordType<Context extends StoreContext<any>> =
Context extends Store<infer R> ? R : Context extends { store: Store<infer R> } ? R : never
/**
* Free version of {@link Store.createComputedCache}.
*
* @example
* ```ts
* const myCache = createComputedCache('myCache', (editor: Editor, shape: TLShape) => {
* return editor.getSomethingExpensive(shape)
* })
*
* myCache.get(editor, shape.id)
* ```
*
* @public
*/
export function createComputedCache<
Context extends StoreContext<any>,
Result,
Record extends ContextRecordType<Context> = ContextRecordType<Context>,
>(
name: string,
derive: (context: Context, record: Record) => Result | undefined,
isEqual?: (a: Record, b: Record) => boolean
) {
const cache = new WeakCache<Context, ComputedCache<Result, Record>>()
return {
get(context: Context, id: IdOf<Record>) {
const computedCache = cache.get(context, () => {
const store = (context instanceof Store ? context : context.store) as Store<Record>
return store.createComputedCache(name, (record) => derive(context, record), isEqual)
})
return computedCache.get(id)
},
}
}

View file

@ -55,7 +55,9 @@ import { SvgExportContext } from '@tldraw/editor';
import { T } from '@tldraw/editor';
import { TLAnyBindingUtilConstructor } from '@tldraw/editor';
import { TLAnyShapeUtilConstructor } from '@tldraw/editor';
import { TLArrowBinding } from '@tldraw/editor';
import { TLArrowShape } from '@tldraw/editor';
import { TLArrowShapeArrowheadStyle } from '@tldraw/editor';
import { TLAssetId } from '@tldraw/editor';
import { TLBaseEventInfo } from '@tldraw/editor';
import { TLBookmarkShape } from '@tldraw/editor';
@ -815,6 +817,15 @@ export function GeoStylePickerSet({ styles }: {
styles: ReadonlySharedStyleMap;
}): JSX_2.Element | null;
// @public (undocumented)
export function getArrowBindings(editor: Editor, shape: TLArrowShape): TLArrowBindings;
// @public (undocumented)
export function getArrowTerminalsInArrowSpace(editor: Editor, shape: TLArrowShape, bindings: TLArrowBindings): {
end: Vec;
start: Vec;
};
// @public
export function getEmbedInfo(inputUrl: string): TLEmbedResult;
@ -1452,6 +1463,57 @@ export function TextStylePickerSet({ theme, styles, }: {
// @public (undocumented)
export function TextToolbarItem(): JSX_2.Element;
// @public (undocumented)
export interface TLArcInfo {
// (undocumented)
center: VecLike;
// (undocumented)
largeArcFlag: number;
// (undocumented)
length: number;
// (undocumented)
radius: number;
// (undocumented)
size: number;
// (undocumented)
sweepFlag: number;
}
// @public (undocumented)
export interface TLArrowBindings {
// (undocumented)
end: TLArrowBinding | undefined;
// (undocumented)
start: TLArrowBinding | undefined;
}
// @public (undocumented)
export type TLArrowInfo = {
bindings: TLArrowBindings;
bodyArc: TLArcInfo;
end: TLArrowPoint;
handleArc: TLArcInfo;
isStraight: false;
isValid: boolean;
middle: VecLike;
start: TLArrowPoint;
} | {
bindings: TLArrowBindings;
end: TLArrowPoint;
isStraight: true;
isValid: boolean;
length: number;
middle: VecLike;
start: TLArrowPoint;
};
// @public (undocumented)
export type TLArrowPoint = {
arrowhead: TLArrowShapeArrowheadStyle;
handle: VecLike;
point: VecLike;
};
// @public (undocumented)
export type TLComponents = Expand<TLEditorComponents & TLUiComponents>;

View file

@ -18,6 +18,12 @@ export { defaultShapeUtils } from './lib/defaultShapeUtils'
export { defaultTools } from './lib/defaultTools'
export { ArrowShapeTool } from './lib/shapes/arrow/ArrowShapeTool'
export { ArrowShapeUtil } from './lib/shapes/arrow/ArrowShapeUtil'
export { type TLArcInfo, type TLArrowInfo, type TLArrowPoint } from './lib/shapes/arrow/arrow-types'
export {
getArrowBindings,
getArrowTerminalsInArrowSpace,
type TLArrowBindings,
} from './lib/shapes/arrow/shared'
export { BookmarkShapeUtil } from './lib/shapes/bookmark/BookmarkShapeUtil'
export { DrawShapeTool } from './lib/shapes/draw/DrawShapeTool'
export { DrawShapeUtil } from './lib/shapes/draw/DrawShapeUtil'

View file

@ -17,12 +17,11 @@ import {
arrowBindingMigrations,
arrowBindingProps,
assert,
getArrowBindings,
getIndexAbove,
getIndexBetween,
intersectLineSegmentCircle,
removeArrowBinding,
} from '@tldraw/editor'
import { getArrowBindings, getArrowInfo, removeArrowBinding } from '../../shapes/arrow/shared'
export class ArrowBindingUtil extends BindingUtil<TLArrowBinding> {
static override type = 'arrow'
@ -178,7 +177,7 @@ function arrowDidUpdate(editor: Editor, arrow: TLArrowShape) {
}
function unbindArrowTerminal(editor: Editor, arrow: TLArrowShape, terminal: 'start' | 'end') {
const info = editor.getArrowInfo(arrow)!
const info = getArrowInfo(editor, arrow)!
if (!info) {
throw new Error('expected arrow info')
}

View file

@ -1,12 +1,6 @@
import {
IndexKey,
TLArrowShape,
TLShapeId,
Vec,
createShapeId,
getArrowBindings,
} from '@tldraw/editor'
import { IndexKey, TLArrowShape, TLShapeId, Vec, createShapeId } from '@tldraw/editor'
import { TestEditor } from '../../../test/TestEditor'
import { getArrowBindings } from './shared'
let editor: TestEditor

View file

@ -1,12 +1,6 @@
import {
HALF_PI,
TLArrowShape,
TLShapeId,
createOrUpdateArrowBinding,
createShapeId,
getArrowBindings,
} from '@tldraw/editor'
import { HALF_PI, TLArrowShape, TLShapeId, createShapeId } from '@tldraw/editor'
import { TestEditor } from '../../../test/TestEditor'
import { createOrUpdateArrowBinding, getArrowBindings } from './shared'
let editor: TestEditor

View file

@ -10,7 +10,6 @@ import {
ShapeUtil,
SvgExportContext,
TLArrowBinding,
TLArrowBindings,
TLArrowShape,
TLHandle,
TLOnEditEndHandler,
@ -25,12 +24,8 @@ import {
Vec,
arrowShapeMigrations,
arrowShapeProps,
createOrUpdateArrowBinding,
getArrowBindings,
getArrowTerminalsInArrowSpace,
getDefaultColorTheme,
mapObjectMapValues,
removeArrowBinding,
structuredClone,
toDomPrecision,
track,
@ -56,6 +51,14 @@ import {
getStraightArrowHandlePath,
} from './arrowpaths'
import { ArrowTextLabel } from './components/ArrowTextLabel'
import {
TLArrowBindings,
createOrUpdateArrowBinding,
getArrowBindings,
getArrowInfo,
getArrowTerminalsInArrowSpace,
removeArrowBinding,
} from './shared'
let globalRenderIndex = 0
@ -103,7 +106,7 @@ export class ArrowShapeUtil extends ShapeUtil<TLArrowShape> {
}
getGeometry(shape: TLArrowShape) {
const info = this.editor.getArrowInfo(shape)!
const info = getArrowInfo(this.editor, shape)!
const debugGeom: Geometry2d[] = []
@ -141,7 +144,7 @@ export class ArrowShapeUtil extends ShapeUtil<TLArrowShape> {
}
override getHandles(shape: TLArrowShape): TLHandle[] {
const info = this.editor.getArrowInfo(shape)!
const info = getArrowInfo(this.editor, shape)!
return [
{
@ -549,7 +552,7 @@ export class ArrowShapeUtil extends ShapeUtil<TLArrowShape> {
'arrow.dragging'
) && !this.editor.getInstanceState().isReadonly
const info = this.editor.getArrowInfo(shape)
const info = getArrowInfo(this.editor, shape)
if (!info?.isValid) return null
const labelPosition = getArrowLabelPosition(this.editor, shape)
@ -585,7 +588,7 @@ export class ArrowShapeUtil extends ShapeUtil<TLArrowShape> {
// eslint-disable-next-line react-hooks/rules-of-hooks
const isEditing = useIsEditing(shape.id)
const info = this.editor.getArrowInfo(shape)
const info = getArrowInfo(this.editor, shape)
if (!info) return null
const { start, end } = getArrowTerminalsInArrowSpace(this.editor, shape, info?.bindings)
@ -752,7 +755,7 @@ export class ArrowShapeUtil extends ShapeUtil<TLArrowShape> {
}
function getLength(editor: Editor, shape: TLArrowShape): number {
const info = editor.getArrowInfo(shape)!
const info = getArrowInfo(editor, shape)!
return info.isStraight
? Vec.Dist(info.start.handle, info.end.handle)
@ -768,7 +771,7 @@ const ArrowSvg = track(function ArrowSvg({
}) {
const editor = useEditor()
const theme = useDefaultColorTheme()
const info = editor.getArrowInfo(shape)
const info = getArrowInfo(editor, shape)
const bounds = Box.ZeroFix(editor.getShapeGeometry(shape).bounds)
const bindings = getArrowBindings(editor, shape)

View file

@ -1,5 +1,4 @@
import { TLArrowShapeArrowheadStyle } from '@tldraw/tlschema'
import { VecLike } from '../../../../primitives/Vec'
import { TLArrowShapeArrowheadStyle, VecLike } from '@tldraw/editor'
import { TLArrowBindings } from './shared'
/** @public */

View file

@ -6,7 +6,6 @@ import {
Editor,
Geometry2d,
Polygon2d,
TLArrowInfo,
TLArrowShape,
Vec,
VecLike,
@ -24,6 +23,8 @@ import {
STROKE_SIZES,
TEXT_PROPS,
} from '../shared/default-shape-constants'
import { TLArrowInfo } from './arrow-types'
import { getArrowInfo } from './shared'
const labelSizeCache = new WeakMap<TLArrowShape, Vec>()
@ -31,7 +32,7 @@ function getArrowLabelSize(editor: Editor, shape: TLArrowShape) {
const cachedSize = labelSizeCache.get(shape)
if (cachedSize) return cachedSize
const info = editor.getArrowInfo(shape)!
const info = getArrowInfo(editor, shape)!
let width = 0
let height = 0
@ -266,7 +267,7 @@ function getCurvedArrowLabelRange(
export function getArrowLabelPosition(editor: Editor, shape: TLArrowShape) {
let labelCenter
const debugGeom: Geometry2d[] = []
const info = editor.getArrowInfo(shape)!
const info = getArrowInfo(editor, shape)!
const hasStartBinding = !!info.bindings.start
const hasEndBinding = !!info.bindings.end

View file

@ -1,4 +1,5 @@
import { HALF_PI, PI, TLArrowInfo, Vec, VecLike, intersectCircleCircle } from '@tldraw/editor'
import { HALF_PI, PI, Vec, VecLike, intersectCircleCircle } from '@tldraw/editor'
import { TLArrowInfo } from './arrow-types'
type TLArrowPointsInfo = {
point: VecLike

View file

@ -1,4 +1,5 @@
import { TLArrowInfo, VecLike } from '@tldraw/editor'
import { VecLike } from '@tldraw/editor'
import { TLArrowInfo } from './arrow-types'
/* --------------------- Curved --------------------- */

View file

@ -1,15 +1,17 @@
import { TLArrowShape } from '@tldraw/tlschema'
import { Mat } from '../../../../primitives/Mat'
import { Vec, VecLike } from '../../../../primitives/Vec'
import { intersectCirclePolygon, intersectCirclePolyline } from '../../../../primitives/intersect'
import {
Editor,
Mat,
PI,
PI2,
TLArrowShape,
Vec,
VecLike,
clockwiseAngleDist,
counterClockwiseAngleDist,
intersectCirclePolygon,
intersectCirclePolyline,
isSafeFloat,
} from '../../../../primitives/utils'
import type { Editor } from '../../../Editor'
} from '@tldraw/editor'
import { TLArcInfo, TLArrowInfo } from './arrow-types'
import {
BOUND_ARROW_OFFSET,

View file

@ -1,14 +1,17 @@
import {
Editor,
Group2d,
Mat,
TLArrowBinding,
TLArrowBindingProps,
TLArrowShape,
TLShape,
TLShapeId,
} from '@tldraw/tlschema'
import { Mat } from '../../../../primitives/Mat'
import { Vec } from '../../../../primitives/Vec'
import { Group2d } from '../../../../primitives/geometry/Group2d'
import { Editor } from '../../../Editor'
Vec,
} from '@tldraw/editor'
import { createComputedCache } from '@tldraw/store'
import { getCurvedArrowInfo } from './curved-arrow'
import { getStraightArrowInfo } from './straight-arrow'
export function getIsArrowStraight(shape: TLArrowShape) {
return Math.abs(shape.props.bend) < 8 // snap to +-8px
@ -101,6 +104,19 @@ export function getArrowBindings(editor: Editor, shape: TLArrowShape): TLArrowBi
}
}
const arrowInfoCache = createComputedCache('arrow info', (editor: Editor, shape: TLArrowShape) => {
const bindings = getArrowBindings(editor, shape)
return getIsArrowStraight(shape)
? getStraightArrowInfo(editor, shape, bindings)
: getCurvedArrowInfo(editor, shape, bindings)
})
/** @public */
export function getArrowInfo(editor: Editor, shape: TLArrowShape | TLShapeId) {
const id = typeof shape === 'string' ? shape : shape.id
return arrowInfoCache.get(editor, id)
}
/** @public */
export function getArrowTerminalsInArrowSpace(
editor: Editor,

View file

@ -1,11 +1,13 @@
import { TLArrowShape } from '@tldraw/tlschema'
import { Mat, MatModel } from '../../../../primitives/Mat'
import { Vec, VecLike } from '../../../../primitives/Vec'
import {
Editor,
Mat,
MatModel,
TLArrowShape,
Vec,
VecLike,
intersectLineSegmentPolygon,
intersectLineSegmentPolyline,
} from '../../../../primitives/intersect'
import { Editor } from '../../../Editor'
} from '@tldraw/editor'
import { TLArrowInfo } from './arrow-types'
import {
BOUND_ARROW_OFFSET,

View file

@ -10,7 +10,7 @@ import {
TLOnHandleDragHandler,
TLOnResizeHandler,
Vec,
WeakMapCache,
WeakCache,
getIndexBetween,
getIndices,
lineShapeMigrations,
@ -30,7 +30,7 @@ import {
getSvgPathForLineGeometry,
} from './components/svg'
const handlesCache = new WeakMapCache<TLLineShape['props'], TLHandle[]>()
const handlesCache = new WeakCache<TLLineShape['props'], TLHandle[]>()
/** @public */
export class LineShapeUtil extends ShapeUtil<TLLineShape> {

View file

@ -11,7 +11,7 @@ import {
TLShape,
TLShapeId,
Vec,
WeakMapCache,
WeakCache,
getDefaultColorTheme,
noteShapeMigrations,
noteShapeProps,
@ -372,7 +372,7 @@ function getNoteLabelSize(editor: Editor, shape: TLNoteShape) {
}
}
const labelSizesForNote = new WeakMapCache<TLShape, ReturnType<typeof getNoteLabelSize>>()
const labelSizesForNote = new WeakCache<TLShape, ReturnType<typeof getNoteLabelSize>>()
function getLabelSize(editor: Editor, shape: TLNoteShape) {
return labelSizesForNote.get(shape, () => getNoteLabelSize(editor, shape))

View file

@ -11,7 +11,7 @@ import {
TLShapeUtilFlag,
TLTextShape,
Vec,
WeakMapCache,
WeakCache,
getDefaultColorTheme,
preventDefault,
textShapeMigrations,
@ -28,7 +28,7 @@ import { FONT_FAMILIES, FONT_SIZES, TEXT_PROPS } from '../shared/default-shape-c
import { getFontDefForExport } from '../shared/defaultStyleDefs'
import { resizeScaled } from '../shared/resizeScaled'
const sizeCache = new WeakMapCache<TLTextShape['props'], { height: number; width: number }>()
const sizeCache = new WeakCache<TLTextShape['props'], { height: number; width: number }>()
/** @public */
export class TextShapeUtil extends ShapeUtil<TLTextShape> {

View file

@ -11,11 +11,11 @@ import {
TLShapeId,
TLShapePartial,
Vec,
getArrowBindings,
snapAngle,
sortByIndex,
structuredClone,
} from '@tldraw/editor'
import { getArrowBindings } from '../../../shapes/arrow/shared'
import { kickoutOccludedShapes } from '../selectHelpers'
export class DraggingHandle extends StateNode {

View file

@ -10,6 +10,7 @@ import {
Vec,
getPointInArcT,
} from '@tldraw/editor'
import { getArrowInfo } from '../../../shapes/arrow/shared'
export class PointingArrowLabel extends StateNode {
static override id = 'pointing_arrow_label'
@ -74,7 +75,7 @@ export class PointingArrowLabel extends StateNode {
const shape = this.editor.getShape<TLArrowShape>(this.shapeId)
if (!shape) return
const info = this.editor.getArrowInfo(shape)!
const info = getArrowInfo(this.editor, shape)!
const groupGeometry = this.editor.getShapeGeometry<Group2d>(shape)
const bodyGeometry = groupGeometry.children[0] as Geometry2d

View file

@ -7,8 +7,8 @@ import {
TLNoteShape,
TLPointerEventInfo,
Vec,
getArrowBindings,
} from '@tldraw/editor'
import { getArrowBindings } from '../../../shapes/arrow/shared'
import {
NOTE_CENTER_OFFSET,
getNoteAdjacentPositions,

View file

@ -5,10 +5,10 @@ import {
TLGroupShape,
TLLineShape,
TLTextShape,
getArrowBindings,
useEditor,
useValue,
} from '@tldraw/editor'
import { getArrowBindings } from '../../shapes/arrow/shared'
function shapesWithUnboundArrows(editor: Editor) {
const selectedShapeIds = editor.getSelectedShapeIds()

View file

@ -25,9 +25,9 @@ import {
VecModel,
clamp,
createShapeId,
getArrowBindings,
structuredClone,
} from '@tldraw/editor'
import { getArrowBindings } from '../../shapes/arrow/shared'
const TLDRAW_V1_VERSION = 15.5

View file

@ -10,6 +10,8 @@ import {
RequiredKeys,
RotateCorner,
SelectionHandle,
TLArrowBinding,
TLArrowShape,
TLContent,
TLEditorOptions,
TLEventInfo,
@ -22,6 +24,7 @@ import {
TLWheelEventInfo,
Vec,
VecLike,
compact,
computed,
createShapeId,
createTLStore,
@ -708,6 +711,13 @@ export class TestEditor extends Editor {
getPageRotation(shape: TLShape) {
return this.getPageRotationById(shape.id)
}
getArrowsBoundTo(shapeId: TLShapeId) {
const ids = new Set(
this.getBindingsToShape<TLArrowBinding>(shapeId, 'arrow').map((b) => b.fromId)
)
return compact(Array.from(ids, (id) => this.getShape<TLArrowShape>(id)))
}
}
export const defaultShapesIds = {

View file

@ -1,4 +1,5 @@
import { TLArrowShape, TLShapeId, Vec, createShapeId, getArrowBindings } from '@tldraw/editor'
import { TLArrowShape, TLShapeId, Vec, createShapeId } from '@tldraw/editor'
import { getArrowBindings } from '../lib/shapes/arrow/shared'
import { TestEditor } from './TestEditor'
import { TL } from './test-jsx'

View file

@ -1,4 +1,5 @@
import { TLArrowShape, createShapeId, getArrowBindings } from '@tldraw/editor'
import { TLArrowShape, createShapeId } from '@tldraw/editor'
import { getArrowBindings } from '../lib/shapes/arrow/shared'
import { TestEditor } from './TestEditor'
let editor: TestEditor

View file

@ -1,4 +1,5 @@
import { createShapeId, getArrowBindings, TLArrowShape } from '@tldraw/editor'
import { createShapeId, TLArrowShape } from '@tldraw/editor'
import { getArrowBindings } from '../../lib/shapes/arrow/shared'
import { TestEditor } from '../TestEditor'
let editor: TestEditor

View file

@ -1,4 +1,5 @@
import { createBindingId, createShapeId, getArrowBindings } from '@tldraw/editor'
import { createBindingId, createShapeId } from '@tldraw/editor'
import { getArrowBindings } from '../../lib/shapes/arrow/shared'
import { TestEditor } from '../TestEditor'
let editor: TestEditor

View file

@ -1,11 +1,11 @@
import {
createBindingId,
createShapeId,
getArrowBindings,
TLArrowShape,
TLBindingPartial,
TLShapePartial,
} from '@tldraw/editor'
import { getArrowBindings } from '../lib/shapes/arrow/shared'
import { TestEditor } from './TestEditor'
let editor: TestEditor

View file

@ -8,8 +8,8 @@ import {
TLShapePartial,
createBindingId,
createShapeId,
getArrowBindings,
} from '@tldraw/editor'
import { getArrowBindings } from '../lib/shapes/arrow/shared'
import { TestEditor } from './TestEditor'
let editor: TestEditor

View file

@ -5,8 +5,8 @@ import {
TLFrameShape,
TLShapeId,
createShapeId,
getArrowBindings,
} from '@tldraw/editor'
import { getArrowBindings } from '../lib/shapes/arrow/shared'
import { DEFAULT_FRAME_PADDING, fitFrameToContent, removeFrame } from '../lib/utils/frames/frames'
import { TestEditor } from './TestEditor'

View file

@ -12,9 +12,9 @@ import {
assert,
compact,
createShapeId,
getArrowBindings,
sortByIndex,
} from '@tldraw/editor'
import { getArrowBindings } from '../lib/shapes/arrow/shared'
import { TestEditor } from './TestEditor'
jest.mock('nanoid', () => {

View file

@ -9,8 +9,8 @@ import {
TLShapePartial,
Vec,
createShapeId,
getArrowBindings,
} from '@tldraw/editor'
import { getArrowBindings } from '../lib/shapes/arrow/shared'
import { TestEditor } from './TestEditor'
import { getSnapLines } from './getSnapLines'

View file

@ -347,6 +347,12 @@ export function validateIndexKey(key: string): asserts key is IndexKey;
// @internal (undocumented)
export function warnDeprecatedGetter(name: string): void;
// @public
export class WeakCache<K extends object, V> {
get<P extends K>(item: P, cb: (item: P) => V): NonNullable<V>;
items: WeakMap<K, V>;
}
// @public
export const ZERO_INDEX_KEY: IndexKey;

View file

@ -8,6 +8,7 @@ export {
partition,
rotateArray,
} from './lib/array'
export { WeakCache } from './lib/cache'
export {
Result,
assert,

View file

@ -1,7 +1,10 @@
/** A micro cache used when storing records in memory (using a WeakMap). */
export class Cache<T extends object, K> {
/**
* A micro cache used when storing records in memory (using a WeakMap).
* @public
*/
export class WeakCache<K extends object, V> {
/** The map of items to their cached values. */
items = new WeakMap<T, K>()
items = new WeakMap<K, V>()
/**
* Get the cached value for a given record. If the record is not present in the map, the callback
@ -10,7 +13,7 @@ export class Cache<T extends object, K> {
* @param item - The item to get.
* @param cb - The callback to use to create the value when a cached value is not found.
*/
get<P extends T>(item: P, cb: (item: P) => K) {
get<P extends K>(item: P, cb: (item: P) => V) {
if (!this.items.has(item)) {
this.items.set(item, cb(item))
}