Add optional generic to updateShapes / createShapes (#1579)

This PR adds a generic that we can use with `updateShapes` and
`createShapes` in order to type the partials being passed into those
methods. By default, the partials are typed as `TLUnknownShape`, which
accepts any props.

### Change Type

- [x] `minor` — New feature

### Test Plan

- [x] Unit Tests

### Release Notes

- [editor] adds an optional shape generic to `updateShapes` and
`createShapes`
This commit is contained in:
Steve Ruiz 2023-06-13 19:02:17 +01:00 committed by GitHub
parent 69e5b248ca
commit ce1cf82029
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 436 additions and 274 deletions

View file

@ -1,4 +1,11 @@
import { createShapeId, Editor, Tldraw, TLGeoShape, useEditor } from '@tldraw/tldraw' import {
createShapeId,
Editor,
Tldraw,
TLGeoShape,
TLShapePartial,
useEditor,
} from '@tldraw/tldraw'
import '@tldraw/tldraw/editor.css' import '@tldraw/tldraw/editor.css'
import '@tldraw/tldraw/ui.css' import '@tldraw/tldraw/ui.css'
import { useEffect } from 'react' import { useEffect } from 'react'
@ -18,7 +25,7 @@ export default function APIExample() {
editor.focus() editor.focus()
// Create a shape // Create a shape
editor.createShapes([ editor.createShapes<TLGeoShape>([
{ {
id, id,
type: 'geo', type: 'geo',
@ -38,17 +45,17 @@ export default function APIExample() {
// Get the created shape // Get the created shape
const shape = editor.getShapeById<TLGeoShape>(id)! const shape = editor.getShapeById<TLGeoShape>(id)!
// Update the shape const shapeUpdate: TLShapePartial<TLGeoShape> = {
editor.updateShapes([
{
id, id,
type: 'geo', type: 'geo',
props: { props: {
h: shape.props.h * 3, h: shape.props.h * 3,
text: 'hello world!', text: 'hello world!',
}, },
}, }
])
// Update the shape
editor.updateShapes([shapeUpdate])
// Select the shape // Select the shape
editor.select(id) editor.select(id)

View file

@ -1,4 +1,4 @@
import { createShapeId, Tldraw } from '@tldraw/tldraw' import { createShapeId, Tldraw, TLShapePartial } from '@tldraw/tldraw'
import '@tldraw/tldraw/editor.css' import '@tldraw/tldraw/editor.css'
import '@tldraw/tldraw/ui.css' import '@tldraw/tldraw/ui.css'
import { ErrorShape } from './ErrorShape' import { ErrorShape } from './ErrorShape'
@ -16,16 +16,16 @@ export default function ErrorBoundaryExample() {
ShapeErrorFallback: ({ error }) => <div>Shape error! {String(error)}</div>, // use a custom error fallback for shapes ShapeErrorFallback: ({ error }) => <div>Shape error! {String(error)}</div>, // use a custom error fallback for shapes
}} }}
onMount={(editor) => { onMount={(editor) => {
// When the app starts, create our error shape so we can see. const errorShapePartial: TLShapePartial<ErrorShape> = {
editor.createShapes([
{
type: 'error', type: 'error',
id: createShapeId(), id: createShapeId(),
x: 0, x: 0,
y: 0, y: 0,
props: { message: 'Something has gone wrong' }, props: { message: 'Something has gone wrong' },
}, }
])
// When the app starts, create our error shape so we can see.
editor.createShapes<ErrorShape>([errorShapePartial])
// Center the camera on the error shape // Center the camera on the error shape
editor.zoomToFit() editor.zoomToFit()

View file

@ -424,7 +424,7 @@ export class Editor extends EventEmitter<TLEventMap> {
}; };
}; };
createPage(title: string, id?: TLPageId, belowPageIndex?: string): this; createPage(title: string, id?: TLPageId, belowPageIndex?: string): this;
createShapes(partials: TLShapePartial[], select?: boolean): this; createShapes<T extends TLUnknownShape>(partials: TLShapePartial<T>[], select?: boolean): this;
get croppingId(): null | TLShapeId; get croppingId(): null | TLShapeId;
get cullingBounds(): Box2d; get cullingBounds(): Box2d;
// @internal (undocumented) // @internal (undocumented)
@ -760,7 +760,7 @@ export class Editor extends EventEmitter<TLEventMap> {
updateDocumentSettings(settings: Partial<TLDocument>): void; updateDocumentSettings(settings: Partial<TLDocument>): void;
updateInstanceState(partial: Partial<Omit<TLInstance, 'currentPageId'>>, ephemeral?: boolean, squashing?: boolean): this; updateInstanceState(partial: Partial<Omit<TLInstance, 'currentPageId'>>, ephemeral?: boolean, squashing?: boolean): this;
updatePage(partial: RequiredKeys<TLPage, 'id'>, squashing?: boolean): this; updatePage(partial: RequiredKeys<TLPage, 'id'>, squashing?: boolean): this;
updateShapes(partials: (null | TLShapePartial | undefined)[], squashing?: boolean): this; updateShapes<T extends TLUnknownShape>(partials: (null | TLShapePartial<T> | undefined)[], squashing?: boolean): this;
updateViewportScreenBounds(center?: boolean): this; updateViewportScreenBounds(center?: boolean): this;
// (undocumented) // (undocumented)
readonly user: UserPreferencesManager; readonly user: UserPreferencesManager;

View file

@ -1348,7 +1348,7 @@ export class Editor extends EventEmitter<TLEventMap> {
} }
if (finalIndex !== reparentedArrow.index) { if (finalIndex !== reparentedArrow.index) {
this.updateShapes([{ id: arrowId, type: 'arrow', index: finalIndex }]) this.updateShapes<TLArrowShape>([{ id: arrowId, type: 'arrow', index: finalIndex }])
} }
} }
@ -4636,14 +4636,14 @@ export class Editor extends EventEmitter<TLEventMap> {
* @example * @example
* *
* ```ts * ```ts
* editor.createShapes([{ id: 'box1', type: 'box' }]) * editor.createShapes([{ id: 'box1', type: 'text', props: { text: "ok" } }])
* ``` * ```
* *
* @param partials - The shape partials to create. * @param partials - The shape partials to create.
* @param select - Whether to select the created shapes. Defaults to false. * @param select - Whether to select the created shapes. Defaults to false.
* @public * @public
*/ */
createShapes(partials: TLShapePartial[], select = false) { createShapes<T extends TLUnknownShape>(partials: TLShapePartial<T>[], select = false) {
this._createShapes(partials, select) this._createShapes(partials, select)
return this return this
} }
@ -4934,14 +4934,17 @@ export class Editor extends EventEmitter<TLEventMap> {
* @example * @example
* *
* ```ts * ```ts
* editor.updateShapes([{ id: 'box1', type: 'box', x: 100, y: 100 }]) * editor.updateShapes([{ id: 'box1', type: 'geo', props: { w: 100, h: 100 } }])
* ``` * ```
* *
* @param partials - The shape partials to update. * @param partials - The shape partials to update.
* @param squashing - Whether the change is ephemeral. * @param squashing - Whether the change is ephemeral.
* @public * @public
*/ */
updateShapes(partials: (TLShapePartial | null | undefined)[], squashing = false) { updateShapes<T extends TLUnknownShape>(
partials: (TLShapePartial<T> | null | undefined)[],
squashing = false
) {
let compactedPartials = compact(partials) let compactedPartials = compact(partials)
if (this.animatingShapes.size > 0) { if (this.animatingShapes.size > 0) {
compactedPartials.forEach((p) => this.animatingShapes.delete(p.id)) compactedPartials.forEach((p) => this.animatingShapes.delete(p.id))
@ -8985,7 +8988,7 @@ export class Editor extends EventEmitter<TLEventMap> {
const highestIndex = shapesWithRootParent[shapesWithRootParent.length - 1]?.index const highestIndex = shapesWithRootParent[shapesWithRootParent.length - 1]?.index
this.batch(() => { this.batch(() => {
this.createShapes([ this.createShapes<TLGroupShape>([
{ {
id: groupId, id: groupId,
type: 'group', type: 'group',

View file

@ -1,4 +1,4 @@
import { TLShapeId } from '@tldraw/tlschema' import { TLArrowShape, TLShapeId } from '@tldraw/tlschema'
import { TestEditor } from '../../test/TestEditor' import { TestEditor } from '../../test/TestEditor'
import { TL } from '../../test/jsx' import { TL } from '../../test/jsx'
@ -212,7 +212,7 @@ describe('arrowBindingsIndex', () => {
) )
// move arrowA from box2 to box3 // move arrowA from box2 to box3
editor.updateShapes([ editor.updateShapes<TLArrowShape>([
{ {
id: arrowAId, id: arrowAId,
type: 'arrow', type: 'arrow',

View file

@ -6,6 +6,8 @@ import {
TLAssetId, TLAssetId,
TLEmbedShape, TLEmbedShape,
TLShapePartial, TLShapePartial,
TLTextShape,
TLTextShapeProps,
createShapeId, createShapeId,
} from '@tldraw/tlschema' } from '@tldraw/tlschema'
import { compact, getHashForString } from '@tldraw/utils' import { compact, getHashForString } from '@tldraw/utils'
@ -248,7 +250,7 @@ export class ExternalContentManager {
let w: number let w: number
let h: number let h: number
let autoSize: boolean let autoSize: boolean
let align = 'middle' let align = 'middle' as TLTextShapeProps['align']
const isMultiLine = textToPaste.split('\n').length > 1 const isMultiLine = textToPaste.split('\n').length > 1
@ -293,7 +295,7 @@ export class ExternalContentManager {
p.y = editor.viewportPageBounds.minY + 40 + h / 2 p.y = editor.viewportPageBounds.minY + 40 + h / 2
} }
editor.createShapes([ editor.createShapes<TLTextShape>([
{ {
id: createShapeId(), id: createShapeId(),
type: 'text', type: 'text',

View file

@ -907,7 +907,7 @@ export class ArrowShapeUtil extends ShapeUtil<TLArrowShape> {
} = shape } = shape
if (text.trimEnd() !== shape.props.text) { if (text.trimEnd() !== shape.props.text) {
this.editor.updateShapes([ this.editor.updateShapes<TLArrowShape>([
{ {
id, id,
type, type,

View file

@ -2,7 +2,6 @@ import { createShapeId, TLArrowShape } from '@tldraw/tlschema'
import { ArrowShapeUtil } from '../../../shapes/arrow/ArrowShapeUtil' import { ArrowShapeUtil } from '../../../shapes/arrow/ArrowShapeUtil'
import { StateNode } from '../../../tools/StateNode' import { StateNode } from '../../../tools/StateNode'
import { TLEventHandlers } from '../../../types/event-types' import { TLEventHandlers } from '../../../types/event-types'
import { ArrowShapeTool } from '../ArrowShapeTool'
export class Pointing extends StateNode { export class Pointing extends StateNode {
static override id = 'pointing' static override id = 'pointing'
@ -31,16 +30,14 @@ export class Pointing extends StateNode {
this.didTimeout = false this.didTimeout = false
const shapeType = (this.parent as ArrowShapeTool).shapeType
this.editor.mark('creating') this.editor.mark('creating')
const id = createShapeId() const id = createShapeId()
this.editor.createShapes([ this.editor.createShapes<TLArrowShape>([
{ {
id, id,
type: shapeType, type: 'arrow',
x: currentPagePoint.x, x: currentPagePoint.x,
y: currentPagePoint.y, y: currentPagePoint.y,
}, },

View file

@ -139,7 +139,7 @@ function updateBookmarkAssetOnUrlChange(editor: Editor, shape: TLBookmarkShape)
if (editor.getAssetById(assetId)) { if (editor.getAssetById(assetId)) {
// Existing asset for this URL? // Existing asset for this URL?
if (shape.props.assetId !== assetId) { if (shape.props.assetId !== assetId) {
editor.updateShapes([ editor.updateShapes<TLBookmarkShape>([
{ {
id: shape.id, id: shape.id,
type: shape.type, type: shape.type,
@ -151,7 +151,7 @@ function updateBookmarkAssetOnUrlChange(editor: Editor, shape: TLBookmarkShape)
// No asset for this URL? // No asset for this URL?
// First, clear out the existing asset reference // First, clear out the existing asset reference
editor.updateShapes([ editor.updateShapes<TLBookmarkShape>([
{ {
id: shape.id, id: shape.id,
type: shape.type, type: shape.type,
@ -181,7 +181,7 @@ const createBookmarkAssetOnUrlChange = debounce(async (editor: Editor, shape: TL
editor.createAssets([asset]) editor.createAssets([asset])
// And update the shape // And update the shape
editor.updateShapes([ editor.updateShapes<TLBookmarkShape>([
{ {
id: shape.id, id: shape.id,
type: shape.type, type: shape.type,

View file

@ -4,6 +4,7 @@ import {
TLDrawShape, TLDrawShape,
TLDrawShapeSegment, TLDrawShapeSegment,
TLHighlightShape, TLHighlightShape,
TLShapePartial,
TLSizeType, TLSizeType,
Vec2dModel, Vec2dModel,
} from '@tldraw/tlschema' } from '@tldraw/tlschema'
@ -25,7 +26,7 @@ export class Drawing extends StateNode {
initialShape?: DrawableShape initialShape?: DrawableShape
shapeType: 'draw' | 'highlight' = this.parent.id === 'highlight' ? 'highlight' : 'draw' shapeType: DrawableShape['type'] = this.parent.id === 'highlight' ? 'highlight' : 'draw'
util = util =
this.shapeType === 'highlight' this.shapeType === 'highlight'
@ -219,16 +220,22 @@ export class Drawing extends StateNode {
this.currentLineLength = this.getLineLength(segments) this.currentLineLength = this.getLineLength(segments)
this.editor.updateShapes([ const shapePartial: TLShapePartial<DrawableShape> = {
{
id: shape.id, id: shape.id,
type: this.shapeType, type: this.shapeType,
props: { props: {
segments, segments,
isClosed: this.canClose() ? this.getIsClosed(segments, shape.props.size) : undefined,
}, },
}, }
])
if (this.canClose()) {
;(shapePartial as TLShapePartial<TLDrawShape>).props!.isClosed = this.getIsClosed(
segments,
shape.props.size
)
}
this.editor.updateShapes<TLDrawShape | TLHighlightShape>([shapePartial])
return return
} }
@ -238,7 +245,8 @@ export class Drawing extends StateNode {
this.pagePointWhereCurrentSegmentChanged = originPagePoint.clone() this.pagePointWhereCurrentSegmentChanged = originPagePoint.clone()
const id = createShapeId() const id = createShapeId()
this.editor.createShapes([
this.editor.createShapes<DrawableShape>([
{ {
id, id,
type: this.shapeType, type: this.shapeType,
@ -261,7 +269,6 @@ export class Drawing extends StateNode {
}, },
}, },
]) ])
this.currentLineLength = 0 this.currentLineLength = 0
this.initialShape = this.editor.getShapeById<DrawableShape>(id) this.initialShape = this.editor.getShapeById<DrawableShape>(id)
} }
@ -343,20 +350,23 @@ export class Drawing extends StateNode {
} }
} }
this.editor.updateShapes( const shapePartial: TLShapePartial<DrawableShape> = {
[
{
id, id,
type: this.shapeType, type: this.shapeType,
props: { props: {
segments: [...segments, newSegment], segments: [...segments, newSegment],
isClosed: this.canClose() ? this.getIsClosed(segments, size) : undefined,
}, },
}, }
],
true if (this.canClose()) {
;(shapePartial as TLShapePartial<TLDrawShape>).props!.isClosed = this.getIsClosed(
segments,
size
) )
} }
this.editor.updateShapes<TLDrawShape | TLHighlightShape>([shapePartial], true)
}
break break
} }
case 'starting_free': { case 'starting_free': {
@ -400,21 +410,24 @@ export class Drawing extends StateNode {
const finalSegments = [...newSegments, newFreeSegment] const finalSegments = [...newSegments, newFreeSegment]
this.currentLineLength = this.getLineLength(finalSegments) this.currentLineLength = this.getLineLength(finalSegments)
this.editor.updateShapes( const shapePartial: TLShapePartial<DrawableShape> = {
[
{
id, id,
type: this.shapeType, type: this.shapeType,
props: { props: {
segments: finalSegments, segments: finalSegments,
isClosed: this.canClose() ? this.getIsClosed(finalSegments, size) : undefined,
}, },
}, }
],
true if (this.canClose()) {
;(shapePartial as TLShapePartial<TLDrawShape>).props!.isClosed = this.getIsClosed(
finalSegments,
size
) )
} }
this.editor.updateShapes([shapePartial], true)
}
break break
} }
case 'straight': { case 'straight': {
@ -539,19 +552,22 @@ export class Drawing extends StateNode {
points: [newSegment.points[0], newPoint], points: [newSegment.points[0], newPoint],
} }
this.editor.updateShapes( const shapePartial: TLShapePartial<DrawableShape> = {
[
{
id, id,
type: this.shapeType, type: this.shapeType,
props: { props: {
segments: newSegments, segments: newSegments,
isClosed: this.canClose() ? this.getIsClosed(segments, size) : undefined,
}, },
}, }
],
true if (this.canClose()) {
;(shapePartial as TLShapePartial<TLDrawShape>).props!.isClosed = this.getIsClosed(
segments,
size
) )
}
this.editor.updateShapes([shapePartial], true)
break break
} }
@ -581,19 +597,22 @@ export class Drawing extends StateNode {
this.currentLineLength = this.getLineLength(newSegments) this.currentLineLength = this.getLineLength(newSegments)
this.editor.updateShapes( const shapePartial: TLShapePartial<DrawableShape> = {
[
{
id, id,
type: this.shapeType, type: this.shapeType,
props: { props: {
segments: newSegments, segments: newSegments,
isClosed: this.canClose() ? this.getIsClosed(segments, size) : undefined,
}, },
}, }
],
true if (this.canClose()) {
;(shapePartial as TLShapePartial<TLDrawShape>).props!.isClosed = this.getIsClosed(
newSegments,
size
) )
}
this.editor.updateShapes([shapePartial], true)
// Set a maximum length for the lines array; after 200 points, complete the line. // Set a maximum length for the lines array; after 200 points, complete the line.
if (newPoints.length > 500) { if (newPoints.length > 500) {
@ -603,7 +622,7 @@ export class Drawing extends StateNode {
const newShapeId = createShapeId() const newShapeId = createShapeId()
this.editor.createShapes([ this.editor.createShapes<DrawableShape>([
{ {
id: newShapeId, id: newShapeId,
type: this.shapeType, type: this.shapeType,

View file

@ -14,7 +14,7 @@ export class Pointing extends StateNode {
this.editor.mark('creating') this.editor.mark('creating')
this.editor.createShapes([ this.editor.createShapes<TLGeoShape>([
{ {
id, id,
type: 'geo', type: 'geo',
@ -63,7 +63,7 @@ export class Pointing extends StateNode {
this.editor.mark('creating') this.editor.mark('creating')
this.editor.createShapes([ this.editor.createShapes<TLGeoShape>([
{ {
id, id,
type: 'geo', type: 'geo',
@ -85,7 +85,7 @@ export class Pointing extends StateNode {
const delta = this.editor.getDeltaInParentSpace(shape, bounds.center) const delta = this.editor.getDeltaInParentSpace(shape, bounds.center)
this.editor.select(id) this.editor.select(id)
this.editor.updateShapes([ this.editor.updateShapes<TLGeoShape>([
{ {
id: shape.id, id: shape.id,
type: 'geo', type: 'geo',

View file

@ -4,7 +4,6 @@ import { TLHandle, TLLineShape, TLShapeId, createShapeId } from '@tldraw/tlschem
import { last, structuredClone } from '@tldraw/utils' import { last, structuredClone } from '@tldraw/utils'
import { StateNode } from '../../../tools/StateNode' import { StateNode } from '../../../tools/StateNode'
import { TLEventHandlers, TLInterruptEvent } from '../../../types/event-types' import { TLEventHandlers, TLInterruptEvent } from '../../../types/event-types'
import { LineShapeTool } from '../LineShapeTool'
export class Pointing extends StateNode { export class Pointing extends StateNode {
static override id = 'pointing' static override id = 'pointing'
@ -79,10 +78,10 @@ export class Pointing extends StateNode {
} else { } else {
const id = createShapeId() const id = createShapeId()
this.editor.createShapes([ this.editor.createShapes<TLLineShape>([
{ {
id, id,
type: (this.parent as LineShapeTool).shapeType, type: 'line',
x: currentPagePoint.x, x: currentPagePoint.x,
y: currentPagePoint.y, y: currentPagePoint.y,
}, },

View file

@ -1,5 +1,5 @@
/* eslint-disable no-inner-declarations */ /* eslint-disable no-inner-declarations */
import { TLShape } from '@tldraw/tlschema' import { TLShape, TLUnknownShape } from '@tldraw/tlschema'
import React, { useCallback, useEffect, useRef } from 'react' import React, { useCallback, useEffect, useRef } from 'react'
import { useValue } from 'signia-react' import { useValue } from 'signia-react'
import { useEditor } from '../../../hooks/useEditor' import { useEditor } from '../../../hooks/useEditor'
@ -145,7 +145,9 @@ export function useEditableText<T extends Extract<TLShape, { props: { text: stri
} }
// ---------------------------- // ----------------------------
editor.updateShapes([{ id, type, props: { text } }]) editor.updateShapes<TLUnknownShape & { props: { text: string } }>([
{ id, type, props: { text } },
])
}, },
[editor, id, type] [editor, id, type]
) )

View file

@ -21,7 +21,7 @@ export class Pointing extends StateNode {
this.editor.mark('creating') this.editor.mark('creating')
this.editor.createShapes([ this.editor.createShapes<TLTextShape>([
{ {
id, id,
type: 'text', type: 'text',

View file

@ -27,7 +27,8 @@ export class Pointing extends StateNode {
this.editor.mark(this.markId) this.editor.mark(this.markId)
this.editor.createShapes([ this.editor.createShapes<TLBaseBoxShape>(
[
{ {
id, id,
type: shapeType, type: shapeType,
@ -38,9 +39,9 @@ export class Pointing extends StateNode {
h: 1, h: 1,
}, },
}, },
]) ],
true
this.editor.setSelectedIds([id]) )
this.editor.setSelectedTool('select.resizing', { this.editor.setSelectedTool('select.resizing', {
...info, ...info,
target: 'selection', target: 'selection',
@ -83,7 +84,7 @@ export class Pointing extends StateNode {
this.editor.mark(this.markId) this.editor.mark(this.markId)
this.editor.createShapes([ this.editor.createShapes<TLBaseBoxShape>([
{ {
id, id,
type: shapeType, type: shapeType,
@ -96,7 +97,7 @@ export class Pointing extends StateNode {
const { w, h } = this.editor.getShapeUtil(shape).defaultProps() as TLBaseBoxShape['props'] const { w, h } = this.editor.getShapeUtil(shape).defaultProps() as TLBaseBoxShape['props']
const delta = this.editor.getDeltaInParentSpace(shape, new Vec2d(w / 2, h / 2)) const delta = this.editor.getDeltaInParentSpace(shape, new Vec2d(w / 2, h / 2))
this.editor.updateShapes([ this.editor.updateShapes<TLBaseBoxShape>([
{ {
id, id,
type: shapeType, type: shapeType,

View file

@ -185,7 +185,7 @@ export class Idle extends StateNode {
this.editor.mark('translate crop') this.editor.mark('translate crop')
} }
this.editor.updateShapes([partial]) this.editor.updateShapes<ShapeWithCrop>([partial])
} }
} }
} }

View file

@ -1,5 +1,5 @@
import { Vec2d } from '@tldraw/primitives' import { Vec2d } from '@tldraw/primitives'
import { TLGeoShape, TLShape, createShapeId } from '@tldraw/tlschema' import { TLGeoShape, TLShape, TLTextShape, createShapeId } from '@tldraw/tlschema'
import { debugFlags } from '../../../../utils/debug-flags' import { debugFlags } from '../../../../utils/debug-flags'
import { import {
TLClickEventInfo, TLClickEventInfo,
@ -382,7 +382,7 @@ export class Idle extends StateNode {
const { x, y } = this.editor.inputs.currentPagePoint const { x, y } = this.editor.inputs.currentPagePoint
this.editor.createShapes([ this.editor.createShapes<TLTextShape>([
{ {
id, id,
type: 'text', type: 'text',

View file

@ -1,4 +1,4 @@
import { TLGeoShape, createShapeId } from '@tldraw/tlschema' import { TLArrowShape, TLGeoShape, createShapeId } from '@tldraw/tlschema'
import { TestEditor } from '../TestEditor' import { TestEditor } from '../TestEditor'
let editor: TestEditor let editor: TestEditor
@ -16,6 +16,72 @@ beforeEach(() => {
editor = new TestEditor() editor = new TestEditor()
}) })
it('Uses typescript generics', () => {
expect(() => {
// No error here because no generic, the editor doesn't know what this guy is
editor.createShapes([
{
id: ids.box1,
type: 'geo',
props: { w: 'OH NO' },
},
])
// Yep error here because we are giving the wrong props to the shape
editor.createShapes<TLGeoShape>([
{
id: ids.box1,
type: 'geo',
//@ts-expect-error
props: { w: 'OH NO' },
},
])
// Yep error here because we are giving the wrong generic
editor.createShapes<TLArrowShape>([
{
id: ids.box1,
//@ts-expect-error
type: 'geo',
//@ts-expect-error
props: { w: 'OH NO' },
},
])
// All good, correct match of generic and shape type
editor.createShapes<TLGeoShape>([
{
id: ids.box1,
type: 'geo',
props: { w: 100 },
},
])
editor.createShapes<TLGeoShape>([
{
id: ids.box1,
type: 'geo',
},
{
id: ids.box1,
// @ts-expect-error - wrong type
type: 'arrow',
},
])
// Unions are supported just fine
editor.createShapes<TLGeoShape | TLArrowShape>([
{
id: ids.box1,
type: 'geo',
},
{
id: ids.box1,
type: 'arrow',
},
])
}).toThrowError()
})
it('Parents shapes to the current page if the parent is not found', () => { it('Parents shapes to the current page if the parent is not found', () => {
editor.createShapes([{ id: ids.box1, parentId: ids.missing, type: 'geo' }]) editor.createShapes([{ id: ids.box1, parentId: ids.missing, type: 'geo' }])
expect(editor.getShapeById(ids.box1)!.parentId).toEqual(editor.currentPageId) expect(editor.getShapeById(ids.box1)!.parentId).toEqual(editor.currentPageId)

View file

@ -1,4 +1,4 @@
import { createShapeId } from '@tldraw/tlschema' import { createShapeId, TLArrowShape, TLGeoShape } from '@tldraw/tlschema'
import { createDefaultShapes, TestEditor } from '../TestEditor' import { createDefaultShapes, TestEditor } from '../TestEditor'
let editor: TestEditor let editor: TestEditor
@ -13,6 +13,73 @@ beforeEach(() => {
editor.createShapes(createDefaultShapes()) editor.createShapes(createDefaultShapes())
}) })
it('Uses typescript generics', () => {
expect(() => {
// No error here because no generic, the editor doesn't know what this guy is
editor.updateShapes([
{
id: ids.box1,
type: 'geo',
props: { w: 'OH NO' },
},
])
// Yep error here because we are giving the wrong props to the shape
editor.updateShapes<TLGeoShape>([
{
id: ids.box1,
type: 'geo',
//@ts-expect-error
props: { w: 'OH NO' },
},
])
// Yep error here because we are giving the wrong generic
editor.updateShapes<TLArrowShape>([
{
id: ids.box1,
//@ts-expect-error
type: 'geo',
//@ts-expect-error
props: { w: 'OH NO' },
},
])
// All good, correct match of generic and shape type
editor.updateShapes<TLGeoShape>([
{
id: ids.box1,
type: 'geo',
props: { w: 100 },
},
])
editor.updateShapes<TLGeoShape>([
{
id: ids.box1,
type: 'geo',
},
{
id: ids.box1,
// @ts-expect-error - wrong type
type: 'arrow',
},
])
// Unions are supported just fine
editor.updateShapes<TLGeoShape | TLArrowShape>([
{
id: ids.box1,
type: 'geo',
},
{
id: ids.box1,
type: 'arrow',
},
])
}).toThrowError()
})
it('updates shapes', () => { it('updates shapes', () => {
editor.mark('update shapes') editor.mark('update shapes')
editor.updateShapes([ editor.updateShapes([

View file

@ -19,7 +19,6 @@ import {
TLNoteShape, TLNoteShape,
TLPageId, TLPageId,
TLShapeId, TLShapeId,
TLShapePartial,
TLSizeType, TLSizeType,
TLTextShape, TLTextShape,
TLVideoShape, TLVideoShape,
@ -182,7 +181,8 @@ export function buildFromV1Document(editor: Editor, document: LegacyTldrawDocume
switch (v1Shape.type) { switch (v1Shape.type) {
case TDShapeType.Sticky: { case TDShapeType.Sticky: {
const partial: TLShapePartial<TLNoteShape> = { editor.createShapes<TLNoteShape>([
{
...inCommon, ...inCommon,
type: 'note', type: 'note',
props: { props: {
@ -192,13 +192,13 @@ export function buildFromV1Document(editor: Editor, document: LegacyTldrawDocume
font: getV2Font(v1Shape.style.font), font: getV2Font(v1Shape.style.font),
align: getV2Align(v1Shape.style.textAlign), align: getV2Align(v1Shape.style.textAlign),
}, },
} },
])
editor.createShapes([partial])
break break
} }
case TDShapeType.Rectangle: { case TDShapeType.Rectangle: {
const partial: TLShapePartial<TLGeoShape> = { editor.createShapes<TLGeoShape>([
{
...inCommon, ...inCommon,
type: 'geo', type: 'geo',
props: { props: {
@ -214,9 +214,8 @@ export function buildFromV1Document(editor: Editor, document: LegacyTldrawDocume
dash: getV2Dash(v1Shape.style.dash), dash: getV2Dash(v1Shape.style.dash),
align: 'middle', align: 'middle',
}, },
} },
])
editor.createShapes([partial])
const pageBoundsBeforeLabel = editor.getPageBoundsById(inCommon.id)! const pageBoundsBeforeLabel = editor.getPageBoundsById(inCommon.id)!
@ -254,7 +253,8 @@ export function buildFromV1Document(editor: Editor, document: LegacyTldrawDocume
break break
} }
case TDShapeType.Triangle: { case TDShapeType.Triangle: {
const partial: TLShapePartial<TLGeoShape> = { editor.createShapes<TLGeoShape>([
{
...inCommon, ...inCommon,
type: 'geo', type: 'geo',
props: { props: {
@ -269,9 +269,8 @@ export function buildFromV1Document(editor: Editor, document: LegacyTldrawDocume
dash: getV2Dash(v1Shape.style.dash), dash: getV2Dash(v1Shape.style.dash),
align: 'middle', align: 'middle',
}, },
} },
])
editor.createShapes([partial])
const pageBoundsBeforeLabel = editor.getPageBoundsById(inCommon.id)! const pageBoundsBeforeLabel = editor.getPageBoundsById(inCommon.id)!
@ -309,7 +308,8 @@ export function buildFromV1Document(editor: Editor, document: LegacyTldrawDocume
break break
} }
case TDShapeType.Ellipse: { case TDShapeType.Ellipse: {
const partial: TLShapePartial<TLGeoShape> = { editor.createShapes<TLGeoShape>([
{
...inCommon, ...inCommon,
type: 'geo', type: 'geo',
props: { props: {
@ -324,9 +324,8 @@ export function buildFromV1Document(editor: Editor, document: LegacyTldrawDocume
dash: getV2Dash(v1Shape.style.dash), dash: getV2Dash(v1Shape.style.dash),
align: 'middle', align: 'middle',
}, },
} },
])
editor.createShapes([partial])
const pageBoundsBeforeLabel = editor.getPageBoundsById(inCommon.id)! const pageBoundsBeforeLabel = editor.getPageBoundsById(inCommon.id)!
@ -370,7 +369,8 @@ export function buildFromV1Document(editor: Editor, document: LegacyTldrawDocume
break break
} }
const partial: TLShapePartial<TLDrawShape> = { editor.createShapes<TLDrawShape>([
{
...inCommon, ...inCommon,
type: 'draw', type: 'draw',
props: { props: {
@ -382,9 +382,8 @@ export function buildFromV1Document(editor: Editor, document: LegacyTldrawDocume
isComplete: v1Shape.isComplete, isComplete: v1Shape.isComplete,
segments: [{ type: 'free', points: v1Shape.points.map(getV2Point) }], segments: [{ type: 'free', points: v1Shape.points.map(getV2Point) }],
}, },
} },
])
editor.createShapes([partial])
break break
} }
case TDShapeType.Arrow: { case TDShapeType.Arrow: {
@ -395,7 +394,8 @@ export function buildFromV1Document(editor: Editor, document: LegacyTldrawDocume
const v2Bend = (dist * -v1Bend) / 2 const v2Bend = (dist * -v1Bend) / 2
// Could also be a line... but we'll use it as an arrow anyway // Could also be a line... but we'll use it as an arrow anyway
const partial: TLShapePartial<TLArrowShape> = { editor.createShapes<TLArrowShape>([
{
...inCommon, ...inCommon,
type: 'arrow', type: 'arrow',
props: { props: {
@ -419,14 +419,14 @@ export function buildFromV1Document(editor: Editor, document: LegacyTldrawDocume
}, },
bend: v2Bend, bend: v2Bend,
}, },
} },
])
editor.createShapes([partial])
break break
} }
case TDShapeType.Text: { case TDShapeType.Text: {
const partial: TLShapePartial<TLTextShape> = { editor.createShapes<TLTextShape>([
{
...inCommon, ...inCommon,
type: 'text', type: 'text',
props: { props: {
@ -437,9 +437,8 @@ export function buildFromV1Document(editor: Editor, document: LegacyTldrawDocume
align: getV2Align(v1Shape.style.textAlign), align: getV2Align(v1Shape.style.textAlign),
scale: v1Shape.style.scale ?? 1, scale: v1Shape.style.scale ?? 1,
}, },
} },
])
editor.createShapes([partial])
break break
} }
case TDShapeType.Image: { case TDShapeType.Image: {
@ -450,7 +449,8 @@ export function buildFromV1Document(editor: Editor, document: LegacyTldrawDocume
return return
} }
const partial: TLShapePartial<TLImageShape> = { editor.createShapes<TLImageShape>([
{
...inCommon, ...inCommon,
type: 'image', type: 'image',
props: { props: {
@ -458,9 +458,8 @@ export function buildFromV1Document(editor: Editor, document: LegacyTldrawDocume
h: coerceDimension(v1Shape.size[1]), h: coerceDimension(v1Shape.size[1]),
assetId, assetId,
}, },
} },
])
editor.createShapes([partial])
break break
} }
case TDShapeType.Video: { case TDShapeType.Video: {
@ -471,7 +470,8 @@ export function buildFromV1Document(editor: Editor, document: LegacyTldrawDocume
return return
} }
const partial: TLShapePartial<TLVideoShape> = { editor.createShapes<TLVideoShape>([
{
...inCommon, ...inCommon,
type: 'video', type: 'video',
props: { props: {
@ -479,9 +479,8 @@ export function buildFromV1Document(editor: Editor, document: LegacyTldrawDocume
h: coerceDimension(v1Shape.size[1]), h: coerceDimension(v1Shape.size[1]),
assetId, assetId,
}, },
} },
])
editor.createShapes([partial])
break break
} }
} }