[refactor] reduce dependencies on shape utils in editor (#1693)

We'd like to make the @tldraw/editor layer more independent of specific
shapes. Unfortunately there are many places where shape types and
certain shape behavior is deeply embedded in the Editor. This PR begins
to refactor out dependencies between the editor library and shape utils.

It does this in two ways:
- removing shape utils from the arguments of `isShapeOfType`, replacing
with a generic
- removing shape utils from the arguments of `getShapeUtil`, replacing
with a generic
- moving custom arrow info cache out of the util and into the editor
class
- changing the a tool's `shapeType` to be a string instead of a shape
util

We're here trading type safety based on inferred types—"hey editor, give
me your instance of this shape util class"—for knowledge at the point of
call—"hey editor, give me a shape util class of this type; and trust me
it'll be an instance this shape util class". Likewise for shapes.

### A note on style 

We haven't really established our conventions or style when it comes to
types, but I'm increasingly of the opinion that we should defer to the
point of call to narrow a type based on generics (keeping the types in
typescript land) rather than using arguments, which blur into JavaScript
land.

### Change Type

- [x] `major` — Breaking change

### Test Plan

- [x] Unit Tests

### Release Notes

- removes shape utils from the arguments of `isShapeOfType`, replacing
with a generic
- removes shape utils from the arguments of `getShapeUtil`, replacing
with a generic
- moves custom arrow info cache out of the util and into the editor
class
- changes the a tool's `shapeType` to be a string instead of a shape
util
This commit is contained in:
Steve Ruiz 2023-07-07 14:56:31 +01:00 committed by GitHub
parent d99c4a0e9c
commit 910be6073f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
55 changed files with 281 additions and 361 deletions

View file

@ -86,8 +86,7 @@ export class CardShapeUtil extends BaseBoxShapeUtil<CardShape> {
export class CardShapeTool extends BaseBoxShapeTool { export class CardShapeTool extends BaseBoxShapeTool {
static override id = 'card' static override id = 'card'
static override initial = 'idle' static override initial = 'idle'
override shapeType = 'card'
override shapeType = CardShapeUtil
} }
export const CardShape = defineShape('card', { export const CardShape = defineShape('card', {

View file

@ -1,13 +1,11 @@
import { BaseBoxShapeTool, TLClickEvent } from '@tldraw/tldraw' import { BaseBoxShapeTool, TLClickEvent } from '@tldraw/tldraw'
import { CardShapeUtil } from './CardShapeUtil'
// A tool used to create our custom card shapes. Extending the base // A tool used to create our custom card shapes. Extending the base
// box shape tool gives us a lot of functionality for free. // box shape tool gives us a lot of functionality for free.
export class CardShapeTool extends BaseBoxShapeTool { export class CardShapeTool extends BaseBoxShapeTool {
static override id = 'card' static override id = 'card'
static override initial = 'idle' static override initial = 'idle'
override shapeType = 'card'
override shapeType = CardShapeUtil
override onDoubleClick: TLClickEvent = (_info) => { override onDoubleClick: TLClickEvent = (_info) => {
// you can handle events in handlers like this one; // you can handle events in handlers like this one;

View file

@ -121,8 +121,6 @@ export class ArrowShapeUtil extends ShapeUtil<TLArrowShape> {
// (undocumented) // (undocumented)
component(shape: TLArrowShape): JSX.Element | null; component(shape: TLArrowShape): JSX.Element | null;
// (undocumented) // (undocumented)
getArrowInfo(shape: TLArrowShape): ArrowInfo | undefined;
// (undocumented)
getBounds(shape: TLArrowShape): Box2d; getBounds(shape: TLArrowShape): Box2d;
// (undocumented) // (undocumented)
getCanvasSvgDefs(): TLShapeUtilCanvasSvgDef[]; getCanvasSvgDefs(): TLShapeUtilCanvasSvgDef[];
@ -185,7 +183,7 @@ export abstract class BaseBoxShapeTool extends StateNode {
// (undocumented) // (undocumented)
static initial: string; static initial: string;
// (undocumented) // (undocumented)
abstract shapeType: TLShapeUtilConstructor<any>; abstract shapeType: string;
} }
// @public (undocumented) // @public (undocumented)
@ -462,6 +460,8 @@ export class Editor extends EventEmitter<TLEventMap> {
getAncestorPageId(shape?: TLShape): TLPageId | undefined; getAncestorPageId(shape?: TLShape): TLPageId | undefined;
getAncestors(shape: TLShape, acc?: TLShape[]): TLShape[]; getAncestors(shape: TLShape, acc?: TLShape[]): TLShape[];
getAncestorsById(id: TLShapeId, acc?: TLShape[]): TLShape[]; getAncestorsById(id: TLShapeId, acc?: TLShape[]): TLShape[];
// (undocumented)
getArrowInfo(shape: TLArrowShape): ArrowInfo | undefined;
getArrowsBoundTo(shapeId: TLShapeId): { getArrowsBoundTo(shapeId: TLShapeId): {
arrowId: TLShapeId; arrowId: TLShapeId;
handleId: "end" | "start"; handleId: "end" | "start";
@ -512,11 +512,11 @@ export class Editor extends EventEmitter<TLEventMap> {
getShapesAtPoint(point: VecLike): TLShape[]; getShapesAtPoint(point: VecLike): TLShape[];
// (undocumented) // (undocumented)
getShapeStyleIfExists<T>(shape: TLShape, style: StyleProp<T>): T | undefined; getShapeStyleIfExists<T>(shape: TLShape, style: StyleProp<T>): T | undefined;
getShapeUtil<C extends {
new (...args: any[]): ShapeUtil<any>;
type: string;
}>(util: C): InstanceType<C>;
getShapeUtil<S extends TLUnknownShape>(shape: S | TLShapePartial<S>): ShapeUtil<S>; getShapeUtil<S extends TLUnknownShape>(shape: S | TLShapePartial<S>): ShapeUtil<S>;
// (undocumented)
getShapeUtil<S extends TLUnknownShape>(type: S['type']): ShapeUtil<S>;
// (undocumented)
getShapeUtil<T extends ShapeUtil>(type: T extends ShapeUtil<infer R> ? R['type'] : string): T;
getSortedChildIds(parentId: TLParentId): TLShapeId[]; getSortedChildIds(parentId: TLParentId): TLShapeId[];
getStateDescendant(path: string): StateNode | undefined; getStateDescendant(path: string): StateNode | undefined;
// @internal (undocumented) // @internal (undocumented)
@ -579,10 +579,7 @@ export class Editor extends EventEmitter<TLEventMap> {
readonly isSafari: boolean; readonly isSafari: boolean;
isSelected(id: TLShapeId): boolean; isSelected(id: TLShapeId): boolean;
isShapeInPage(shape: TLShape, pageId?: TLPageId): boolean; isShapeInPage(shape: TLShape, pageId?: TLPageId): boolean;
isShapeOfType<T extends TLUnknownShape>(shape: TLUnknownShape, util: { isShapeOfType<T extends TLUnknownShape>(shape: TLUnknownShape, type: T['type']): shape is T;
new (...args: any): ShapeUtil<T>;
type: string;
}): shape is T;
isShapeOrAncestorLocked(shape?: TLShape): boolean; isShapeOrAncestorLocked(shape?: TLShape): boolean;
get isSnapMode(): boolean; get isSnapMode(): boolean;
get isToolLocked(): boolean; get isToolLocked(): boolean;
@ -2083,7 +2080,7 @@ export abstract class StateNode implements Partial<TLEventHandlers> {
// (undocumented) // (undocumented)
path: Computed<string>; path: Computed<string>;
// (undocumented) // (undocumented)
shapeType?: TLShapeUtilConstructor<TLBaseShape<any, any>>; shapeType?: string;
// (undocumented) // (undocumented)
transition(id: string, info: any): this; transition(id: string, info: any): this;
// (undocumented) // (undocumented)

View file

@ -1,9 +1,8 @@
import { RotateCorner, toDomPrecision } from '@tldraw/primitives' import { RotateCorner, toDomPrecision } from '@tldraw/primitives'
import { track } from '@tldraw/state' import { track } from '@tldraw/state'
import { TLEmbedShape, TLTextShape } from '@tldraw/tlschema'
import classNames from 'classnames' import classNames from 'classnames'
import { useRef } from 'react' import { useRef } from 'react'
import { EmbedShapeUtil } from '../editor/shapes/embed/EmbedShapeUtil'
import { TextShapeUtil } from '../editor/shapes/text/TextShapeUtil'
import { getCursor } from '../hooks/useCursor' import { getCursor } from '../hooks/useCursor'
import { useEditor } from '../hooks/useEditor' import { useEditor } from '../hooks/useEditor'
import { useSelectionEvents } from '../hooks/useSelectionEvents' import { useSelectionEvents } from '../hooks/useSelectionEvents'
@ -94,11 +93,11 @@ export const SelectionFg = track(function SelectionFg() {
(showSelectionBounds && (showSelectionBounds &&
editor.isIn('select.resizing') && editor.isIn('select.resizing') &&
onlyShape && onlyShape &&
editor.isShapeOfType(onlyShape, TextShapeUtil)) editor.isShapeOfType<TLTextShape>(onlyShape, 'text'))
if ( if (
onlyShape && onlyShape &&
editor.isShapeOfType(onlyShape, EmbedShapeUtil) && editor.isShapeOfType<TLEmbedShape>(onlyShape, 'embed') &&
shouldDisplayBox && shouldDisplayBox &&
IS_FIREFOX IS_FIREFOX
) { ) {
@ -186,7 +185,7 @@ export const SelectionFg = track(function SelectionFg() {
shouldDisplayControls && shouldDisplayControls &&
isCoarsePointer && isCoarsePointer &&
onlyShape && onlyShape &&
editor.isShapeOfType(onlyShape, TextShapeUtil) && editor.isShapeOfType<TLTextShape>(onlyShape, 'text') &&
textHandleHeight * zoom >= 4 textHandleHeight * zoom >= 4
return ( return (

View file

@ -125,12 +125,10 @@ import { TextManager } from './managers/TextManager'
import { TickManager } from './managers/TickManager' import { TickManager } from './managers/TickManager'
import { UserPreferencesManager } from './managers/UserPreferencesManager' import { UserPreferencesManager } from './managers/UserPreferencesManager'
import { ShapeUtil, TLResizeMode } from './shapes/ShapeUtil' import { ShapeUtil, TLResizeMode } from './shapes/ShapeUtil'
import { ArrowShapeUtil } from './shapes/arrow/ArrowShapeUtil' import { ArrowInfo } from './shapes/arrow/arrow/arrow-types'
import { getCurvedArrowInfo } from './shapes/arrow/arrow/curved-arrow' import { getCurvedArrowInfo } from './shapes/arrow/arrow/curved-arrow'
import { getArrowTerminalsInArrowSpace, getIsArrowStraight } from './shapes/arrow/arrow/shared' import { getArrowTerminalsInArrowSpace, getIsArrowStraight } from './shapes/arrow/arrow/shared'
import { getStraightArrowInfo } from './shapes/arrow/arrow/straight-arrow' import { getStraightArrowInfo } from './shapes/arrow/arrow/straight-arrow'
import { FrameShapeUtil } from './shapes/frame/FrameShapeUtil'
import { GroupShapeUtil } from './shapes/group/GroupShapeUtil'
import { SvgExportContext, SvgExportDef } from './shapes/shared/SvgExportContext' import { SvgExportContext, SvgExportDef } from './shapes/shared/SvgExportContext'
import { RootState } from './tools/RootState' import { RootState } from './tools/RootState'
import { StateNode, TLStateNodeConstructor } from './tools/StateNode' import { StateNode, TLStateNodeConstructor } from './tools/StateNode'
@ -283,7 +281,7 @@ export class Editor extends EventEmitter<TLEventMap> {
this._updateDepth-- this._updateDepth--
} }
this.store.onAfterCreate = (record) => { this.store.onAfterCreate = (record) => {
if (record.typeName === 'shape' && this.isShapeOfType(record, ArrowShapeUtil)) { if (record.typeName === 'shape' && this.isShapeOfType<TLArrowShape>(record, 'arrow')) {
this._arrowDidUpdate(record) this._arrowDidUpdate(record)
} }
if (record.typeName === 'page') { if (record.typeName === 'page') {
@ -603,54 +601,29 @@ export class Editor extends EventEmitter<TLEventMap> {
*/ */
shapeUtils: { readonly [K in string]?: ShapeUtil<TLUnknownShape> } shapeUtils: { readonly [K in string]?: ShapeUtil<TLUnknownShape> }
/**
* Get a shape util by its definition.
*
* @example
* ```ts
* editor.getShapeUtil(ArrowShapeUtil)
* ```
*
* @param util - The shape util.
*
* @public
*/
getShapeUtil<C extends { new (...args: any[]): ShapeUtil<any>; type: string }>(
util: C
): InstanceType<C>
/** /**
* Get a shape util from a shape itself. * Get a shape util from a shape itself.
* *
* @example * @example
* ```ts * ```ts
* const util = editor.getShapeUtil(myShape) * const util = editor.getShapeUtil(myArrowShape)
* const util = editor.getShapeUtil<ArrowShapeUtil>(myShape) * const util = editor.getShapeUtil('arrow')
* const util = editor.getShapeUtil(ArrowShapeUtil) * const util = editor.getShapeUtil<TLArrowShape>(myArrowShape)
* const util = editor.getShapeUtil(TLArrowShape)('arrow')
* ``` * ```
* *
* @param shape - A shape or shape partial. * @param shape - A shape, shape partial, or shape type.
* *
* @public * @public
*/ */
getShapeUtil<S extends TLUnknownShape>(shape: S | TLShapePartial<S>): ShapeUtil<S> getShapeUtil<S extends TLUnknownShape>(shape: S | TLShapePartial<S>): ShapeUtil<S>
getShapeUtil<T extends ShapeUtil>(shapeUtilConstructor: { getShapeUtil<S extends TLUnknownShape>(type: S['type']): ShapeUtil<S>
type: T extends ShapeUtil<infer R> ? R['type'] : string getShapeUtil<T extends ShapeUtil>(type: T extends ShapeUtil<infer R> ? R['type'] : string): T
}): T { getShapeUtil(arg: string | { type: string }) {
const shapeUtil = getOwnProperty(this.shapeUtils, shapeUtilConstructor.type) as T | undefined const type = typeof arg === 'string' ? arg : arg.type
assert(shapeUtil, `No shape util found for type "${shapeUtilConstructor.type}"`) const shapeUtil = getOwnProperty(this.shapeUtils, type)
assert(shapeUtil, `No shape util found for type "${type}"`)
// does shapeUtilConstructor extends ShapeUtil? return shapeUtil
if (
'prototype' in shapeUtilConstructor &&
shapeUtilConstructor.prototype instanceof ShapeUtil
) {
assert(
shapeUtil instanceof (shapeUtilConstructor as any),
`Shape util found for type "${shapeUtilConstructor.type}" is not an instance of the provided constructor`
)
}
return shapeUtil as T
} }
/** @internal */ /** @internal */
@ -759,6 +732,19 @@ export class Editor extends EventEmitter<TLEventMap> {
this.store.put([{ ...arrow, props: { ...arrow.props, [handleId]: { type: 'point', x, y } } }]) this.store.put([{ ...arrow, props: { ...arrow.props, [handleId]: { type: 'point', x, y } } }])
} }
@computed
private get arrowInfoCache() {
return this.store.createComputedCache<ArrowInfo, TLArrowShape>('arrow infoCache', (shape) => {
return getIsArrowStraight(shape)
? getStraightArrowInfo(this, shape)
: getCurvedArrowInfo(this, shape)
})
}
getArrowInfo(shape: TLArrowShape) {
return this.arrowInfoCache.get(shape.id)
}
// private _shapeWillUpdate = (prev: TLShape, next: TLShape) => { // private _shapeWillUpdate = (prev: TLShape, next: TLShape) => {
// const update = this.getShapeUtil(next).onUpdate?.(prev, next) // const update = this.getShapeUtil(next).onUpdate?.(prev, next)
// return update ?? next // return update ?? next
@ -844,7 +830,7 @@ export class Editor extends EventEmitter<TLEventMap> {
/** @internal */ /** @internal */
private _shapeDidChange(prev: TLShape, next: TLShape) { private _shapeDidChange(prev: TLShape, next: TLShape) {
if (this.isShapeOfType(next, ArrowShapeUtil)) { if (this.isShapeOfType<TLArrowShape>(next, 'arrow')) {
this._arrowDidUpdate(next) this._arrowDidUpdate(next)
} }
@ -907,7 +893,7 @@ export class Editor extends EventEmitter<TLEventMap> {
filtered.length === 0 filtered.length === 0
? next?.focusLayerId ? next?.focusLayerId
: this.findCommonAncestor(compact(filtered.map((id) => this.getShapeById(id))), (shape) => : this.findCommonAncestor(compact(filtered.map((id) => this.getShapeById(id))), (shape) =>
this.isShapeOfType(shape, GroupShapeUtil) this.isShapeOfType<TLGroupShape>(shape, 'group')
) )
if (filtered.length !== next.selectedIds.length || nextFocusLayerId != next.focusLayerId) { if (filtered.length !== next.selectedIds.length || nextFocusLayerId != next.focusLayerId) {
@ -2181,7 +2167,7 @@ export class Editor extends EventEmitter<TLEventMap> {
if (focusedShape) { if (focusedShape) {
// If we have a focused layer, look for an ancestor of the focused shape that is a group // If we have a focused layer, look for an ancestor of the focused shape that is a group
const match = this.findAncestor(focusedShape, (shape) => const match = this.findAncestor(focusedShape, (shape) =>
this.isShapeOfType(shape, GroupShapeUtil) this.isShapeOfType<TLGroupShape>(shape, 'group')
) )
// If we have an ancestor that can become a focused layer, set it as the focused layer // If we have an ancestor that can become a focused layer, set it as the focused layer
this.setFocusLayer(match?.id ?? null) this.setFocusLayer(match?.id ?? null)
@ -4740,7 +4726,7 @@ export class Editor extends EventEmitter<TLEventMap> {
} }
const frameAncestors = this.getAncestorsById(shape.id).filter((shape) => const frameAncestors = this.getAncestorsById(shape.id).filter((shape) =>
this.isShapeOfType(shape, FrameShapeUtil) this.isShapeOfType<TLFrameShape>(shape, 'frame')
) )
if (frameAncestors.length === 0) return undefined if (frameAncestors.length === 0) return undefined
@ -5203,7 +5189,7 @@ export class Editor extends EventEmitter<TLEventMap> {
* *
* @example * @example
* ```ts * ```ts
* const isArrowShape = isShapeOfType(someShape, ArrowShapeUtil) * const isArrowShape = isShapeOfType<TLArrowShape>(someShape, 'arrow')
* ``` * ```
* *
* @param util - the TLShapeUtil constructor to test against * @param util - the TLShapeUtil constructor to test against
@ -5211,11 +5197,8 @@ export class Editor extends EventEmitter<TLEventMap> {
* *
* @public * @public
*/ */
isShapeOfType<T extends TLUnknownShape>( isShapeOfType<T extends TLUnknownShape>(shape: TLUnknownShape, type: T['type']): shape is T {
shape: TLUnknownShape, return shape.type === type
util: { new (...args: any): ShapeUtil<T>; type: string }
): shape is T {
return shape.type === util.type
} }
/** /**
@ -5595,7 +5578,7 @@ export class Editor extends EventEmitter<TLEventMap> {
let node = shape as TLShape | undefined let node = shape as TLShape | undefined
while (node) { while (node) {
if ( if (
this.isShapeOfType(node, GroupShapeUtil) && this.isShapeOfType<TLGroupShape>(node, 'group') &&
this.focusLayerId !== node.id && this.focusLayerId !== node.id &&
!this.hasAncestor(this.focusLayerShape, node.id) && !this.hasAncestor(this.focusLayerShape, node.id) &&
(filter?.(node) ?? true) (filter?.(node) ?? true)
@ -5761,10 +5744,10 @@ export class Editor extends EventEmitter<TLEventMap> {
let newShape: TLShape = deepCopy(shape) let newShape: TLShape = deepCopy(shape)
if ( if (
this.isShapeOfType(shape, ArrowShapeUtil) && this.isShapeOfType<TLArrowShape>(shape, 'arrow') &&
this.isShapeOfType(newShape, ArrowShapeUtil) this.isShapeOfType<TLArrowShape>(newShape, 'arrow')
) { ) {
const info = this.getShapeUtil(ArrowShapeUtil).getArrowInfo(shape) const info = this.getArrowInfo(shape)
let newStartShapeId: TLShapeId | undefined = undefined let newStartShapeId: TLShapeId | undefined = undefined
let newEndShapeId: TLShapeId | undefined = undefined let newEndShapeId: TLShapeId | undefined = undefined
@ -6084,7 +6067,7 @@ export class Editor extends EventEmitter<TLEventMap> {
shapes = compact( shapes = compact(
shapes shapes
.map((shape) => { .map((shape) => {
if (this.isShapeOfType(shape, GroupShapeUtil)) { if (this.isShapeOfType<TLGroupShape>(shape, 'group')) {
return this.getSortedChildIds(shape.id).map((id) => this.getShapeById(id)) return this.getSortedChildIds(shape.id).map((id) => this.getShapeById(id))
} }
@ -6144,7 +6127,7 @@ export class Editor extends EventEmitter<TLEventMap> {
const shapes = compact(ids.map((id) => this.getShapeById(id))).filter((shape) => { const shapes = compact(ids.map((id) => this.getShapeById(id))).filter((shape) => {
if (!shape) return false if (!shape) return false
if (this.isShapeOfType(shape, ArrowShapeUtil)) { if (this.isShapeOfType<TLArrowShape>(shape, 'arrow')) {
if (shape.props.start.type === 'binding' || shape.props.end.type === 'binding') { if (shape.props.start.type === 'binding' || shape.props.end.type === 'binding') {
return false return false
} }
@ -6275,7 +6258,7 @@ export class Editor extends EventEmitter<TLEventMap> {
.filter((shape) => { .filter((shape) => {
if (!shape) return false if (!shape) return false
if (this.isShapeOfType(shape, ArrowShapeUtil)) { if (this.isShapeOfType<TLArrowShape>(shape, 'arrow')) {
if (shape.props.start.type === 'binding' || shape.props.end.type === 'binding') { if (shape.props.start.type === 'binding' || shape.props.end.type === 'binding') {
return false return false
} }
@ -7319,7 +7302,7 @@ export class Editor extends EventEmitter<TLEventMap> {
const groups: TLGroupShape[] = [] const groups: TLGroupShape[] = []
shapes.forEach((shape) => { shapes.forEach((shape) => {
if (this.isShapeOfType(shape, GroupShapeUtil)) { if (this.isShapeOfType<TLGroupShape>(shape, 'group')) {
groups.push(shape) groups.push(shape)
} else { } else {
idsToSelect.add(shape.id) idsToSelect.add(shape.id)
@ -7568,7 +7551,7 @@ export class Editor extends EventEmitter<TLEventMap> {
* @internal * @internal
*/ */
private _extractSharedStyles(shape: TLShape, sharedStyleMap: SharedStyleMap) { private _extractSharedStyles(shape: TLShape, sharedStyleMap: SharedStyleMap) {
if (this.isShapeOfType(shape, GroupShapeUtil)) { if (this.isShapeOfType<TLGroupShape>(shape, 'group')) {
// For groups, ignore the styles of the group shape and instead include the styles of the // For groups, ignore the styles of the group shape and instead include the styles of the
// group's children. These are the shapes that would have their styles changed if the // group's children. These are the shapes that would have their styles changed if the
// user called `setStyle` on the current selection. // user called `setStyle` on the current selection.
@ -7673,7 +7656,7 @@ export class Editor extends EventEmitter<TLEventMap> {
// For groups, ignore the opacity of the group shape and instead include // For groups, ignore the opacity of the group shape and instead include
// the opacity of the group's children. These are the shapes that would have // the opacity of the group's children. These are the shapes that would have
// their opacity changed if the user called `setOpacity` on the current selection. // their opacity changed if the user called `setOpacity` on the current selection.
if (this.isShapeOfType(shape, GroupShapeUtil)) { if (this.isShapeOfType<TLGroupShape>(shape, 'group')) {
for (const childId of this.getSortedChildIds(shape.id)) { for (const childId of this.getSortedChildIds(shape.id)) {
addShape(childId) addShape(childId)
} }
@ -7727,7 +7710,7 @@ export class Editor extends EventEmitter<TLEventMap> {
const addShapeById = (id: TLShape['id']) => { const addShapeById = (id: TLShape['id']) => {
const shape = this.getShapeById(id) const shape = this.getShapeById(id)
if (!shape) return if (!shape) return
if (this.isShapeOfType(shape, GroupShapeUtil)) { if (this.isShapeOfType<TLGroupShape>(shape, 'group')) {
const childIds = this.getSortedChildIds(id) const childIds = this.getSortedChildIds(id)
for (const childId of childIds) { for (const childId of childIds) {
addShapeById(childId) addShapeById(childId)
@ -7799,7 +7782,7 @@ export class Editor extends EventEmitter<TLEventMap> {
const addShapeById = (id: TLShape['id']) => { const addShapeById = (id: TLShape['id']) => {
const shape = this.getShapeById(id) const shape = this.getShapeById(id)
if (!shape) return if (!shape) return
if (this.isShapeOfType(shape, GroupShapeUtil)) { if (this.isShapeOfType<TLGroupShape>(shape, 'group')) {
const childIds = this.getSortedChildIds(id) const childIds = this.getSortedChildIds(id)
for (const childId of childIds) { for (const childId of childIds) {
addShapeById(childId) addShapeById(childId)
@ -7883,14 +7866,14 @@ export class Editor extends EventEmitter<TLEventMap> {
shape = structuredClone(shape) as typeof shape shape = structuredClone(shape) as typeof shape
if (this.isShapeOfType(shape, ArrowShapeUtil)) { if (this.isShapeOfType<TLArrowShape>(shape, 'arrow')) {
const startBindingId = const startBindingId =
shape.props.start.type === 'binding' ? shape.props.start.boundShapeId : undefined shape.props.start.type === 'binding' ? shape.props.start.boundShapeId : undefined
const endBindingId = const endBindingId =
shape.props.end.type === 'binding' ? shape.props.end.boundShapeId : undefined shape.props.end.type === 'binding' ? shape.props.end.boundShapeId : undefined
const info = this.getShapeUtil(ArrowShapeUtil).getArrowInfo(shape) const info = this.getArrowInfo(shape)
if (shape.props.start.type === 'binding') { if (shape.props.start.type === 'binding') {
if (!shapes.some((s) => s.id === startBindingId)) { if (!shapes.some((s) => s.id === startBindingId)) {
@ -8034,7 +8017,7 @@ export class Editor extends EventEmitter<TLEventMap> {
for (const shape of this.selectedShapes) { for (const shape of this.selectedShapes) {
if (lowestDepth === 0) break if (lowestDepth === 0) break
const isFrame = this.isShapeOfType(shape, FrameShapeUtil) const isFrame = this.isShapeOfType<TLFrameShape>(shape, 'frame')
const ancestors = this.getAncestors(shape) const ancestors = this.getAncestors(shape)
if (isFrame) ancestors.push(shape) if (isFrame) ancestors.push(shape)
@ -8073,8 +8056,8 @@ export class Editor extends EventEmitter<TLEventMap> {
if (rootShapeIds.length === 1) { if (rootShapeIds.length === 1) {
const rootShape = shapes.find((s) => s.id === rootShapeIds[0])! const rootShape = shapes.find((s) => s.id === rootShapeIds[0])!
if ( if (
this.isShapeOfType(parent, FrameShapeUtil) && this.isShapeOfType<TLFrameShape>(parent, 'frame') &&
this.isShapeOfType(rootShape, FrameShapeUtil) && this.isShapeOfType<TLFrameShape>(rootShape, 'frame') &&
rootShape.props.w === parent?.props.w && rootShape.props.w === parent?.props.w &&
rootShape.props.h === parent?.props.h rootShape.props.h === parent?.props.h
) { ) {
@ -8130,7 +8113,7 @@ export class Editor extends EventEmitter<TLEventMap> {
index = getIndexAbove(index) index = getIndexAbove(index)
} }
if (this.isShapeOfType(newShape, ArrowShapeUtil)) { if (this.isShapeOfType<TLArrowShape>(newShape, 'arrow')) {
if (newShape.props.start.type === 'binding') { if (newShape.props.start.type === 'binding') {
const mappedId = idMap.get(newShape.props.start.boundShapeId) const mappedId = idMap.get(newShape.props.start.boundShapeId)
newShape.props.start = mappedId newShape.props.start = mappedId
@ -8281,11 +8264,11 @@ export class Editor extends EventEmitter<TLEventMap> {
if (rootShapes.length === 1) { if (rootShapes.length === 1) {
const onlyRoot = rootShapes[0] as TLFrameShape const onlyRoot = rootShapes[0] as TLFrameShape
// If the old bounds are in the viewport... // If the old bounds are in the viewport...
if (this.isShapeOfType(onlyRoot, FrameShapeUtil)) { if (this.isShapeOfType<TLFrameShape>(onlyRoot, 'frame')) {
while ( while (
this.getShapesAtPoint(point).some( this.getShapesAtPoint(point).some(
(shape) => (shape) =>
this.isShapeOfType(shape, FrameShapeUtil) && this.isShapeOfType<TLFrameShape>(shape, 'frame') &&
shape.props.w === onlyRoot.props.w && shape.props.w === onlyRoot.props.w &&
shape.props.h === onlyRoot.props.h shape.props.h === onlyRoot.props.h
) )
@ -8398,7 +8381,7 @@ export class Editor extends EventEmitter<TLEventMap> {
if (!bbox) return if (!bbox) return
const singleFrameShapeId = const singleFrameShapeId =
ids.length === 1 && this.isShapeOfType(this.getShapeById(ids[0])!, FrameShapeUtil) ids.length === 1 && this.isShapeOfType<TLFrameShape>(this.getShapeById(ids[0])!, 'frame')
? ids[0] ? ids[0]
: null : null
if (!singleFrameShapeId) { if (!singleFrameShapeId) {
@ -8474,7 +8457,7 @@ export class Editor extends EventEmitter<TLEventMap> {
const shape = this.getShapeById(id)! const shape = this.getShapeById(id)!
if (this.isShapeOfType(shape, GroupShapeUtil)) return [] if (this.isShapeOfType<TLGroupShape>(shape, 'group')) return []
const util = this.getShapeUtil(shape) const util = this.getShapeUtil(shape)

View file

@ -1,7 +1,6 @@
import { TLArrowShape, TLShapeId } from '@tldraw/tlschema' import { TLArrowShape, TLGeoShape, TLShapeId } from '@tldraw/tlschema'
import { TestEditor } from '../../test/TestEditor' import { TestEditor } from '../../test/TestEditor'
import { TL } from '../../test/jsx' import { TL } from '../../test/jsx'
import { GeoShapeUtil } from '../shapes/geo/GeoShapeUtil'
let editor: TestEditor let editor: TestEditor
@ -185,7 +184,7 @@ describe('arrowBindingsIndex', () => {
editor.duplicateShapes() editor.duplicateShapes()
const [box1Clone, box2Clone] = editor.selectedShapes const [box1Clone, box2Clone] = editor.selectedShapes
.filter((shape) => editor.isShapeOfType(shape, GeoShapeUtil)) .filter((shape) => editor.isShapeOfType<TLGeoShape>(shape, 'geo'))
.sort((a, b) => a.x - b.x) .sort((a, b) => a.x - b.x)
expect(editor.getArrowsBoundTo(box2Clone.id)).toHaveLength(3) expect(editor.getArrowsBoundTo(box2Clone.id)).toHaveLength(3)

View file

@ -1,7 +1,6 @@
import { Computed, RESET_VALUE, computed, isUninitialized } from '@tldraw/state' import { Computed, RESET_VALUE, computed, isUninitialized } from '@tldraw/state'
import { TLArrowShape, TLShape, TLShapeId } from '@tldraw/tlschema' import { TLArrowShape, TLShape, TLShapeId } from '@tldraw/tlschema'
import { Editor } from '../Editor' import { Editor } from '../Editor'
import { ArrowShapeUtil } from '../shapes/arrow/ArrowShapeUtil'
export type TLArrowBindingsIndex = Record< export type TLArrowBindingsIndex = Record<
TLShapeId, TLShapeId,
@ -83,7 +82,7 @@ export const arrowBindingsIndex = (editor: Editor): Computed<TLArrowBindingsInde
for (const changes of diff) { for (const changes of diff) {
for (const newShape of Object.values(changes.added)) { for (const newShape of Object.values(changes.added)) {
if (editor.isShapeOfType(newShape, ArrowShapeUtil)) { if (editor.isShapeOfType<TLArrowShape>(newShape, 'arrow')) {
const { start, end } = newShape.props const { start, end } = newShape.props
if (start.type === 'binding') { if (start.type === 'binding') {
addBinding(start.boundShapeId, newShape.id, 'start') addBinding(start.boundShapeId, newShape.id, 'start')
@ -96,8 +95,8 @@ export const arrowBindingsIndex = (editor: Editor): Computed<TLArrowBindingsInde
for (const [prev, next] of Object.values(changes.updated) as [TLShape, TLShape][]) { for (const [prev, next] of Object.values(changes.updated) as [TLShape, TLShape][]) {
if ( if (
!editor.isShapeOfType(prev, ArrowShapeUtil) || !editor.isShapeOfType<TLArrowShape>(prev, 'arrow') ||
!editor.isShapeOfType(next, ArrowShapeUtil) !editor.isShapeOfType<TLArrowShape>(next, 'arrow')
) )
continue continue
@ -124,7 +123,7 @@ export const arrowBindingsIndex = (editor: Editor): Computed<TLArrowBindingsInde
} }
for (const prev of Object.values(changes.removed)) { for (const prev of Object.values(changes.removed)) {
if (editor.isShapeOfType(prev, ArrowShapeUtil)) { if (editor.isShapeOfType<TLArrowShape>(prev, 'arrow')) {
const { start, end } = prev.props const { start, end } = prev.props
if (start.type === 'binding') { if (start.type === 'binding') {
removingBinding(start.boundShapeId, prev.id, 'start') removingBinding(start.boundShapeId, prev.id, 'start')

View file

@ -27,7 +27,6 @@ import { getEmbedInfo } from '../../utils/embeds'
import { Editor } from '../Editor' import { Editor } from '../Editor'
import { FONT_FAMILIES, FONT_SIZES, TEXT_PROPS } from '../shapes/shared/default-shape-constants' import { FONT_FAMILIES, FONT_SIZES, TEXT_PROPS } from '../shapes/shared/default-shape-constants'
import { INDENT } from '../shapes/text/TextHelpers' import { INDENT } from '../shapes/text/TextHelpers'
import { TextShapeUtil } from '../shapes/text/TextShapeUtil'
/** @public */ /** @public */
export type TLExternalContent = export type TLExternalContent =
@ -235,7 +234,7 @@ export class ExternalContentManager {
const p = const p =
point ?? (editor.inputs.shiftKey ? editor.inputs.currentPagePoint : editor.viewportPageCenter) point ?? (editor.inputs.shiftKey ? editor.inputs.currentPagePoint : editor.viewportPageCenter)
const defaultProps = editor.getShapeUtil(TextShapeUtil).getDefaultProps() const defaultProps = editor.getShapeUtil<TLTextShape>('text').getDefaultProps()
const textToPaste = stripTrailingWhitespace( const textToPaste = stripTrailingWhitespace(
stripCommonMinimumIndentation(replaceTabsWithSpaces(text)) stripCommonMinimumIndentation(replaceTabsWithSpaces(text))

View file

@ -12,11 +12,10 @@ import {
VecLike, VecLike,
} from '@tldraw/primitives' } from '@tldraw/primitives'
import { atom, computed, EMPTY_ARRAY } from '@tldraw/state' import { atom, computed, EMPTY_ARRAY } from '@tldraw/state'
import { TLParentId, TLShape, TLShapeId, Vec2dModel } from '@tldraw/tlschema' import { TLGroupShape, TLParentId, TLShape, TLShapeId, Vec2dModel } from '@tldraw/tlschema'
import { dedupe, deepCopy } from '@tldraw/utils' import { dedupe, deepCopy } from '@tldraw/utils'
import { uniqueId } from '../../utils/data' import { uniqueId } from '../../utils/data'
import type { Editor } from '../Editor' import type { Editor } from '../Editor'
import { GroupShapeUtil } from '../shapes/group/GroupShapeUtil'
export type PointsSnapLine = { export type PointsSnapLine = {
id: string id: string
@ -266,7 +265,7 @@ export class SnapManager {
const pageBounds = editor.getPageBoundsById(childId) const pageBounds = editor.getPageBoundsById(childId)
if (!(pageBounds && renderingBounds.includes(pageBounds))) continue if (!(pageBounds && renderingBounds.includes(pageBounds))) continue
// Snap to children of groups but not group itself // Snap to children of groups but not group itself
if (editor.isShapeOfType(childShape, GroupShapeUtil)) { if (editor.isShapeOfType<TLGroupShape>(childShape, 'group')) {
collectSnappableShapesFromParent(childId) collectSnappableShapesFromParent(childId)
continue continue
} }

View file

@ -1,5 +1,4 @@
import { StateNode } from '../../tools/StateNode' import { StateNode } from '../../tools/StateNode'
import { ArrowShapeUtil } from './ArrowShapeUtil'
import { Idle } from './toolStates/Idle' import { Idle } from './toolStates/Idle'
import { Pointing } from './toolStates/Pointing' import { Pointing } from './toolStates/Pointing'
@ -8,5 +7,5 @@ export class ArrowShapeTool extends StateNode {
static initial = 'idle' static initial = 'idle'
static children = () => [Idle, Pointing] static children = () => [Idle, Pointing]
shapeType = ArrowShapeUtil shapeType = 'arrow'
} }

View file

@ -2,7 +2,6 @@ import { TAU } from '@tldraw/primitives'
import { TLArrowShape, TLArrowShapeTerminal, TLShapeId, createShapeId } from '@tldraw/tlschema' import { TLArrowShape, TLArrowShapeTerminal, TLShapeId, createShapeId } from '@tldraw/tlschema'
import { assert } from '@tldraw/utils' import { assert } from '@tldraw/utils'
import { TestEditor } from '../../../test/TestEditor' import { TestEditor } from '../../../test/TestEditor'
import { ArrowShapeUtil } from './ArrowShapeUtil'
let editor: TestEditor let editor: TestEditor
@ -299,7 +298,7 @@ describe('Other cases when arrow are moved', () => {
editor.setSelectedTool('arrow').pointerDown(1000, 1000).pointerMove(50, 350).pointerUp(50, 350) editor.setSelectedTool('arrow').pointerDown(1000, 1000).pointerMove(50, 350).pointerUp(50, 350)
let arrow = editor.shapesArray[editor.shapesArray.length - 1] let arrow = editor.shapesArray[editor.shapesArray.length - 1]
assert(editor.isShapeOfType(arrow, ArrowShapeUtil)) assert(editor.isShapeOfType<TLArrowShape>(arrow, 'arrow'))
assert(arrow.props.end.type === 'binding') assert(arrow.props.end.type === 'binding')
expect(arrow.props.end.boundShapeId).toBe(ids.box3) expect(arrow.props.end.boundShapeId).toBe(ids.box3)
@ -308,7 +307,7 @@ describe('Other cases when arrow are moved', () => {
// arrow should still be bound to box3 // arrow should still be bound to box3
arrow = editor.getShapeById(arrow.id)! arrow = editor.getShapeById(arrow.id)!
assert(editor.isShapeOfType(arrow, ArrowShapeUtil)) assert(editor.isShapeOfType<TLArrowShape>(arrow, 'arrow'))
assert(arrow.props.end.type === 'binding') assert(arrow.props.end.type === 'binding')
expect(arrow.props.end.boundShapeId).toBe(ids.box3) expect(arrow.props.end.boundShapeId).toBe(ids.box3)
}) })

View file

@ -52,19 +52,10 @@ import {
import { getPerfectDashProps } from '../shared/getPerfectDashProps' import { getPerfectDashProps } from '../shared/getPerfectDashProps'
import { getShapeFillSvg, ShapeFill, useDefaultColorTheme } from '../shared/ShapeFill' import { getShapeFillSvg, ShapeFill, useDefaultColorTheme } from '../shared/ShapeFill'
import { SvgExportContext } from '../shared/SvgExportContext' import { SvgExportContext } from '../shared/SvgExportContext'
import { ArrowInfo } from './arrow/arrow-types'
import { getArrowheadPathForType } from './arrow/arrowheads' import { getArrowheadPathForType } from './arrow/arrowheads'
import { import { getCurvedArrowHandlePath, getSolidCurvedArrowPath } from './arrow/curved-arrow'
getCurvedArrowHandlePath, import { getArrowTerminalsInArrowSpace } from './arrow/shared'
getCurvedArrowInfo, import { getSolidStraightArrowPath, getStraightArrowHandlePath } from './arrow/straight-arrow'
getSolidCurvedArrowPath,
} from './arrow/curved-arrow'
import { getArrowTerminalsInArrowSpace, getIsArrowStraight } from './arrow/shared'
import {
getSolidStraightArrowPath,
getStraightArrowHandlePath,
getStraightArrowInfo,
} from './arrow/straight-arrow'
import { ArrowTextLabel } from './components/ArrowTextLabel' import { ArrowTextLabel } from './components/ArrowTextLabel'
let globalRenderIndex = 0 let globalRenderIndex = 0
@ -108,7 +99,7 @@ export class ArrowShapeUtil extends ShapeUtil<TLArrowShape> {
} }
getOutlineWithoutLabel(shape: TLArrowShape): Vec2d[] { getOutlineWithoutLabel(shape: TLArrowShape): Vec2d[] {
const info = this.getArrowInfo(shape) const info = this.editor.getArrowInfo(shape)
if (!info) { if (!info) {
return [] return []
@ -213,24 +204,8 @@ export class ArrowShapeUtil extends ShapeUtil<TLArrowShape> {
return EMPTY_ARRAY return EMPTY_ARRAY
} }
@computed
private get infoCache() {
return this.editor.store.createComputedCache<ArrowInfo, TLArrowShape>(
'arrow infoCache',
(shape) => {
return getIsArrowStraight(shape)
? getStraightArrowInfo(this.editor, shape)
: getCurvedArrowInfo(this.editor, shape)
}
)
}
getArrowInfo(shape: TLArrowShape) {
return this.infoCache.get(shape.id)
}
getHandles(shape: TLArrowShape): TLHandle[] { getHandles(shape: TLArrowShape): TLHandle[] {
const info = this.infoCache.get(shape.id)! const info = this.editor.getArrowInfo(shape)!
return [ return [
{ {
id: 'start', id: 'start',
@ -581,7 +556,7 @@ export class ArrowShapeUtil extends ShapeUtil<TLArrowShape> {
'arrow.dragging' 'arrow.dragging'
) && !this.editor.isReadOnly ) && !this.editor.isReadOnly
const info = this.getArrowInfo(shape) const info = this.editor.getArrowInfo(shape)
const bounds = this.editor.getBounds(shape) const bounds = this.editor.getBounds(shape)
const labelSize = this.getLabelBounds(shape) const labelSize = this.getLabelBounds(shape)
@ -760,7 +735,7 @@ export class ArrowShapeUtil extends ShapeUtil<TLArrowShape> {
indicator(shape: TLArrowShape) { indicator(shape: TLArrowShape) {
const { start, end } = getArrowTerminalsInArrowSpace(this.editor, shape) const { start, end } = getArrowTerminalsInArrowSpace(this.editor, shape)
const info = this.getArrowInfo(shape) const info = this.editor.getArrowInfo(shape)
const bounds = this.editor.getBounds(shape) const bounds = this.editor.getBounds(shape)
const labelSize = this.getLabelBounds(shape) const labelSize = this.getLabelBounds(shape)
@ -854,7 +829,7 @@ export class ArrowShapeUtil extends ShapeUtil<TLArrowShape> {
@computed get labelBoundsCache(): ComputedCache<Box2d | null, TLArrowShape> { @computed get labelBoundsCache(): ComputedCache<Box2d | null, TLArrowShape> {
return this.editor.store.createComputedCache('labelBoundsCache', (shape) => { return this.editor.store.createComputedCache('labelBoundsCache', (shape) => {
const info = this.getArrowInfo(shape) const info = this.editor.getArrowInfo(shape)
const bounds = this.editor.getBounds(shape) const bounds = this.editor.getBounds(shape)
const { text, font, size } = shape.props const { text, font, size } = shape.props
@ -938,7 +913,7 @@ export class ArrowShapeUtil extends ShapeUtil<TLArrowShape> {
const color = theme[shape.props.color].solid const color = theme[shape.props.color].solid
const info = this.getArrowInfo(shape) const info = this.editor.getArrowInfo(shape)
const strokeWidth = STROKE_SIZES[shape.props.size] const strokeWidth = STROKE_SIZES[shape.props.size]

View file

@ -1,5 +1,4 @@
import { createShapeId, TLArrowShape } from '@tldraw/tlschema' import { createShapeId, TLArrowShape } from '@tldraw/tlschema'
import { ArrowShapeUtil } from '../../../shapes/arrow/ArrowShapeUtil'
import { StateNode } from '../../../tools/StateNode' import { StateNode } from '../../../tools/StateNode'
import { TLEventHandlers } from '../../../types/event-types' import { TLEventHandlers } from '../../../types/event-types'
@ -43,7 +42,7 @@ export class Pointing extends StateNode {
}, },
]) ])
const util = this.editor.getShapeUtil(ArrowShapeUtil) const util = this.editor.getShapeUtil<TLArrowShape>('arrow')
const shape = this.editor.getShapeById<TLArrowShape>(id) const shape = this.editor.getShapeById<TLArrowShape>(id)
if (!shape) return if (!shape) return
@ -90,7 +89,7 @@ export class Pointing extends StateNode {
} }
if (!this.didTimeout) { if (!this.didTimeout) {
const util = this.editor.getShapeUtil(ArrowShapeUtil) const util = this.editor.getShapeUtil<TLArrowShape>('arrow')
const shape = this.editor.getShapeById<TLArrowShape>(this.shape.id) const shape = this.editor.getShapeById<TLArrowShape>(this.shape.id)
if (!shape) return if (!shape) return

View file

@ -1,5 +1,4 @@
import { StateNode } from '../../tools/StateNode' import { StateNode } from '../../tools/StateNode'
import { DrawShapeUtil } from './DrawShapeUtil'
import { Drawing } from './toolStates/Drawing' import { Drawing } from './toolStates/Drawing'
import { Idle } from './toolStates/Idle' import { Idle } from './toolStates/Idle'
@ -8,7 +7,7 @@ export class DrawShapeTool extends StateNode {
static initial = 'idle' static initial = 'idle'
static children = () => [Idle, Drawing] static children = () => [Idle, Drawing]
shapeType = DrawShapeUtil shapeType = 'draw'
onExit = () => { onExit = () => {
const drawingState = this.children!['drawing'] as Drawing const drawingState = this.children!['drawing'] as Drawing

View file

@ -13,9 +13,7 @@ import { DRAG_DISTANCE } from '../../../../constants'
import { uniqueId } from '../../../../utils/data' import { uniqueId } from '../../../../utils/data'
import { StateNode } from '../../../tools/StateNode' import { StateNode } from '../../../tools/StateNode'
import { TLEventHandlers, TLPointerEventInfo } from '../../../types/event-types' import { TLEventHandlers, TLPointerEventInfo } from '../../../types/event-types'
import { HighlightShapeUtil } from '../../highlight/HighlightShapeUtil'
import { STROKE_SIZES } from '../../shared/default-shape-constants' import { STROKE_SIZES } from '../../shared/default-shape-constants'
import { DrawShapeUtil } from '../DrawShapeUtil'
type DrawableShape = TLDrawShape | TLHighlightShape type DrawableShape = TLDrawShape | TLHighlightShape
@ -26,7 +24,7 @@ export class Drawing extends StateNode {
initialShape?: DrawableShape initialShape?: DrawableShape
shapeType = this.parent.id === 'highlight' ? HighlightShapeUtil : DrawShapeUtil shapeType = this.parent.id === 'highlight' ? ('highlight' as const) : ('draw' as const)
util = this.editor.getShapeUtil(this.shapeType) util = this.editor.getShapeUtil(this.shapeType)
@ -138,7 +136,7 @@ export class Drawing extends StateNode {
} }
canClose() { canClose() {
return this.shapeType.type !== 'highlight' return this.shapeType !== 'highlight'
} }
getIsClosed(segments: TLDrawShapeSegment[], size: TLDefaultSizeStyle) { getIsClosed(segments: TLDrawShapeSegment[], size: TLDefaultSizeStyle) {
@ -219,7 +217,7 @@ export class Drawing extends StateNode {
const shapePartial: TLShapePartial<DrawableShape> = { const shapePartial: TLShapePartial<DrawableShape> = {
id: shape.id, id: shape.id,
type: this.shapeType.type, type: this.shapeType,
props: { props: {
segments, segments,
}, },
@ -246,7 +244,7 @@ export class Drawing extends StateNode {
this.editor.createShapes<DrawableShape>([ this.editor.createShapes<DrawableShape>([
{ {
id, id,
type: this.shapeType.type, type: this.shapeType,
x: originPagePoint.x, x: originPagePoint.x,
y: originPagePoint.y, y: originPagePoint.y,
props: { props: {
@ -349,7 +347,7 @@ export class Drawing extends StateNode {
const shapePartial: TLShapePartial<DrawableShape> = { const shapePartial: TLShapePartial<DrawableShape> = {
id, id,
type: this.shapeType.type, type: this.shapeType,
props: { props: {
segments: [...segments, newSegment], segments: [...segments, newSegment],
}, },
@ -409,7 +407,7 @@ export class Drawing extends StateNode {
const shapePartial: TLShapePartial<DrawableShape> = { const shapePartial: TLShapePartial<DrawableShape> = {
id, id,
type: this.shapeType.type, type: this.shapeType,
props: { props: {
segments: finalSegments, segments: finalSegments,
}, },
@ -551,7 +549,7 @@ export class Drawing extends StateNode {
const shapePartial: TLShapePartial<DrawableShape> = { const shapePartial: TLShapePartial<DrawableShape> = {
id, id,
type: this.shapeType.type, type: this.shapeType,
props: { props: {
segments: newSegments, segments: newSegments,
}, },
@ -596,7 +594,7 @@ export class Drawing extends StateNode {
const shapePartial: TLShapePartial<DrawableShape> = { const shapePartial: TLShapePartial<DrawableShape> = {
id, id,
type: this.shapeType.type, type: this.shapeType,
props: { props: {
segments: newSegments, segments: newSegments,
}, },
@ -613,7 +611,7 @@ export class Drawing extends StateNode {
// Set a maximum length for the lines array; after 200 points, complete the line. // Set a maximum length for the lines array; after 200 points, complete the line.
if (newPoints.length > 500) { if (newPoints.length > 500) {
this.editor.updateShapes([{ id, type: this.shapeType.type, props: { isComplete: true } }]) this.editor.updateShapes([{ id, type: this.shapeType, props: { isComplete: true } }])
const { currentPagePoint } = this.editor.inputs const { currentPagePoint } = this.editor.inputs
@ -622,7 +620,7 @@ export class Drawing extends StateNode {
this.editor.createShapes<DrawableShape>([ this.editor.createShapes<DrawableShape>([
{ {
id: newShapeId, id: newShapeId,
type: this.shapeType.type, type: this.shapeType,
x: toFixed(currentPagePoint.x), x: toFixed(currentPagePoint.x),
y: toFixed(currentPagePoint.y), y: toFixed(currentPagePoint.y),
props: { props: {

View file

@ -84,7 +84,7 @@ export class EmbedShapeUtil extends BaseBoxShapeUtil<TLEmbedShape> {
if (editingId && hoveredId !== editingId) { if (editingId && hoveredId !== editingId) {
const editingShape = this.editor.getShapeById(editingId) const editingShape = this.editor.getShapeById(editingId)
if (editingShape && this.editor.isShapeOfType(editingShape, EmbedShapeUtil)) { if (editingShape && this.editor.isShapeOfType<TLEmbedShape>(editingShape, 'embed')) {
return true return true
} }
} }

View file

@ -1,9 +1,8 @@
import { BaseBoxShapeTool } from '../../tools/BaseBoxShapeTool/BaseBoxShapeTool' import { BaseBoxShapeTool } from '../../tools/BaseBoxShapeTool/BaseBoxShapeTool'
import { FrameShapeUtil } from './FrameShapeUtil'
export class FrameShapeTool extends BaseBoxShapeTool { export class FrameShapeTool extends BaseBoxShapeTool {
static override id = 'frame' static override id = 'frame'
static initial = 'idle' static initial = 'idle'
shapeType = FrameShapeUtil shapeType = 'frame'
} }

View file

@ -1,10 +1,15 @@
import { canolicalizeRotation, SelectionEdge, toDomPrecision } from '@tldraw/primitives' import { canolicalizeRotation, SelectionEdge, toDomPrecision } from '@tldraw/primitives'
import { getDefaultColorTheme, TLFrameShape, TLShape, TLShapeId } from '@tldraw/tlschema' import {
getDefaultColorTheme,
TLFrameShape,
TLGroupShape,
TLShape,
TLShapeId,
} from '@tldraw/tlschema'
import { last } from '@tldraw/utils' import { last } from '@tldraw/utils'
import { SVGContainer } from '../../../components/SVGContainer' import { SVGContainer } from '../../../components/SVGContainer'
import { defaultEmptyAs } from '../../../utils/string' import { defaultEmptyAs } from '../../../utils/string'
import { BaseBoxShapeUtil } from '../BaseBoxShapeUtil' import { BaseBoxShapeUtil } from '../BaseBoxShapeUtil'
import { GroupShapeUtil } from '../group/GroupShapeUtil'
import { TLOnResizeEndHandler } from '../ShapeUtil' import { TLOnResizeEndHandler } from '../ShapeUtil'
import { createTextSvgElementFromSpans } from '../shared/createTextSvgElementFromSpans' import { createTextSvgElementFromSpans } from '../shared/createTextSvgElementFromSpans'
import { useDefaultColorTheme } from '../shared/ShapeFill' import { useDefaultColorTheme } from '../shared/ShapeFill'
@ -173,7 +178,7 @@ export class FrameShapeUtil extends BaseBoxShapeUtil<TLFrameShape> {
onDragShapesOut = (_shape: TLFrameShape, shapes: TLShape[]): void => { onDragShapesOut = (_shape: TLFrameShape, shapes: TLShape[]): void => {
const parent = this.editor.getShapeById(_shape.parentId) const parent = this.editor.getShapeById(_shape.parentId)
const isInGroup = parent && this.editor.isShapeOfType(parent, GroupShapeUtil) const isInGroup = parent && this.editor.isShapeOfType<TLGroupShape>(parent, 'group')
// If frame is in a group, keep the shape // If frame is in a group, keep the shape
// moved out in that group // moved out in that group

View file

@ -1,5 +1,4 @@
import { StateNode } from '../../tools/StateNode' import { StateNode } from '../../tools/StateNode'
import { GeoShapeUtil } from './GeoShapeUtil'
import { Idle } from './toolStates/Idle' import { Idle } from './toolStates/Idle'
import { Pointing } from './toolStates/Pointing' import { Pointing } from './toolStates/Pointing'
@ -8,5 +7,5 @@ export class GeoShapeTool extends StateNode {
static initial = 'idle' static initial = 'idle'
static children = () => [Idle, Pointing] static children = () => [Idle, Pointing]
shapeType = GeoShapeUtil shapeType = 'geo'
} }

View file

@ -1,6 +1,6 @@
import { TLGeoShape } from '@tldraw/tlschema'
import { StateNode } from '../../../tools/StateNode' import { StateNode } from '../../../tools/StateNode'
import { TLEventHandlers } from '../../../types/event-types' import { TLEventHandlers } from '../../../types/event-types'
import { GeoShapeUtil } from '../GeoShapeUtil'
export class Idle extends StateNode { export class Idle extends StateNode {
static override id = 'idle' static override id = 'idle'
@ -16,7 +16,7 @@ export class Idle extends StateNode {
onKeyUp: TLEventHandlers['onKeyUp'] = (info) => { onKeyUp: TLEventHandlers['onKeyUp'] = (info) => {
if (info.key === 'Enter') { if (info.key === 'Enter') {
const shape = this.editor.onlySelectedShape const shape = this.editor.onlySelectedShape
if (shape && this.editor.isShapeOfType(shape, GeoShapeUtil)) { if (shape && this.editor.isShapeOfType<TLGeoShape>(shape, 'geo')) {
// todo: ensure that this only works with the most recently created shape, not just any geo shape that happens to be selected at the time // todo: ensure that this only works with the most recently created shape, not just any geo shape that happens to be selected at the time
this.editor.mark('editing shape') this.editor.mark('editing shape')
this.editor.setEditingId(shape.id) this.editor.setEditingId(shape.id)

View file

@ -58,7 +58,7 @@ export class GroupShapeUtil extends ShapeUtil<TLGroupShape> {
hintingIds.some( hintingIds.some(
(id) => (id) =>
id !== shape.id && id !== shape.id &&
this.editor.isShapeOfType(this.editor.getShapeById(id)!, GroupShapeUtil) this.editor.isShapeOfType<TLGroupShape>(this.editor.getShapeById(id)!, 'group')
) )
if ( if (

View file

@ -2,14 +2,13 @@ import { StateNode } from '../../tools/StateNode'
// shared custody // shared custody
import { Drawing } from '../draw/toolStates/Drawing' import { Drawing } from '../draw/toolStates/Drawing'
import { Idle } from '../draw/toolStates/Idle' import { Idle } from '../draw/toolStates/Idle'
import { HighlightShapeUtil } from './HighlightShapeUtil'
export class HighlightShapeTool extends StateNode { export class HighlightShapeTool extends StateNode {
static override id = 'highlight' static override id = 'highlight'
static initial = 'idle' static initial = 'idle'
static children = () => [Idle, Drawing] static children = () => [Idle, Drawing]
shapeType = HighlightShapeUtil shapeType = 'highlight'
onExit = () => { onExit = () => {
const drawingState = this.children!['drawing'] as Drawing const drawingState = this.children!['drawing'] as Drawing

View file

@ -1,6 +1,6 @@
import { TLLineShape } from '@tldraw/tlschema'
import { assert } from '@tldraw/utils' import { assert } from '@tldraw/utils'
import { TestEditor } from '../../../test/TestEditor' import { TestEditor } from '../../../test/TestEditor'
import { LineShapeUtil } from '../../shapes/line/LineShapeUtil'
let editor: TestEditor let editor: TestEditor
@ -128,7 +128,7 @@ describe('When extending the line with the shift-key in tool-lock mode', () => {
.pointerUp(20, 10) .pointerUp(20, 10)
const line = editor.shapesArray[editor.shapesArray.length - 1] const line = editor.shapesArray[editor.shapesArray.length - 1]
assert(editor.isShapeOfType(line, LineShapeUtil)) assert(editor.isShapeOfType<TLLineShape>(line, 'line'))
const handles = Object.values(line.props.handles) const handles = Object.values(line.props.handles)
expect(handles.length).toBe(3) expect(handles.length).toBe(3)
}) })
@ -145,7 +145,7 @@ describe('When extending the line with the shift-key in tool-lock mode', () => {
.pointerUp(30, 10) .pointerUp(30, 10)
const line = editor.shapesArray[editor.shapesArray.length - 1] const line = editor.shapesArray[editor.shapesArray.length - 1]
assert(editor.isShapeOfType(line, LineShapeUtil)) assert(editor.isShapeOfType<TLLineShape>(line, 'line'))
const handles = Object.values(line.props.handles) const handles = Object.values(line.props.handles)
expect(handles.length).toBe(3) expect(handles.length).toBe(3)
}) })
@ -163,7 +163,7 @@ describe('When extending the line with the shift-key in tool-lock mode', () => {
.pointerUp(30, 10) .pointerUp(30, 10)
const line = editor.shapesArray[editor.shapesArray.length - 1] const line = editor.shapesArray[editor.shapesArray.length - 1]
assert(editor.isShapeOfType(line, LineShapeUtil)) assert(editor.isShapeOfType<TLLineShape>(line, 'line'))
const handles = Object.values(line.props.handles) const handles = Object.values(line.props.handles)
expect(handles.length).toBe(3) expect(handles.length).toBe(3)
}) })
@ -183,7 +183,7 @@ describe('When extending the line with the shift-key in tool-lock mode', () => {
.pointerUp(30, 10) .pointerUp(30, 10)
const line = editor.shapesArray[editor.shapesArray.length - 1] const line = editor.shapesArray[editor.shapesArray.length - 1]
assert(editor.isShapeOfType(line, LineShapeUtil)) assert(editor.isShapeOfType<TLLineShape>(line, 'line'))
const handles = Object.values(line.props.handles) const handles = Object.values(line.props.handles)
expect(handles.length).toBe(3) expect(handles.length).toBe(3)
}) })
@ -205,7 +205,7 @@ describe('When extending the line with the shift-key in tool-lock mode', () => {
.pointerUp(40, 10) .pointerUp(40, 10)
const line = editor.shapesArray[editor.shapesArray.length - 1] const line = editor.shapesArray[editor.shapesArray.length - 1]
assert(editor.isShapeOfType(line, LineShapeUtil)) assert(editor.isShapeOfType<TLLineShape>(line, 'line'))
const handles = Object.values(line.props.handles) const handles = Object.values(line.props.handles)
expect(handles.length).toBe(3) expect(handles.length).toBe(3)
}) })

View file

@ -1,5 +1,4 @@
import { StateNode } from '../../tools/StateNode' import { StateNode } from '../../tools/StateNode'
import { LineShapeUtil } from './LineShapeUtil'
import { Idle } from './toolStates/Idle' import { Idle } from './toolStates/Idle'
import { Pointing } from './toolStates/Pointing' import { Pointing } from './toolStates/Pointing'
@ -8,5 +7,5 @@ export class LineShapeTool extends StateNode {
static initial = 'idle' static initial = 'idle'
static children = () => [Idle, Pointing] static children = () => [Idle, Pointing]
shapeType = LineShapeUtil shapeType = 'line'
} }

View file

@ -1,5 +1,4 @@
import { StateNode } from '../../tools/StateNode' import { StateNode } from '../../tools/StateNode'
import { NoteShapeUtil } from './NoteShapeUtil'
import { Idle } from './toolStates/Idle' import { Idle } from './toolStates/Idle'
import { Pointing } from './toolStates/Pointing' import { Pointing } from './toolStates/Pointing'
@ -8,5 +7,5 @@ export class NoteShapeTool extends StateNode {
static initial = 'idle' static initial = 'idle'
static children = () => [Idle, Pointing] static children = () => [Idle, Pointing]
shapeType = NoteShapeUtil shapeType = 'note'
} }

View file

@ -1,5 +1,4 @@
import { StateNode } from '../../tools/StateNode' import { StateNode } from '../../tools/StateNode'
import { TextShapeUtil } from './TextShapeUtil'
import { Idle } from './toolStates/Idle' import { Idle } from './toolStates/Idle'
import { Pointing } from './toolStates/Pointing' import { Pointing } from './toolStates/Pointing'
@ -9,5 +8,5 @@ export class TextShapeTool extends StateNode {
static children = () => [Idle, Pointing] static children = () => [Idle, Pointing]
shapeType = TextShapeUtil shapeType = 'text'
} }

View file

@ -1,7 +1,6 @@
import { TLGeoShape, TLTextShape } from '@tldraw/tlschema'
import { StateNode } from '../../../tools/StateNode' import { StateNode } from '../../../tools/StateNode'
import { TLEventHandlers } from '../../../types/event-types' import { TLEventHandlers } from '../../../types/event-types'
import { GeoShapeUtil } from '../../geo/GeoShapeUtil'
import { TextShapeUtil } from '../TextShapeUtil'
export class Idle extends StateNode { export class Idle extends StateNode {
static override id = 'idle' static override id = 'idle'
@ -19,7 +18,7 @@ export class Idle extends StateNode {
(parent) => !selectedIds.includes(parent.id) (parent) => !selectedIds.includes(parent.id)
) )
if (hoveringShape.id !== focusLayerId) { if (hoveringShape.id !== focusLayerId) {
if (this.editor.isShapeOfType(hoveringShape, TextShapeUtil)) { if (this.editor.isShapeOfType<TLTextShape>(hoveringShape, 'text')) {
this.editor.setHoveredId(hoveringShape.id) this.editor.setHoveredId(hoveringShape.id)
} }
} }
@ -41,7 +40,7 @@ export class Idle extends StateNode {
const { hoveredId } = this.editor const { hoveredId } = this.editor
if (hoveredId) { if (hoveredId) {
const shape = this.editor.getShapeById(hoveredId)! const shape = this.editor.getShapeById(hoveredId)!
if (this.editor.isShapeOfType(shape, TextShapeUtil)) { if (this.editor.isShapeOfType<TLTextShape>(shape, 'text')) {
requestAnimationFrame(() => { requestAnimationFrame(() => {
this.editor.setSelectedIds([shape.id]) this.editor.setSelectedIds([shape.id])
this.editor.setEditingId(shape.id) this.editor.setEditingId(shape.id)
@ -65,7 +64,7 @@ export class Idle extends StateNode {
onKeyDown: TLEventHandlers['onKeyDown'] = (info) => { onKeyDown: TLEventHandlers['onKeyDown'] = (info) => {
if (info.key === 'Enter') { if (info.key === 'Enter') {
const shape = this.editor.selectedShapes[0] const shape = this.editor.selectedShapes[0]
if (shape && this.editor.isShapeOfType(shape, GeoShapeUtil)) { if (shape && this.editor.isShapeOfType<TLGeoShape>(shape, 'geo')) {
this.editor.setSelectedTool('select') this.editor.setSelectedTool('select')
this.editor.setEditingId(shape.id) this.editor.setEditingId(shape.id)
this.editor.root.current.value!.transition('editing_shape', { this.editor.root.current.value!.transition('editing_shape', {

View file

@ -1,4 +1,3 @@
import { TLShapeUtilConstructor } from '../../shapes/ShapeUtil'
import { StateNode } from '../StateNode' import { StateNode } from '../StateNode'
import { Idle } from './children/Idle' import { Idle } from './children/Idle'
import { Pointing } from './children/Pointing' import { Pointing } from './children/Pointing'
@ -9,5 +8,5 @@ export abstract class BaseBoxShapeTool extends StateNode {
static initial = 'idle' static initial = 'idle'
static children = () => [Idle, Pointing] static children = () => [Idle, Pointing]
abstract shapeType: TLShapeUtilConstructor<any> abstract shapeType: string
} }

View file

@ -21,7 +21,7 @@ export class Pointing extends StateNode {
if (this.editor.inputs.isDragging) { if (this.editor.inputs.isDragging) {
const { originPagePoint } = this.editor.inputs const { originPagePoint } = this.editor.inputs
const shapeType = (this.parent as BaseBoxShapeTool)!.shapeType.type as TLBaseBoxShape['type'] const shapeType = (this.parent as BaseBoxShapeTool)!.shapeType
const id = createShapeId() const id = createShapeId()
@ -78,7 +78,7 @@ export class Pointing extends StateNode {
this.editor.mark(this.markId) this.editor.mark(this.markId)
const shapeType = (this.parent as BaseBoxShapeTool)!.shapeType.type as TLBaseBoxShape['type'] const shapeType = (this.parent as BaseBoxShapeTool)!.shapeType as TLBaseBoxShape['type']
const id = createShapeId() const id = createShapeId()

View file

@ -1,8 +1,6 @@
import { pointInPolygon } from '@tldraw/primitives' import { pointInPolygon } from '@tldraw/primitives'
import { TLScribble, TLShapeId } from '@tldraw/tlschema' import { TLFrameShape, TLGroupShape, TLScribble, TLShapeId } from '@tldraw/tlschema'
import { ScribbleManager } from '../../../managers/ScribbleManager' import { ScribbleManager } from '../../../managers/ScribbleManager'
import { FrameShapeUtil } from '../../../shapes/frame/FrameShapeUtil'
import { GroupShapeUtil } from '../../../shapes/group/GroupShapeUtil'
import { TLEventHandlers, TLPointerEventInfo } from '../../../types/event-types' import { TLEventHandlers, TLPointerEventInfo } from '../../../types/event-types'
import { StateNode } from '../../StateNode' import { StateNode } from '../../StateNode'
@ -24,8 +22,8 @@ export class Erasing extends StateNode {
.filter( .filter(
(shape) => (shape) =>
this.editor.isShapeOrAncestorLocked(shape) || this.editor.isShapeOrAncestorLocked(shape) ||
((this.editor.isShapeOfType(shape, GroupShapeUtil) || ((this.editor.isShapeOfType<TLGroupShape>(shape, 'group') ||
this.editor.isShapeOfType(shape, FrameShapeUtil)) && this.editor.isShapeOfType<TLFrameShape>(shape, 'frame')) &&
this.editor.isPointInShape(originPagePoint, shape)) this.editor.isPointInShape(originPagePoint, shape))
) )
.map((shape) => shape.id) .map((shape) => shape.id)
@ -98,7 +96,7 @@ export class Erasing extends StateNode {
const erasing = new Set<TLShapeId>(erasingIdsSet) const erasing = new Set<TLShapeId>(erasingIdsSet)
for (const shape of shapesArray) { for (const shape of shapesArray) {
if (this.editor.isShapeOfType(shape, GroupShapeUtil)) continue if (this.editor.isShapeOfType<TLGroupShape>(shape, 'group')) continue
// Avoid testing masked shapes, unless the pointer is inside the mask // Avoid testing masked shapes, unless the pointer is inside the mask
const pageMask = this.editor.getPageMaskById(shape.id) const pageMask = this.editor.getPageMaskById(shape.id)

View file

@ -1,6 +1,4 @@
import { TLShapeId } from '@tldraw/tlschema' import { TLFrameShape, TLGroupShape, TLShapeId } from '@tldraw/tlschema'
import { FrameShapeUtil } from '../../../shapes/frame/FrameShapeUtil'
import { GroupShapeUtil } from '../../../shapes/group/GroupShapeUtil'
import { TLEventHandlers } from '../../../types/event-types' import { TLEventHandlers } from '../../../types/event-types'
import { StateNode } from '../../StateNode' import { StateNode } from '../../StateNode'
@ -17,12 +15,16 @@ export class Pointing extends StateNode {
for (const shape of [...this.editor.sortedShapesArray].reverse()) { for (const shape of [...this.editor.sortedShapesArray].reverse()) {
if (this.editor.isPointInShape(inputs.currentPagePoint, shape)) { if (this.editor.isPointInShape(inputs.currentPagePoint, shape)) {
// Skip groups // Skip groups
if (this.editor.isShapeOfType(shape, GroupShapeUtil)) continue if (this.editor.isShapeOfType<TLGroupShape>(shape, 'group')) continue
const hitShape = this.editor.getOutermostSelectableShape(shape) const hitShape = this.editor.getOutermostSelectableShape(shape)
// If we've hit a frame after hitting any other shape, stop here // If we've hit a frame after hitting any other shape, stop here
if (this.editor.isShapeOfType(hitShape, FrameShapeUtil) && erasing.size > initialSize) break if (
this.editor.isShapeOfType<TLFrameShape>(hitShape, 'frame') &&
erasing.size > initialSize
)
break
erasing.add(hitShape.id) erasing.add(hitShape.id)
} }

View file

@ -6,9 +6,7 @@ import {
Vec2d, Vec2d,
VecLike, VecLike,
} from '@tldraw/primitives' } from '@tldraw/primitives'
import { TLPageId, TLShape, TLShapeId } from '@tldraw/tlschema' import { TLFrameShape, TLGroupShape, TLPageId, TLShape, TLShapeId } from '@tldraw/tlschema'
import { FrameShapeUtil } from '../../../shapes/frame/FrameShapeUtil'
import { GroupShapeUtil } from '../../../shapes/group/GroupShapeUtil'
import { ShapeUtil } from '../../../shapes/ShapeUtil' import { ShapeUtil } from '../../../shapes/ShapeUtil'
import { import {
TLCancelEvent, TLCancelEvent,
@ -43,7 +41,7 @@ export class Brushing extends StateNode {
this.editor.shapesArray this.editor.shapesArray
.filter( .filter(
(shape) => (shape) =>
this.editor.isShapeOfType(shape, GroupShapeUtil) || this.editor.isShapeOfType<TLGroupShape>(shape, 'group') ||
this.editor.isShapeOrAncestorLocked(shape) this.editor.isShapeOrAncestorLocked(shape)
) )
.map((shape) => shape.id) .map((shape) => shape.id)
@ -136,7 +134,7 @@ export class Brushing extends StateNode {
// Should we even test for a single segment intersections? Only if // Should we even test for a single segment intersections? Only if
// we're not holding the ctrl key for alternate selection mode // we're not holding the ctrl key for alternate selection mode
// (only wraps count!), or if the shape is a frame. // (only wraps count!), or if the shape is a frame.
if (ctrlKey || this.editor.isShapeOfType(shape, FrameShapeUtil)) { if (ctrlKey || this.editor.isShapeOfType<TLFrameShape>(shape, 'frame')) {
continue testAllShapes continue testAllShapes
} }

View file

@ -1,14 +1,7 @@
import { SelectionHandle, Vec2d } from '@tldraw/primitives' import { SelectionHandle, Vec2d } from '@tldraw/primitives'
import { import { TLBaseShape, TLImageShape, TLImageShapeCrop, TLShapePartial } from '@tldraw/tlschema'
TLBaseShape,
TLImageShapeCrop,
TLImageShapeProps,
TLShape,
TLShapePartial,
} from '@tldraw/tlschema'
import { deepCopy } from '@tldraw/utils' import { deepCopy } from '@tldraw/utils'
import { MIN_CROP_SIZE } from '../../../../constants' import { MIN_CROP_SIZE } from '../../../../constants'
import { ImageShapeUtil } from '../../../shapes/image/ImageShapeUtil'
import { import {
TLEnterEventHandler, TLEnterEventHandler,
TLEventHandlers, TLEventHandlers,
@ -81,10 +74,10 @@ export class Cropping extends StateNode {
const { shape, cursorHandleOffset } = this.snapshot const { shape, cursorHandleOffset } = this.snapshot
if (!shape) return if (!shape) return
const util = this.editor.getShapeUtil(ImageShapeUtil) const util = this.editor.getShapeUtil<TLImageShape>('image')
if (!util) return if (!util) return
const props = shape.props as TLImageShapeProps const props = shape.props
const currentPagePoint = this.editor.inputs.currentPagePoint.clone().sub(cursorHandleOffset) const currentPagePoint = this.editor.inputs.currentPagePoint.clone().sub(cursorHandleOffset)
const originPagePoint = this.editor.inputs.originPagePoint.clone().sub(cursorHandleOffset) const originPagePoint = this.editor.inputs.originPagePoint.clone().sub(cursorHandleOffset)
@ -229,7 +222,7 @@ export class Cropping extends StateNode {
inputs: { originPagePoint }, inputs: { originPagePoint },
} = this.editor } = this.editor
const shape = this.editor.onlySelectedShape as TLShape const shape = this.editor.onlySelectedShape as TLImageShape
const selectionBounds = this.editor.selectionBounds! const selectionBounds = this.editor.selectionBounds!

View file

@ -1,7 +1,6 @@
import { Vec2d } from '@tldraw/primitives' import { Vec2d } from '@tldraw/primitives'
import { TLGeoShape, TLShape, TLTextShape, createShapeId } from '@tldraw/tlschema' import { TLGeoShape, TLGroupShape, TLShape, TLTextShape, createShapeId } from '@tldraw/tlschema'
import { debugFlags } from '../../../../utils/debug-flags' import { debugFlags } from '../../../../utils/debug-flags'
import { GroupShapeUtil } from '../../../shapes/group/GroupShapeUtil'
import { import {
TLClickEventInfo, TLClickEventInfo,
TLEventHandlers, TLEventHandlers,
@ -319,7 +318,9 @@ export class Idle extends StateNode {
case 'Enter': { case 'Enter': {
const { selectedShapes } = this.editor const { selectedShapes } = this.editor
if (selectedShapes.every((shape) => this.editor.isShapeOfType(shape, GroupShapeUtil))) { if (
selectedShapes.every((shape) => this.editor.isShapeOfType<TLGroupShape>(shape, 'group'))
) {
this.editor.setSelectedIds( this.editor.setSelectedIds(
selectedShapes.flatMap((shape) => this.editor.getSortedChildIds(shape.id)) selectedShapes.flatMap((shape) => this.editor.getSortedChildIds(shape.id))
) )

View file

@ -1,5 +1,4 @@
import { TLShape } from '@tldraw/tlschema' import { TLGroupShape, TLShape } from '@tldraw/tlschema'
import { GroupShapeUtil } from '../../../shapes/group/GroupShapeUtil'
import { TLEventHandlers, TLPointerEventInfo } from '../../../types/event-types' import { TLEventHandlers, TLPointerEventInfo } from '../../../types/event-types'
import { StateNode } from '../../StateNode' import { StateNode } from '../../StateNode'
@ -36,7 +35,7 @@ export class PointingShape extends StateNode {
const parent = this.editor.getParentShape(info.shape) const parent = this.editor.getParentShape(info.shape)
if (parent && this.editor.isShapeOfType(parent, GroupShapeUtil)) { if (parent && this.editor.isShapeOfType<TLGroupShape>(parent, 'group')) {
this.editor.cancelDoubleClick() this.editor.cancelDoubleClick()
} }

View file

@ -9,8 +9,7 @@ import {
Vec2d, Vec2d,
VecLike, VecLike,
} from '@tldraw/primitives' } from '@tldraw/primitives'
import { TLShape, TLShapeId, TLShapePartial } from '@tldraw/tlschema' import { TLFrameShape, TLShape, TLShapeId, TLShapePartial } from '@tldraw/tlschema'
import { FrameShapeUtil } from '../../../shapes/frame/FrameShapeUtil'
import { import {
TLEnterEventHandler, TLEnterEventHandler,
TLEventHandlers, TLEventHandlers,
@ -371,12 +370,13 @@ export class Resizing extends StateNode {
const shape = this.editor.getShapeById(id) const shape = this.editor.getShapeById(id)
if (shape) { if (shape) {
shapeSnapshots.set(shape.id, this._createShapeSnapshot(shape)) shapeSnapshots.set(shape.id, this._createShapeSnapshot(shape))
if (this.editor.isShapeOfType(shape, FrameShapeUtil) && selectedIds.length === 1) return if (this.editor.isShapeOfType<TLFrameShape>(shape, 'frame') && selectedIds.length === 1)
return
this.editor.visitDescendants(shape.id, (descendantId) => { this.editor.visitDescendants(shape.id, (descendantId) => {
const descendent = this.editor.getShapeById(descendantId) const descendent = this.editor.getShapeById(descendantId)
if (descendent) { if (descendent) {
shapeSnapshots.set(descendent.id, this._createShapeSnapshot(descendent)) shapeSnapshots.set(descendent.id, this._createShapeSnapshot(descendent))
if (this.editor.isShapeOfType(descendent, FrameShapeUtil)) { if (this.editor.isShapeOfType<TLFrameShape>(descendent, 'frame')) {
return false return false
} }
} }

View file

@ -1,9 +1,7 @@
import { intersectLineSegmentPolyline, pointInPolygon } from '@tldraw/primitives' import { intersectLineSegmentPolyline, pointInPolygon } from '@tldraw/primitives'
import { TLScribble, TLShape, TLShapeId } from '@tldraw/tlschema' import { TLFrameShape, TLGroupShape, TLScribble, TLShape, TLShapeId } from '@tldraw/tlschema'
import { ScribbleManager } from '../../../managers/ScribbleManager' import { ScribbleManager } from '../../../managers/ScribbleManager'
import { ShapeUtil } from '../../../shapes/ShapeUtil' import { ShapeUtil } from '../../../shapes/ShapeUtil'
import { FrameShapeUtil } from '../../../shapes/frame/FrameShapeUtil'
import { GroupShapeUtil } from '../../../shapes/group/GroupShapeUtil'
import { TLEventHandlers } from '../../../types/event-types' import { TLEventHandlers } from '../../../types/event-types'
import { StateNode } from '../../StateNode' import { StateNode } from '../../StateNode'
@ -106,9 +104,9 @@ export class ScribbleBrushing extends StateNode {
util = this.editor.getShapeUtil(shape) util = this.editor.getShapeUtil(shape)
if ( if (
this.editor.isShapeOfType(shape, GroupShapeUtil) || this.editor.isShapeOfType<TLGroupShape>(shape, 'group') ||
this.newlySelectedIds.has(shape.id) || this.newlySelectedIds.has(shape.id) ||
(this.editor.isShapeOfType(shape, FrameShapeUtil) && (this.editor.isShapeOfType<TLFrameShape>(shape, 'frame') &&
util.hitTestPoint(shape, this.editor.getPointInShapeSpace(shape, originPagePoint))) || util.hitTestPoint(shape, this.editor.getPointInShapeSpace(shape, originPagePoint))) ||
this.editor.isShapeOrAncestorLocked(shape) this.editor.isShapeOrAncestorLocked(shape)
) { ) {

View file

@ -1,7 +1,5 @@
import { Atom, Computed, atom, computed } from '@tldraw/state' import { Atom, Computed, atom, computed } from '@tldraw/state'
import { TLBaseShape } from '@tldraw/tlschema'
import type { Editor } from '../Editor' import type { Editor } from '../Editor'
import { TLShapeUtilConstructor } from '../shapes/ShapeUtil'
import { import {
EVENT_NAME_MAP, EVENT_NAME_MAP,
TLEnterEventHandler, TLEnterEventHandler,
@ -69,7 +67,7 @@ export abstract class StateNode implements Partial<TLEventHandlers> {
id: string id: string
current: Atom<StateNode | undefined> current: Atom<StateNode | undefined>
type: TLStateNodeType type: TLStateNodeType
shapeType?: TLShapeUtilConstructor<TLBaseShape<any, any>> shapeType?: string
initial?: string initial?: string
children?: Record<string, StateNode> children?: Record<string, StateNode>
parent: StateNode parent: StateNode

View file

@ -1,6 +1,7 @@
import { PageRecordType, TLShape, createShapeId } from '@tldraw/tlschema' import { PageRecordType, TLShape, createShapeId } from '@tldraw/tlschema'
import { defaultShapes } from '../config/defaultShapes'
import { defineShape } from '../config/defineShape'
import { BaseBoxShapeUtil } from '../editor/shapes/BaseBoxShapeUtil' import { BaseBoxShapeUtil } from '../editor/shapes/BaseBoxShapeUtil'
import { GeoShapeUtil } from '../editor/shapes/geo/GeoShapeUtil'
import { TestEditor } from './TestEditor' import { TestEditor } from './TestEditor'
import { TL } from './jsx' import { TL } from './jsx'
@ -435,58 +436,66 @@ describe('isFocused', () => {
}) })
describe('getShapeUtil', () => { describe('getShapeUtil', () => {
it('accepts shapes', () => { let myUtil: any
const geoShape = editor.getShapeById(ids.box1)!
const geoUtil = editor.getShapeUtil(geoShape) beforeEach(() => {
expect(geoUtil).toBeInstanceOf(GeoShapeUtil) class _MyFakeShapeUtil extends BaseBoxShapeUtil<any> {
static type = 'blorg'
type = 'blorg'
getDefaultProps() {
return {
w: 100,
h: 100,
}
}
component() {
throw new Error('Method not implemented.')
}
indicator() {
throw new Error('Method not implemented.')
}
}
myUtil = _MyFakeShapeUtil
const myShapeDef = defineShape('blorg', {
util: _MyFakeShapeUtil,
})
editor = new TestEditor({
shapes: [...defaultShapes, myShapeDef],
})
editor.createShapes([
{ id: ids.box1, type: 'blorg', x: 100, y: 100, props: { w: 100, h: 100 } },
])
const page1 = editor.currentPageId
editor.createPage('page 2', ids.page2)
editor.setCurrentPageId(page1)
}) })
it('accepts shape utils', () => { it('accepts shapes', () => {
const geoUtil = editor.getShapeUtil(GeoShapeUtil) const shape = editor.getShapeById(ids.box1)!
expect(geoUtil).toBeInstanceOf(GeoShapeUtil) const util = editor.getShapeUtil(shape)
expect(util).toBeInstanceOf(myUtil)
})
it('accepts shape types', () => {
const util = editor.getShapeUtil('blorg')
expect(util).toBeInstanceOf(myUtil)
}) })
it('throws if that shape type isnt registered', () => { it('throws if that shape type isnt registered', () => {
const myFakeShape = { type: 'fake' } as TLShape const myMissingShape = { type: 'missing' } as TLShape
expect(() => editor.getShapeUtil(myFakeShape)).toThrowErrorMatchingInlineSnapshot( expect(() => editor.getShapeUtil(myMissingShape)).toThrowErrorMatchingInlineSnapshot(
`"No shape util found for type \\"fake\\""` `"No shape util found for type \\"missing\\""`
)
class MyFakeShapeUtil extends BaseBoxShapeUtil<any> {
static type = 'fake'
getDefaultProps() {
throw new Error('Method not implemented.')
}
component() {
throw new Error('Method not implemented.')
}
indicator() {
throw new Error('Method not implemented.')
}
}
expect(() => editor.getShapeUtil(MyFakeShapeUtil)).toThrowErrorMatchingInlineSnapshot(
`"No shape util found for type \\"fake\\""`
) )
}) })
it("throws if a shape util that isn't the one registered is passed in", () => { it('throws if that type isnt registered', () => {
class MyFakeGeoShapeUtil extends BaseBoxShapeUtil<any> { expect(() => editor.getShapeUtil('missing')).toThrowErrorMatchingInlineSnapshot(
static type = 'geo' `"No shape util found for type \\"missing\\""`
getDefaultProps() {
throw new Error('Method not implemented.')
}
component() {
throw new Error('Method not implemented.')
}
indicator() {
throw new Error('Method not implemented.')
}
}
expect(() => editor.getShapeUtil(MyFakeGeoShapeUtil)).toThrowErrorMatchingInlineSnapshot(
`"Shape util found for type \\"geo\\" is not an instance of the provided constructor"`
) )
}) })
}) })

View file

@ -301,7 +301,7 @@ describe('Custom shapes', () => {
class CardTool extends BaseBoxShapeTool { class CardTool extends BaseBoxShapeTool {
static override id = 'card' static override id = 'card'
static override initial = 'idle' static override initial = 'idle'
override shapeType = CardUtil override shapeType = 'card'
} }
const tools = [CardTool] const tools = [CardTool]

View file

@ -1,5 +1,4 @@
import { createShapeId } from '@tldraw/tlschema' import { createShapeId } from '@tldraw/tlschema'
import { GeoShapeUtil } from '../../editor/shapes/geo/GeoShapeUtil'
import { TestEditor } from '../TestEditor' import { TestEditor } from '../TestEditor'
let editor: TestEditor let editor: TestEditor
@ -43,7 +42,7 @@ beforeEach(() => {
describe('editor.rotateShapes', () => { describe('editor.rotateShapes', () => {
it('Rotates shapes and fires events', () => { it('Rotates shapes and fires events', () => {
// Set start / change / end events on only the geo shape // Set start / change / end events on only the geo shape
const util = editor.getShapeUtil(GeoShapeUtil) const util = editor.getShapeUtil('geo')
// Bad! who did this (did I do this) // Bad! who did this (did I do this)
const fnStart = jest.fn() const fnStart = jest.fn()

View file

@ -1,5 +1,4 @@
import { TLArrowShape, TLShapePartial, createShapeId } from '@tldraw/tlschema' import { TLArrowShape, TLShapePartial, createShapeId } from '@tldraw/tlschema'
import { ArrowShapeUtil } from '../editor/shapes/arrow/ArrowShapeUtil'
import { TestEditor } from './TestEditor' import { TestEditor } from './TestEditor'
let editor: TestEditor let editor: TestEditor
@ -189,7 +188,7 @@ describe('When duplicating shapes that include arrows', () => {
.createShapes(shapes) .createShapes(shapes)
.select( .select(
...editor.shapesArray ...editor.shapesArray
.filter((s) => editor.isShapeOfType(s, ArrowShapeUtil)) .filter((s) => editor.isShapeOfType<TLArrowShape>(s, 'arrow'))
.map((s) => s.id) .map((s) => s.id)
) )

View file

@ -1,6 +1,4 @@
import { createShapeId } from '@tldraw/tlschema' import { TLFrameShape, TLGeoShape, createShapeId } from '@tldraw/tlschema'
import { FrameShapeUtil } from '../editor/shapes/frame/FrameShapeUtil'
import { GeoShapeUtil } from '../editor/shapes/geo/GeoShapeUtil'
import { TestEditor } from './TestEditor' import { TestEditor } from './TestEditor'
let editor: TestEditor let editor: TestEditor
@ -56,7 +54,7 @@ beforeEach(() => {
describe('When interacting with a shape...', () => { describe('When interacting with a shape...', () => {
it('fires rotate events', () => { it('fires rotate events', () => {
// Set start / change / end events on only the geo shape // Set start / change / end events on only the geo shape
const util = editor.getShapeUtil(FrameShapeUtil) const util = editor.getShapeUtil<TLFrameShape>('frame')
const fnStart = jest.fn() const fnStart = jest.fn()
util.onRotateStart = fnStart util.onRotateStart = fnStart
@ -89,12 +87,12 @@ describe('When interacting with a shape...', () => {
}) })
it('cleans up events', () => { it('cleans up events', () => {
const util = editor.getShapeUtil(GeoShapeUtil) const util = editor.getShapeUtil<TLGeoShape>('geo')
expect(util.onRotateStart).toBeUndefined() expect(util.onRotateStart).toBeUndefined()
}) })
it('fires double click handler event', () => { it('fires double click handler event', () => {
const util = editor.getShapeUtil(GeoShapeUtil) const util = editor.getShapeUtil<TLGeoShape>('geo')
const fnStart = jest.fn() const fnStart = jest.fn()
util.onDoubleClick = fnStart util.onDoubleClick = fnStart
@ -105,7 +103,7 @@ describe('When interacting with a shape...', () => {
}) })
it('Fires resisizing events', () => { it('Fires resisizing events', () => {
const util = editor.getShapeUtil(FrameShapeUtil) const util = editor.getShapeUtil<TLFrameShape>('frame')
const fnStart = jest.fn() const fnStart = jest.fn()
util.onResizeStart = fnStart util.onResizeStart = fnStart
@ -142,7 +140,7 @@ describe('When interacting with a shape...', () => {
}) })
it('Fires translating events', () => { it('Fires translating events', () => {
const util = editor.getShapeUtil(FrameShapeUtil) const util = editor.getShapeUtil<TLFrameShape>('frame')
const fnStart = jest.fn() const fnStart = jest.fn()
util.onTranslateStart = fnStart util.onTranslateStart = fnStart
@ -170,7 +168,7 @@ describe('When interacting with a shape...', () => {
}) })
it('Uses the shape utils onClick handler', () => { it('Uses the shape utils onClick handler', () => {
const util = editor.getShapeUtil(FrameShapeUtil) const util = editor.getShapeUtil<TLFrameShape>('frame')
const fnClick = jest.fn() const fnClick = jest.fn()
util.onClick = fnClick util.onClick = fnClick
@ -184,7 +182,7 @@ describe('When interacting with a shape...', () => {
}) })
it('Uses the shape utils onClick handler', () => { it('Uses the shape utils onClick handler', () => {
const util = editor.getShapeUtil(FrameShapeUtil) const util = editor.getShapeUtil<TLFrameShape>('frame')
const fnClick = jest.fn((shape: any) => { const fnClick = jest.fn((shape: any) => {
return { return {

View file

@ -1,5 +1,4 @@
import { DefaultFillStyle, TLArrowShape, createShapeId } from '@tldraw/tlschema' import { DefaultFillStyle, TLArrowShape, TLFrameShape, createShapeId } from '@tldraw/tlschema'
import { FrameShapeUtil } from '../../editor/shapes/frame/FrameShapeUtil'
import { TestEditor } from '../TestEditor' import { TestEditor } from '../TestEditor'
let editor: TestEditor let editor: TestEditor
@ -33,7 +32,7 @@ describe('creating frames', () => {
editor.setSelectedTool('frame') editor.setSelectedTool('frame')
editor.pointerDown(100, 100).pointerUp(100, 100) editor.pointerDown(100, 100).pointerUp(100, 100)
expect(editor.onlySelectedShape?.type).toBe('frame') expect(editor.onlySelectedShape?.type).toBe('frame')
const { w, h } = editor.getShapeUtil(FrameShapeUtil).getDefaultProps() const { w, h } = editor.getShapeUtil<TLFrameShape>('frame').getDefaultProps()
expect(editor.getPageBounds(editor.onlySelectedShape!)).toMatchObject({ expect(editor.getPageBounds(editor.onlySelectedShape!)).toMatchObject({
x: 100 - w / 2, x: 100 - w / 2,
y: 100 - h / 2, y: 100 - h / 2,

View file

@ -1895,7 +1895,7 @@ describe('Group opacity', () => {
editor.setOpacity(0.5) editor.setOpacity(0.5)
editor.groupShapes() editor.groupShapes()
const group = editor.getShapeById(onlySelectedId())! const group = editor.getShapeById(onlySelectedId())!
assert(editor.isShapeOfType(group, GroupShapeUtil)) assert(editor.isShapeOfType<TLGroupShape>(group, 'group'))
expect(group.opacity).toBe(1) expect(group.opacity).toBe(1)
}) })
}) })

View file

@ -1,12 +1,11 @@
import { Box2d, Vec2d } from '@tldraw/primitives' import { Box2d, Vec2d } from '@tldraw/primitives'
import { TLShapeId, TLShapePartial, createShapeId } from '@tldraw/tlschema' import { TLArrowShape, TLShapeId, TLShapePartial, createShapeId } from '@tldraw/tlschema'
import { GapsSnapLine, PointsSnapLine, SnapLine } from '../../editor/managers/SnapManager' import { GapsSnapLine, PointsSnapLine, SnapLine } from '../../editor/managers/SnapManager'
import { ShapeUtil } from '../../editor/shapes/ShapeUtil' import { ShapeUtil } from '../../editor/shapes/ShapeUtil'
import { TestEditor } from '../TestEditor' import { TestEditor } from '../TestEditor'
import { defaultShapes } from '../../config/defaultShapes' import { defaultShapes } from '../../config/defaultShapes'
import { defineShape } from '../../config/defineShape' import { defineShape } from '../../config/defineShape'
import { ArrowShapeUtil } from '../../editor/shapes/arrow/ArrowShapeUtil'
import { getSnapLines } from '../testutils/getSnapLines' import { getSnapLines } from '../testutils/getSnapLines'
type __TopLeftSnapOnlyShape = any type __TopLeftSnapOnlyShape = any
@ -1950,7 +1949,7 @@ describe('translating a shape with a bound shape', () => {
}) })
const newArrow = editor.shapesArray.find( const newArrow = editor.shapesArray.find(
(s) => editor.isShapeOfType(s, ArrowShapeUtil) && s.id !== arrow1 (s) => editor.isShapeOfType<TLArrowShape>(s, 'arrow') && s.id !== arrow1
) )
expect(newArrow).toMatchObject({ expect(newArrow).toMatchObject({
props: { start: { type: 'binding' }, end: { type: 'point' } }, props: { start: { type: 'binding' }, end: { type: 'point' } },

View file

@ -1,5 +1,4 @@
import { import {
ArrowShapeUtil,
AssetRecordType, AssetRecordType,
Editor, Editor,
MAX_SHAPES_PER_PAGE, MAX_SHAPES_PER_PAGE,
@ -518,7 +517,7 @@ export function buildFromV1Document(editor: Editor, document: LegacyTldrawDocume
} }
const v2ShapeId = v1ShapeIdsToV2ShapeIds.get(v1Shape.id)! const v2ShapeId = v1ShapeIdsToV2ShapeIds.get(v1Shape.id)!
const util = editor.getShapeUtil(ArrowShapeUtil) const util = editor.getShapeUtil<TLArrowShape>('arrow')
// dumb but necessary // dumb but necessary
editor.inputs.ctrlKey = false editor.inputs.ctrlKey = false

View file

@ -19,6 +19,7 @@ import { TLEditorAssetUrls } from '@tldraw/editor';
import { TLExportType } from '@tldraw/editor'; import { TLExportType } from '@tldraw/editor';
import { TLLanguage } from '@tldraw/editor'; import { TLLanguage } from '@tldraw/editor';
import { TLShapeId } from '@tldraw/editor'; import { TLShapeId } from '@tldraw/editor';
import { TLShapeId as TLShapeId_2 } from '@tldraw/tlschema';
import { VecLike } from '@tldraw/primitives'; import { VecLike } from '@tldraw/primitives';
// @internal (undocumented) // @internal (undocumented)
@ -643,7 +644,7 @@ export function useDialogs(): TLUiDialogsContextType;
export function useEvents(): TLUiEventContextType; export function useEvents(): TLUiEventContextType;
// @public (undocumented) // @public (undocumented)
export function useExportAs(): (ids?: TLShapeId[], format?: TLExportType) => Promise<void>; export function useExportAs(): (ids?: TLShapeId_2[], format?: TLExportType) => Promise<void>;
// @public (undocumented) // @public (undocumented)
export function useHelpMenuSchema(): TLUiMenuSchema; export function useHelpMenuSchema(): TLUiMenuSchema;

View file

@ -1,4 +1,4 @@
import { ArrowShapeUtil, Editor, useEditor } from '@tldraw/editor' import { Editor, TLArrowShape, useEditor } from '@tldraw/editor'
import { useValue } from '@tldraw/state' import { useValue } from '@tldraw/state'
import { assert, exhaustiveSwitchError } from '@tldraw/utils' import { assert, exhaustiveSwitchError } from '@tldraw/utils'
import { TLUiActionItem } from './useActions' import { TLUiActionItem } from './useActions'
@ -139,10 +139,13 @@ function shapesWithUnboundArrows(editor: Editor) {
return selectedShapes.filter((shape) => { return selectedShapes.filter((shape) => {
if (!shape) return false if (!shape) return false
if (editor.isShapeOfType(shape, ArrowShapeUtil) && shape.props.start.type === 'binding') { if (
editor.isShapeOfType<TLArrowShape>(shape, 'arrow') &&
shape.props.start.type === 'binding'
) {
return false return false
} }
if (editor.isShapeOfType(shape, ArrowShapeUtil) && shape.props.end.type === 'binding') { if (editor.isShapeOfType<TLArrowShape>(shape, 'arrow') && shape.props.end.type === 'binding') {
return false return false
} }
return true return true

View file

@ -1,20 +1,14 @@
import { ANIMATION_MEDIUM_MS, Editor, getEmbedInfo, openWindow, useEditor } from '@tldraw/editor'
import { Box2d, TAU, Vec2d, approximately } from '@tldraw/primitives'
import { import {
ANIMATION_MEDIUM_MS, TLBookmarkShape,
BookmarkShapeUtil,
Editor,
EmbedShapeUtil,
GroupShapeUtil,
TLEmbedShape, TLEmbedShape,
TLGroupShape,
TLShapeId, TLShapeId,
TLShapePartial, TLShapePartial,
TLTextShape, TLTextShape,
TextShapeUtil,
createShapeId, createShapeId,
getEmbedInfo, } from '@tldraw/tlschema'
openWindow,
useEditor,
} from '@tldraw/editor'
import { Box2d, TAU, Vec2d, approximately } from '@tldraw/primitives'
import { compact } from '@tldraw/utils' import { compact } from '@tldraw/utils'
import * as React from 'react' import * as React from 'react'
import { EditLinkDialog } from '../components/EditLinkDialog' import { EditLinkDialog } from '../components/EditLinkDialog'
@ -214,7 +208,7 @@ export function ActionsProvider({ overrides, children }: ActionsProviderProps) {
editor.selectedShapes editor.selectedShapes
.filter( .filter(
(shape): shape is TLTextShape => (shape): shape is TLTextShape =>
editor.isShapeOfType(shape, TextShapeUtil) && shape.props.autoSize === false editor.isShapeOfType<TLTextShape>(shape, 'text') && shape.props.autoSize === false
) )
.map((shape) => { .map((shape) => {
return { return {
@ -243,7 +237,7 @@ export function ActionsProvider({ overrides, children }: ActionsProviderProps) {
return return
} }
const shape = editor.getShapeById(ids[0]) const shape = editor.getShapeById(ids[0])
if (!shape || !editor.isShapeOfType(shape, EmbedShapeUtil)) { if (!shape || !editor.isShapeOfType<TLEmbedShape>(shape, 'embed')) {
console.error(warnMsg) console.error(warnMsg)
return return
} }
@ -262,7 +256,8 @@ export function ActionsProvider({ overrides, children }: ActionsProviderProps) {
const createList: TLShapePartial[] = [] const createList: TLShapePartial[] = []
const deleteList: TLShapeId[] = [] const deleteList: TLShapeId[] = []
for (const shape of shapes) { for (const shape of shapes) {
if (!shape || !editor.isShapeOfType(shape, EmbedShapeUtil) || !shape.props.url) continue if (!shape || !editor.isShapeOfType<TLEmbedShape>(shape, 'embed') || !shape.props.url)
continue
const newPos = new Vec2d(shape.x, shape.y) const newPos = new Vec2d(shape.x, shape.y)
newPos.rot(-shape.rotation) newPos.rot(-shape.rotation)
@ -300,7 +295,7 @@ export function ActionsProvider({ overrides, children }: ActionsProviderProps) {
const createList: TLShapePartial[] = [] const createList: TLShapePartial[] = []
const deleteList: TLShapeId[] = [] const deleteList: TLShapeId[] = []
for (const shape of shapes) { for (const shape of shapes) {
if (!editor.isShapeOfType(shape, BookmarkShapeUtil)) continue if (!editor.isShapeOfType<TLBookmarkShape>(shape, 'bookmark')) continue
const { url } = shape.props const { url } = shape.props
@ -383,7 +378,7 @@ export function ActionsProvider({ overrides, children }: ActionsProviderProps) {
onSelect(source) { onSelect(source) {
trackEvent('group-shapes', { source }) trackEvent('group-shapes', { source })
const { onlySelectedShape } = editor const { onlySelectedShape } = editor
if (onlySelectedShape && editor.isShapeOfType(onlySelectedShape, GroupShapeUtil)) { if (onlySelectedShape && editor.isShapeOfType<TLGroupShape>(onlySelectedShape, 'group')) {
editor.mark('ungroup') editor.mark('ungroup')
editor.ungroupShapes(editor.selectedIds) editor.ungroupShapes(editor.selectedIds)
} else { } else {

View file

@ -1,11 +1,11 @@
import { import {
ArrowShapeUtil,
BookmarkShapeUtil,
Editor, Editor,
EmbedShapeUtil, TLArrowShape,
GeoShapeUtil, TLBookmarkShape,
TLContent, TLContent,
TextShapeUtil, TLEmbedShape,
TLGeoShape,
TLTextShape,
getValidHttpURLList, getValidHttpURLList,
isSvgText, isSvgText,
isValidHttpURL, isValidHttpURL,
@ -508,15 +508,15 @@ const handleNativeOrMenuCopy = (editor: Editor) => {
const textItems = content.shapes const textItems = content.shapes
.map((shape) => { .map((shape) => {
if ( if (
editor.isShapeOfType(shape, TextShapeUtil) || editor.isShapeOfType<TLTextShape>(shape, 'text') ||
editor.isShapeOfType(shape, GeoShapeUtil) || editor.isShapeOfType<TLGeoShape>(shape, 'geo') ||
editor.isShapeOfType(shape, ArrowShapeUtil) editor.isShapeOfType<TLArrowShape>(shape, 'arrow')
) { ) {
return shape.props.text return shape.props.text
} }
if ( if (
editor.isShapeOfType(shape, BookmarkShapeUtil) || editor.isShapeOfType<TLBookmarkShape>(shape, 'bookmark') ||
editor.isShapeOfType(shape, EmbedShapeUtil) editor.isShapeOfType<TLEmbedShape>(shape, 'embed')
) { ) {
return shape.props.url return shape.props.url
} }

View file

@ -1,4 +1,4 @@
import { BookmarkShapeUtil, Editor, EmbedShapeUtil, getEmbedInfo, useEditor } from '@tldraw/editor' import { Editor, TLBookmarkShape, TLEmbedShape, getEmbedInfo, useEditor } from '@tldraw/editor'
import { track, useValue } from '@tldraw/state' import { track, useValue } from '@tldraw/state'
import React, { useMemo } from 'react' import React, { useMemo } from 'react'
import { import {
@ -65,7 +65,7 @@ export const TLUiContextMenuSchemaProvider = track(function TLUiContextMenuSchem
if (editor.selectedIds.length !== 1) return false if (editor.selectedIds.length !== 1) return false
return editor.selectedIds.some((selectedId) => { return editor.selectedIds.some((selectedId) => {
const shape = editor.getShapeById(selectedId) const shape = editor.getShapeById(selectedId)
return shape && editor.isShapeOfType(shape, EmbedShapeUtil) && shape.props.url return shape && editor.isShapeOfType<TLEmbedShape>(shape, 'embed') && shape.props.url
}) })
}, },
[] []
@ -78,7 +78,7 @@ export const TLUiContextMenuSchemaProvider = track(function TLUiContextMenuSchem
const shape = editor.getShapeById(selectedId) const shape = editor.getShapeById(selectedId)
return ( return (
shape && shape &&
editor.isShapeOfType(shape, BookmarkShapeUtil) && editor.isShapeOfType<TLBookmarkShape>(shape, 'bookmark') &&
shape.props.url && shape.props.url &&
getEmbedInfo(shape.props.url) getEmbedInfo(shape.props.url)
) )

View file

@ -1,12 +1,11 @@
import { import {
FrameShapeUtil,
TLExportType, TLExportType,
TLShapeId,
downloadDataURLAsFile, downloadDataURLAsFile,
getSvgAsDataUrl, getSvgAsDataUrl,
getSvgAsImage, getSvgAsImage,
useEditor, useEditor,
} from '@tldraw/editor' } from '@tldraw/editor'
import { TLFrameShape, TLShapeId } from '@tldraw/tlschema'
import { useCallback } from 'react' import { useCallback } from 'react'
import { useToasts } from './useToastsProvider' import { useToasts } from './useToastsProvider'
import { useTranslation } from './useTranslation/useTranslation' import { useTranslation } from './useTranslation/useTranslation'
@ -38,7 +37,7 @@ export function useExportAs() {
if (ids.length === 1) { if (ids.length === 1) {
const first = editor.getShapeById(ids[0])! const first = editor.getShapeById(ids[0])!
if (editor.isShapeOfType(first, FrameShapeUtil)) { if (editor.isShapeOfType<TLFrameShape>(first, 'frame')) {
name = first.props.name ?? 'frame' name = first.props.name ?? 'frame'
} else { } else {
name = first.id.replace(/:/, '_') name = first.id.replace(/:/, '_')

View file

@ -1,11 +1,6 @@
import { import { TLArrowShape, TLGroupShape, TLLineShape, useEditor } from '@tldraw/editor'
ArrowShapeUtil,
DrawShapeUtil,
GroupShapeUtil,
LineShapeUtil,
useEditor,
} from '@tldraw/editor'
import { useValue } from '@tldraw/state' import { useValue } from '@tldraw/state'
import { TLDrawShape } from '@tldraw/tlschema'
export function useOnlyFlippableShape() { export function useOnlyFlippableShape() {
const editor = useEditor() const editor = useEditor()
@ -17,10 +12,10 @@ export function useOnlyFlippableShape() {
selectedShapes.length === 1 && selectedShapes.length === 1 &&
selectedShapes.every( selectedShapes.every(
(shape) => (shape) =>
editor.isShapeOfType(shape, GroupShapeUtil) || editor.isShapeOfType<TLGroupShape>(shape, 'group') ||
editor.isShapeOfType(shape, ArrowShapeUtil) || editor.isShapeOfType<TLArrowShape>(shape, 'arrow') ||
editor.isShapeOfType(shape, LineShapeUtil) || editor.isShapeOfType<TLLineShape>(shape, 'line') ||
editor.isShapeOfType(shape, DrawShapeUtil) editor.isShapeOfType<TLDrawShape>(shape, 'draw')
) )
) )
}, },

View file

@ -1,4 +1,4 @@
import { TextShapeUtil, useEditor } from '@tldraw/editor' import { TLTextShape, useEditor } from '@tldraw/editor'
import { useValue } from '@tldraw/state' import { useValue } from '@tldraw/state'
export function useShowAutoSizeToggle() { export function useShowAutoSizeToggle() {
@ -9,7 +9,7 @@ export function useShowAutoSizeToggle() {
const { selectedShapes } = editor const { selectedShapes } = editor
return ( return (
selectedShapes.length === 1 && selectedShapes.length === 1 &&
editor.isShapeOfType(selectedShapes[0], TextShapeUtil) && editor.isShapeOfType<TLTextShape>(selectedShapes[0], 'text') &&
selectedShapes[0].props.autoSize === false selectedShapes[0].props.autoSize === false
) )
}, },