big refactor

This commit is contained in:
Steve Ruiz 2021-09-13 16:38:42 +01:00
parent 612269ab38
commit 2f4a1f97a2
64 changed files with 1769 additions and 2004 deletions

3
.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,3 @@
{
"typescript.tsdk": "node_modules/typescript/lib"
}

View file

@ -1,11 +1,11 @@
import { render } from '@testing-library/react'
import * as React from 'react'
import { renderWithSvg } from '+test'
import { Binding } from './binding'
jest.spyOn(console, 'error').mockImplementation(() => void null)
describe('binding', () => {
test('mounts component without crashing', () => {
renderWithSvg(<Binding point={[0, 0]} type={'anchor'} />)
render(<Binding point={[0, 0]} type={'anchor'} />)
})
})

View file

@ -1,10 +1,10 @@
import { render } from '@testing-library/react'
import * as React from 'react'
import { renderWithSvg } from '+test'
import { Bounds } from './bounds'
describe('bounds', () => {
test('mounts component without crashing', () => {
renderWithSvg(
render(
<Bounds
zoom={1}
bounds={{ minX: 0, minY: 0, maxX: 100, maxY: 100, width: 100, height: 100 }}

View file

@ -1,10 +1,10 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import * as React from 'react'
import type {
TLShape,
TLPage,
TLPageState,
TLCallbacks,
TLShapeUtils,
TLTheme,
TLBounds,
TLBinding,
@ -12,16 +12,15 @@ import type {
import { Canvas } from '../canvas'
import { Inputs } from '../../inputs'
import { useTLTheme, TLContext, TLContextType } from '../../hooks'
import type { TLShapeUtil } from '+index'
export interface RendererProps<
T extends TLShape,
E extends Element,
M extends Record<string, unknown>
> extends Partial<TLCallbacks<T>> {
export interface RendererProps<T extends TLShape, E extends Element, M = any>
extends Partial<TLCallbacks<T>> {
/**
* An object containing instances of your shape classes.
*/
shapeUtils: TLShapeUtils<T, E>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
shapeUtils: Record<T['type'], TLShapeUtil<T, E, M>>
/**
* The current page, containing shapes and bindings.
*/
@ -89,7 +88,7 @@ export function Renderer<T extends TLShape, E extends Element, M extends Record<
rest
const [context] = React.useState<TLContextType<T, E>>(() => ({
const [context] = React.useState<TLContextType<T, E, M>>(() => ({
callbacks: rest,
shapeUtils,
rScreenBounds,

View file

@ -17,7 +17,9 @@ export const ShapeIndicator = React.memo(
}
>
<svg width="100%" height="100%">
<g className="tl-centered-g">{utils.renderIndicator(shape)}</g>
<g className="tl-centered-g">
<utils.Indicator shape={shape} />
</g>
</svg>
</div>
)

View file

@ -3,7 +3,7 @@ import * as React from 'react'
import type { TLShapeUtil, TLRenderInfo, TLShape } from '+types'
export const RenderedShape = React.memo(
<T extends TLShape, E extends Element, M extends Record<string, unknown>>({
<T extends TLShape, E extends Element, M = any>({
shape,
utils,
isEditing,
@ -15,14 +15,18 @@ export const RenderedShape = React.memo(
onShapeBlur,
events,
meta,
}: TLRenderInfo<T, M, E> & {
}: TLRenderInfo<T, E, M> & {
shape: T
utils: TLShapeUtil<T, E>
utils: TLShapeUtil<T, E, M> & {
_Component: React.ForwardRefExoticComponent<
{ shape: T; ref: React.ForwardedRef<E> } & TLRenderInfo<T> & React.RefAttributes<E>
>
}
}) => {
const ref = utils.getRef(shape)
return (
<utils.render
<utils._Component
ref={ref}
shape={shape}
isEditing={isEditing}

View file

@ -1,10 +1,10 @@
import * as React from 'react'
import { mockUtils, renderWithSvg } from '+test'
import { mockUtils, renderWithContext } from '+test'
import { Shape } from './shape'
describe('shape', () => {
test('mounts component without crashing', () => {
renderWithSvg(
renderWithContext(
<Shape
shape={mockUtils.box.create({ id: 'box' })}
utils={mockUtils[mockUtils.box.type]}

View file

@ -7,7 +7,7 @@ import { RenderedShape } from './rendered-shape'
import { Container } from '+components/container'
import { useTLContext } from '+hooks'
export const Shape = <T extends TLShape, E extends Element, M extends Record<string, unknown>>({
export const Shape = <T extends TLShape, E extends Element, M = any>({
shape,
utils,
isEditing,
@ -17,7 +17,7 @@ export const Shape = <T extends TLShape, E extends Element, M extends Record<str
isCurrentParent,
meta,
}: IShapeTreeNode<T, M> & {
utils: TLShapeUtil<T, E>
utils: TLShapeUtil<T, E, M>
}) => {
const { callbacks } = useTLContext()
const bounds = utils.getBounds(shape)

View file

@ -2,10 +2,10 @@ import * as React from 'react'
import type { Inputs } from '+inputs'
import type { TLCallbacks, TLShape, TLBounds, TLPageState, TLShapeUtils } from '+types'
export interface TLContextType<T extends TLShape, E extends Element> {
export interface TLContextType<T extends TLShape, E extends Element, M = any> {
id?: string
callbacks: Partial<TLCallbacks<T>>
shapeUtils: TLShapeUtils<T, E>
shapeUtils: TLShapeUtils<T, E, M>
rPageState: React.MutableRefObject<TLPageState>
rScreenBounds: React.MutableRefObject<TLBounds | null>
inputs: Inputs

View file

@ -2,3 +2,4 @@ export * from './components'
export * from './types'
export * from './utils'
export * from './inputs'
export * from './shapes'

View file

@ -0,0 +1,142 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import * as React from 'react'
/* eslint-disable @typescript-eslint/no-unused-vars */
import type { TLShape, TLBounds } from '+types'
import { ShapeUtil } from './createShape'
import { render } from '@testing-library/react'
import { SVGContainer } from '+components'
import Utils from '+utils'
export interface BoxShape extends TLShape {
size: number[]
}
export const Box = new ShapeUtil<BoxShape, SVGSVGElement, null>(() => {
return {
type: 'box',
defaultProps: {
id: 'example1',
type: 'box',
parentId: 'page',
childIndex: 0,
name: 'Example Shape',
point: [0, 0],
size: [100, 100],
rotation: 0,
},
Component({ shape, events, meta }, ref) {
return (
<SVGContainer ref={ref}>
<g {...events}>
<rect width={shape.size[0]} height={shape.size[1]} fill="none" stroke="black" />
</g>
</SVGContainer>
)
},
Indicator({ shape }) {
return (
<SVGContainer>
<rect width={shape.size[0]} height={shape.size[1]} fill="none" stroke="black" />
</SVGContainer>
)
},
getBounds(shape) {
const bounds = Utils.getFromCache(this.boundsCache, shape, () => {
const [width, height] = shape.size
return {
minX: 0,
maxX: width,
minY: 0,
maxY: height,
width,
height,
} as TLBounds
})
return Utils.translateBounds(bounds, shape.point)
},
getRotatedBounds(shape) {
return {
minX: 0,
minY: 0,
maxX: 100,
maxY: 100,
width: 100,
height: 100,
}
},
shouldRender(prev, next) {
return prev.point !== next.point
},
}
})
const boxShape = {
id: 'example1',
type: 'box',
parentId: 'page',
childIndex: 0,
name: 'Example Shape',
point: [0, 0],
size: [100, 100],
rotation: 0,
}
const box = Box.create({ id: 'box1' })
describe('shape utils', () => {
it('creates a shape utils', () => {
expect(Box).toBeTruthy()
})
it('creates a shape', () => {
expect(Box.create({ id: 'box1' })).toStrictEqual({
...boxShape,
id: 'box1',
})
})
it('sets config', () => {
const bounds = Box.getRotatedBounds(box)
expect(bounds).toStrictEqual({
minX: 0,
minY: 0,
maxX: 100,
maxY: 100,
width: 100,
height: 100,
})
})
test('accesses this in an override method', () => {
expect(Box.shouldRender(box, { ...box, point: [1, 1] })).toBeTruthy()
})
test('mounts component without crashing', () => {
const box = Box.create({ id: 'box1' })
const ref = React.createRef<SVGSVGElement>()
const ref2 = React.createRef<HTMLDivElement>()
const H = React.forwardRef<HTMLDivElement, { message: string }>((props, ref) => {
return <div ref={ref2}>{props.message}</div>
})
render(<H message="Hello" />)
render(
<Box._Component
ref={ref}
shape={box}
isEditing={false}
isBinding={false}
isHovered={false}
isSelected={false}
isCurrentParent={false}
meta={{} as any}
events={{} as any}
/>
)
})
})

View file

@ -0,0 +1,220 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import * as React from 'react'
import { Vec } from '@tldraw/vec'
import type { TLShape, TLShapeUtil } from '+types'
import Utils from '+utils'
import { intersectPolylineBounds, intersectRayBounds } from '@tldraw/intersect'
export const ShapeUtil = function <T extends TLShape, E extends Element, M = any>(
this: TLShapeUtil<T, E, M>,
fn: (
this: TLShapeUtil<T, E, M>
) => Partial<TLShapeUtil<T, E, M>> &
Pick<TLShapeUtil<T, E, M>, 'type' | 'defaultProps' | 'Component' | 'Indicator' | 'getBounds'>
) {
const defaults: Partial<TLShapeUtil<T, E, M>> = {
refMap: new Map(),
boundsCache: new WeakMap(),
canEdit: false,
canBind: false,
isAspectRatioLocked: false,
create: (props) => {
this.refMap.set(props.id, React.createRef())
const defaults = this.defaultProps
return { ...defaults, ...props }
},
getRef: (shape) => {
if (!this.refMap.has(shape.id)) {
this.refMap.set(shape.id, React.createRef<E>())
}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return this.refMap.get(shape.id)!
},
mutate: (shape, props) => {
return { ...shape, ...props }
},
transform: (shape, bounds) => {
return { ...shape, point: [bounds.minX, bounds.minY] }
},
transformSingle: (shape, bounds, info) => {
return this.transform(shape, bounds, info)
},
shouldRender: () => {
return true
},
getRotatedBounds: (shape) => {
return Utils.getBoundsFromPoints(
Utils.getRotatedCorners(this.getBounds(shape), shape.rotation)
)
},
getCenter: (shape) => {
return Utils.getBoundsCenter(this.getBounds(shape))
},
hitTest: (shape, point) => {
return Utils.pointInBounds(point, this.getBounds(shape))
},
hitTestBounds: (shape, bounds) => {
const { minX, minY, maxX, maxY, width, height } = this.getBounds(shape)
const center = [minX + width / 2, minY + height / 2]
const corners = [
[minX, minY],
[maxX, minY],
[maxX, maxY],
[minX, maxY],
].map((point) => Vec.rotWith(point, center, shape.rotation || 0))
return (
corners.every(
(point) =>
!(
point[0] < bounds.minX ||
point[0] > bounds.maxX ||
point[1] < bounds.minY ||
point[1] > bounds.maxY
)
) || intersectPolylineBounds(corners, bounds).length > 0
)
},
getBindingPoint: (shape, fromShape, point, origin, direction, padding, bindAnywhere) => {
// Algorithm time! We need to find the binding point (a normalized point inside of the shape, or around the shape, where the arrow will point to) and the distance from the binding shape to the anchor.
let bindingPoint: number[]
let distance: number
const bounds = this.getBounds(shape)
const expandedBounds = Utils.expandBounds(bounds, padding)
// The point must be inside of the expanded bounding box
if (!Utils.pointInBounds(point, expandedBounds)) return
// The point is inside of the shape, so we'll assume the user is indicating a specific point inside of the shape.
if (bindAnywhere) {
if (Vec.dist(point, this.getCenter(shape)) < 12) {
bindingPoint = [0.5, 0.5]
} else {
bindingPoint = Vec.divV(Vec.sub(point, [expandedBounds.minX, expandedBounds.minY]), [
expandedBounds.width,
expandedBounds.height,
])
}
distance = 0
} else {
// (1) Binding point
// Find furthest intersection between ray from origin through point and expanded bounds. TODO: What if the shape has a curve? In that case, should we intersect the circle-from-three-points instead?
const intersection = intersectRayBounds(origin, direction, expandedBounds)
.filter((int) => int.didIntersect)
.map((int) => int.points[0])
.sort((a, b) => Vec.dist(b, origin) - Vec.dist(a, origin))[0]
// The anchor is a point between the handle and the intersection
const anchor = Vec.med(point, intersection)
// If we're close to the center, snap to the center, or else calculate a normalized point based on the anchor and the expanded bounds.
if (Vec.distanceToLineSegment(point, anchor, this.getCenter(shape)) < 12) {
bindingPoint = [0.5, 0.5]
} else {
//
bindingPoint = Vec.divV(Vec.sub(anchor, [expandedBounds.minX, expandedBounds.minY]), [
expandedBounds.width,
expandedBounds.height,
])
}
// (3) Distance
// If the point is inside of the bounds, set the distance to a fixed value.
if (Utils.pointInBounds(point, bounds)) {
distance = 16
} else {
// If the binding point was close to the shape's center, snap to to the center. Find the distance between the point and the real bounds of the shape
distance = Math.max(
16,
Utils.getBoundsSides(bounds)
.map((side) => Vec.distanceToLineSegment(side[1][0], side[1][1], point))
.sort((a, b) => a - b)[0]
)
}
}
return {
point: Vec.clampV(bindingPoint, 0, 1),
distance,
}
},
onDoubleClickBoundsHandle() {
return
},
onDoubleClickHandle() {
return
},
onHandleChange() {
return
},
onRightPointHandle() {
return
},
onSessionComplete() {
return
},
onStyleChange() {
return
},
onBindingChange() {
return
},
onChildrenChange() {
return
},
updateChildren() {
return
},
}
Object.assign(this, defaults)
Object.assign(this, fn.call(this))
Object.assign(this, fn.call(this))
this.getBounds = this.getBounds.bind(this)
this.Component = this.Component.bind(this)
this._Component = React.forwardRef(this.Component)
return this
} as unknown as {
new <T extends TLShape, E extends Element, M = any>(
fn: (
this: TLShapeUtil<T, E, M>
) => Partial<TLShapeUtil<T, E, M>> &
Pick<TLShapeUtil<T, E, M>, 'type' | 'defaultProps' | 'Component' | 'Indicator' | 'getBounds'>
): TLShapeUtil<T, E, M>
}

View file

@ -0,0 +1 @@
export * from './createShape'

View file

@ -1,7 +0,0 @@
import { Box } from './box'
describe('example shape', () => {
it('should create an instance', () => {
expect(new Box()).toBeTruthy()
})
})

View file

@ -1,72 +0,0 @@
/* eslint-disable @typescript-eslint/no-unused-vars */
import * as React from 'react'
import { TLShapeUtil, TLShape, TLShapeProps, TLBounds, TLRenderInfo, TLTransformInfo } from '+types'
import Utils from '+utils'
export interface BoxShape extends TLShape {
size: number[]
}
export class Box extends TLShapeUtil<BoxShape, SVGGElement> {
type = 'box'
defaultProps = {
id: 'example1',
type: 'box',
parentId: 'page',
childIndex: 0,
name: 'Example Shape',
point: [0, 0],
size: [100, 100],
rotation: 0,
}
render = React.forwardRef<SVGGElement, TLShapeProps<BoxShape, SVGGElement>>(
({ shape, events }, ref) => {
return (
<g ref={ref} {...events}>
<rect width={shape.size[0]} height={shape.size[1]} fill="none" stroke="black" />
</g>
)
}
)
renderIndicator(shape: BoxShape) {
return <rect width={100} height={100} />
}
shouldRender(prev: BoxShape, next: BoxShape): boolean {
return true
}
getBounds(shape: BoxShape): TLBounds {
return Utils.getFromCache(this.boundsCache, shape, () => ({
minX: 0,
minY: 0,
maxX: 0,
maxY: 0,
width: 100,
height: 100,
}))
}
getRotatedBounds(shape: BoxShape) {
return Utils.getBoundsFromPoints(Utils.getRotatedCorners(this.getBounds(shape), shape.rotation))
}
getCenter(shape: BoxShape): number[] {
return Utils.getBoundsCenter(this.getBounds(shape))
}
hitTest(shape: BoxShape, point: number[]) {
return Utils.pointInBounds(point, this.getBounds(shape))
}
transform(shape: BoxShape, bounds: TLBounds, _info: TLTransformInfo<BoxShape>): BoxShape {
return { ...shape, point: [bounds.minX, bounds.minY] }
}
transformSingle(shape: BoxShape, bounds: TLBounds, info: TLTransformInfo<BoxShape>): BoxShape {
return this.transform(shape, bounds, info)
}
}

View file

@ -1,4 +1,3 @@
export * from './box'
export * from './mockDocument'
export * from './mockUtils'
export * from './renderWithContext'

View file

@ -1,5 +1,8 @@
import type { TLBinding, TLPage, TLPageState } from '+types'
import type { BoxShape } from './box'
import type { TLBinding, TLPage, TLPageState, TLShape } from '+types'
interface BoxShape extends TLShape {
size: number[]
}
export const mockDocument: { page: TLPage<BoxShape, TLBinding>; pageState: TLPageState } = {
page: {

View file

@ -1,6 +1,6 @@
import type { TLShapeUtils } from '+types'
import { Box, BoxShape } from './box'
import { Box } from '../shapes/createShape.spec'
export const mockUtils: TLShapeUtils<BoxShape, SVGGElement> = {
box: new Box(),
export const mockUtils: TLShapeUtils = {
box: Box,
}

View file

@ -2,9 +2,8 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
/* --------------------- Primary -------------------- */
import { Vec } from '@tldraw/vec'
import React, { ForwardedRef } from 'react'
import { intersectPolylineBounds } from '@tldraw/intersect'
import type React from 'react'
import type { ForwardedRef } from 'react'
export type Patch<T> = Partial<{ [P in keyof T]: T | Partial<T> | Patch<T[P]> }>
@ -14,7 +13,6 @@ export interface TLPage<T extends TLShape, B extends TLBinding> {
childIndex?: number
shapes: Record<string, T>
bindings: Record<string, B>
backgroundColor?: string
}
export interface TLPageState {
@ -58,9 +56,13 @@ export interface TLShape {
isAspectRatioLocked?: boolean
}
export type TLShapeUtils<T extends TLShape, E extends Element> = Record<string, TLShapeUtil<T, E>>
export type TLShapeUtils<T extends TLShape = any, E extends Element = any, M = any> = Record<
string,
TLShapeUtil<T, E, M>
>
export interface TLRenderInfo<T extends TLShape, M = any, E = any> {
export interface TLRenderInfo<T extends TLShape, E = any, M = any> {
shape: T
isEditing: boolean
isBinding: boolean
isHovered: boolean
@ -78,7 +80,7 @@ export interface TLRenderInfo<T extends TLShape, M = any, E = any> {
}
}
export interface TLShapeProps<T extends TLShape, E = any, M = any> extends TLRenderInfo<T, M, E> {
export interface TLShapeProps<T extends TLShape, E = any, M = any> extends TLRenderInfo<T, E, M> {
ref: ForwardedRef<E>
shape: T
}
@ -88,11 +90,12 @@ export interface TLTool {
name: string
}
export interface TLBinding {
export interface TLBinding<M = any> {
id: string
type: string
toId: string
fromId: string
meta: M
}
export interface TLTheme {
@ -266,175 +269,135 @@ export interface TLBezierCurveSegment {
/* Shape Utility */
/* -------------------------------------------------- */
export abstract class TLShapeUtil<T extends TLShape, E extends Element> {
refMap = new Map<string, React.RefObject<E>>()
export interface TLShapeUtil<T extends TLShape, E extends Element, M extends any> {
type: T['type']
boundsCache = new WeakMap<TLShape, TLBounds>()
defaultProps: T
isEditableText = false
Component(
this: TLShapeUtil<T, E, M>,
props: TLRenderInfo<T, E, M>,
ref: React.ForwardedRef<E>
): React.ReactElement<TLRenderInfo<T, E, M>, E['tagName']>
isAspectRatioLocked = false
Indicator(this: TLShapeUtil<T, E, M>, props: { shape: T }): React.ReactElement | null
canEdit = false
getBounds(this: TLShapeUtil<T, E, M>, shape: T): TLBounds
canBind = false
refMap: Map<string, React.RefObject<E>>
abstract type: T['type']
boundsCache: WeakMap<TLShape, TLBounds>
abstract defaultProps: T
isAspectRatioLocked: boolean
abstract render: React.ForwardRefExoticComponent<
{ shape: T; ref: React.ForwardedRef<E> } & TLRenderInfo<T> & React.RefAttributes<E>
>
canEdit: boolean
abstract renderIndicator(shape: T): JSX.Element | null
canBind: boolean
abstract getBounds(shape: T): TLBounds
getRotatedBounds(this: TLShapeUtil<T, E, M>, shape: T): TLBounds
abstract getRotatedBounds(shape: T): TLBounds
hitTest(this: TLShapeUtil<T, E, M>, shape: T, point: number[]): boolean
shouldRender(_prev: T, _next: T): boolean {
return true
}
hitTestBounds(this: TLShapeUtil<T, E, M>, shape: T, bounds: TLBounds): boolean
shouldDelete(_shape: T): boolean {
return false
}
shouldRender(this: TLShapeUtil<T, E, M>, prev: T, next: T): boolean
getCenter(shape: T): number[] {
const bounds = this.getBounds(shape)
return [bounds.width / 2, bounds.height / 2]
}
getCenter(this: TLShapeUtil<T, E, M>, shape: T): number[]
getRef(shape: T): React.RefObject<E> {
if (!this.refMap.has(shape.id)) {
this.refMap.set(shape.id, React.createRef<E>())
}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return this.refMap.get(shape.id)!
}
getRef(this: TLShapeUtil<T, E, M>, shape: T): React.RefObject<E>
getBindingPoint(
getBindingPoint<K extends TLShape>(
this: TLShapeUtil<T, E, M>,
shape: T,
fromShape: TLShape,
fromShape: K,
point: number[],
origin: number[],
direction: number[],
padding: number,
anywhere: boolean
): { point: number[]; distance: number } | undefined {
return undefined
}
bindAnywhere: boolean
): { point: number[]; distance: number } | undefined
create(props: { id: string } & Partial<T>): T {
this.refMap.set(props.id, React.createRef<E>())
return { ...this.defaultProps, ...props }
}
create: (this: TLShapeUtil<T, E, M>, props: { id: string } & Partial<T>) => T
mutate(shape: T, props: Partial<T>): T {
return { ...shape, ...props }
}
mutate: (this: TLShapeUtil<T, E, M>, shape: T, props: Partial<T>) => Partial<T>
transform(shape: T, bounds: TLBounds, info: TLTransformInfo<T>): Partial<T> | void {
return undefined
}
transform: (
this: TLShapeUtil<T, E, M>,
shape: T,
bounds: TLBounds,
info: TLTransformInfo<T>
) => Partial<T> | void
transformSingle(shape: T, bounds: TLBounds, info: TLTransformInfo<T>): Partial<T> | void {
return this.transform(shape, bounds, info)
}
transformSingle: (
this: TLShapeUtil<T, E, M>,
shape: T,
bounds: TLBounds,
info: TLTransformInfo<T>
) => Partial<T> | void
updateChildren<K extends TLShape>(shape: T, children: K[]): Partial<K>[] | void {
return
}
updateChildren: <K extends TLShape>(
this: TLShapeUtil<T, E, M>,
shape: T,
children: K[]
) => Partial<K>[] | void
onChildrenChange(shape: T, children: TLShape[]): Partial<T> | void {
return
}
onChildrenChange: (this: TLShapeUtil<T, E, M>, shape: T, children: TLShape[]) => Partial<T> | void
onBindingChange(
onBindingChange: (
this: TLShapeUtil<T, E, M>,
shape: T,
binding: TLBinding,
target: TLShape,
targetBounds: TLBounds,
center: number[]
): Partial<T> | void {
return undefined
}
) => Partial<T> | void
onHandleChange(
onHandleChange: (
this: TLShapeUtil<T, E, M>,
shape: T,
handle: Partial<T['handles']>,
info: Partial<TLPointerInfo>
): Partial<T> | void {
return
}
) => Partial<T> | void
onRightPointHandle(
onRightPointHandle: (
this: TLShapeUtil<T, E, M>,
shape: T,
handle: Partial<T['handles']>,
info: Partial<TLPointerInfo>
): Partial<T> | void {
return
}
) => Partial<T> | void
onDoubleClickHandle(
onDoubleClickHandle: (
this: TLShapeUtil<T, E, M>,
shape: T,
handle: Partial<T['handles']>,
info: Partial<TLPointerInfo>
): Partial<T> | void {
return
}
) => Partial<T> | void
onSessionComplete(shape: T): Partial<T> | void {
return
}
onDoubleClickBoundsHandle: (this: TLShapeUtil<T, E, M>, shape: T) => Partial<T> | void
onBoundsReset(shape: T): Partial<T> | void {
return
}
onSessionComplete: (this: TLShapeUtil<T, E, M>, shape: T) => Partial<T> | void
onStyleChange(shape: T): Partial<T> | void {
return
}
onStyleChange: (this: TLShapeUtil<T, E, M>, shape: T) => Partial<T> | void
hitTest(shape: T, point: number[]) {
const bounds = this.getBounds(shape)
return !(
point[0] < bounds.minX ||
point[0] > bounds.maxX ||
point[1] < bounds.minY ||
point[1] > bounds.maxY
)
}
hitTestBounds(shape: T, bounds: TLBounds) {
const { minX, minY, maxX, maxY, width, height } = this.getBounds(shape)
const center = [minX + width / 2, minY + height / 2]
const corners = [
[minX, minY],
[maxX, minY],
[maxX, maxY],
[minX, maxY],
].map((point) => Vec.rotWith(point, center, shape.rotation || 0))
return (
corners.every(
(point) =>
!(
point[0] < bounds.minX ||
point[0] > bounds.maxX ||
point[1] < bounds.minY ||
point[1] > bounds.maxY
)
) || intersectPolylineBounds(corners, bounds).length > 0
)
}
_Component: React.ForwardRefExoticComponent<any>
}
// export interface TLShapeUtil<T extends TLShape, E extends Element, M = any>
// extends TLShapeUtilRequired<T, E, M>,
// Required<TLShapeUtilDefaults<T, E>> {
// _Component: React.ForwardRefExoticComponent<any> & {
// defaultProps: any
// propTypes: any
// }
// }
// export interface TLShapeUtilConfig<T extends TLShape, E extends Element, M = any>
// extends TLShapeUtilRequired<T, E, M>,
// Partial<TLShapeUtilDefaults<T, E>> {}
/* -------------------- Internal -------------------- */
export interface IShapeTreeNode<T extends TLShape, M extends Record<string, unknown>> {
export interface IShapeTreeNode<T extends TLShape, M = any> {
shape: T
children?: IShapeTreeNode<TLShape, M>[]
isEditing: boolean
@ -450,7 +413,7 @@ export interface IShapeTreeNode<T extends TLShape, M extends Record<string, unkn
/* -------------------------------------------------- */
/** @internal */
export type MappedByType<T extends { type: string }> = {
export type MappedByType<K extends string, T extends { type: K }> = {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
[P in T['type']]: T extends any ? (P extends T['type'] ? T : never) : never
}

View file

@ -8,9 +8,17 @@
"baseUrl": "src",
"emitDeclarationOnly": false,
"paths": {
"+*": ["./*"]
"+*": ["./*"],
"@tldraw/vec": ["../vec"],
"@tldraw/intersect": ["../intersect"]
}
},
"references": [
{
"path": "../intersect"
},
{ "path": "../vec" }
],
"typedocOptions": {
"entryPoints": ["src/index.ts"],
"out": "docs"

View file

@ -8,9 +8,17 @@
"baseUrl": "src",
"emitDeclarationOnly": false,
"paths": {
"+*": ["./*"]
"+*": ["./*"],
"@tldraw/core": ["../vec"],
"@tldraw/tldraw": ["../intersect"]
}
},
"references": [
{
"path": "../tldraw"
},
{ "path": "../core" }
],
"typedocOptions": {
"entryPoints": ["src/index.ts"],
"out": "docs"

View file

@ -6,8 +6,12 @@
"outDir": "./dist/types",
"rootDir": "src",
"baseUrl": "src",
"emitDeclarationOnly": false
"emitDeclarationOnly": false,
"paths": {
"@tldraw/vec": ["../vec"]
}
},
"references": [{ "path": "../vec" }],
"typedocOptions": {
"entryPoints": ["src/index.ts"],
"out": "docs"

View file

@ -72,4 +72,4 @@
"rko": "^0.5.25"
},
"gitHead": "5cb031ddc264846ec6732d7179511cddea8ef034"
}
}

View file

@ -1,39 +1,18 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { Rectangle, Ellipse, Arrow, Draw, Text, Group, PostIt } from './shapes'
import { TLDrawShapeType, TLDrawShape, TLDrawShapeUtil, TLDrawShapeUtils } from '~types'
import { TLDrawShapeType, TLDrawShape, TLDrawShapeUtil } from '~types'
export const tldrawShapeUtils: TLDrawShapeUtils = {
[TLDrawShapeType.Rectangle]: new Rectangle(),
[TLDrawShapeType.Ellipse]: new Ellipse(),
[TLDrawShapeType.Draw]: new Draw(),
[TLDrawShapeType.Arrow]: new Arrow(),
[TLDrawShapeType.Text]: new Text(),
[TLDrawShapeType.Group]: new Group(),
[TLDrawShapeType.PostIt]: new PostIt(),
} as TLDrawShapeUtils
export type ShapeByType<T extends keyof TLDrawShapeUtils> = TLDrawShapeUtils[T]
export function getShapeUtilsByType<T extends TLDrawShape>(
shape: T
): TLDrawShapeUtil<T, HTMLElement | SVGElement> {
return tldrawShapeUtils[shape.type as T['type']] as unknown as TLDrawShapeUtil<
T,
HTMLElement | SVGElement
>
// This is a bad "any", but the "this" context stuff we're doing doesn't allow us to union the types
export const tldrawShapeUtils: Record<TLDrawShapeType, any> = {
[TLDrawShapeType.Rectangle]: Rectangle,
[TLDrawShapeType.Ellipse]: Ellipse,
[TLDrawShapeType.Draw]: Draw,
[TLDrawShapeType.Arrow]: Arrow,
[TLDrawShapeType.Text]: Text,
[TLDrawShapeType.Group]: Group,
[TLDrawShapeType.PostIt]: PostIt,
}
export function getShapeUtils<T extends TLDrawShape>(
shape: T
): TLDrawShapeUtil<T, HTMLElement | SVGElement> {
return tldrawShapeUtils[shape.type as T['type']] as unknown as TLDrawShapeUtil<
T,
HTMLElement | SVGElement
>
}
export function createShape<TLDrawShape>(
type: TLDrawShapeType,
props: { id: string } & Partial<TLDrawShape>
) {
return tldrawShapeUtils[type].create(props)
export function getShapeUtils<T extends TLDrawShape>(type: TLDrawShapeType) {
return tldrawShapeUtils[type] as TLDrawShapeUtil<T, any>
}

View file

@ -0,0 +1,54 @@
// Jest Snapshot v1, https://goo.gl/fbAQLP
exports[`Arrow shape Creates a shape: arrow 1`] = `
Object {
"bend": 0,
"childIndex": 1,
"decorations": Object {
"end": "Arrow",
},
"handles": Object {
"bend": Object {
"id": "bend",
"index": 2,
"point": Array [
0.5,
0.5,
],
},
"end": Object {
"canBind": true,
"id": "end",
"index": 1,
"point": Array [
1,
1,
],
},
"start": Object {
"canBind": true,
"id": "start",
"index": 0,
"point": Array [
0,
0,
],
},
},
"id": "arrow",
"name": "Arrow",
"parentId": "page",
"point": Array [
0,
0,
],
"rotation": 0,
"style": Object {
"color": "Black",
"dash": "Draw",
"isFilled": false,
"size": "Medium",
},
"type": "arrow",
}
`;

View file

@ -1,7 +1,7 @@
import { Arrow } from './arrow'
describe('Arrow shape', () => {
it('Creates an instance', () => {
new Arrow()
it('Creates a shape', () => {
expect(Arrow.create({ id: 'arrow' })).toMatchSnapshot('arrow')
})
})

View file

@ -1,26 +1,17 @@
import * as React from 'react'
import {
SVGContainer,
TLBounds,
Utils,
TLTransformInfo,
TLHandle,
TLPointerInfo,
TLShapeProps,
} from '@tldraw/core'
import { ShapeUtil, SVGContainer, TLBounds, Utils, TLHandle } from '@tldraw/core'
import { Vec } from '@tldraw/vec'
import getStroke from 'perfect-freehand'
import { defaultStyle, getPerfectDashProps, getShapeStyle } from '~shape/shape-styles'
import {
ArrowShape,
Decoration,
TLDrawShapeUtil,
TLDrawShapeType,
TLDrawToolType,
DashStyle,
TLDrawShape,
ArrowBinding,
TLDrawMeta,
EllipseShape,
} from '~types'
import {
intersectArcBounds,
@ -31,16 +22,20 @@ import {
intersectRayEllipse,
} from '@tldraw/intersect'
export class Arrow extends TLDrawShapeUtil<ArrowShape, SVGSVGElement> {
type = TLDrawShapeType.Arrow as const
toolType = TLDrawToolType.Handle
canStyleFill = false
simplePathCache = new WeakMap<ArrowShape['handles'], string>()
pathCache = new WeakMap<ArrowShape, string>()
const simplePathCache = new WeakMap<ArrowShape['handles'], string>()
defaultProps = {
export const Arrow = new ShapeUtil<ArrowShape, SVGSVGElement, TLDrawMeta>(() => ({
type: TLDrawShapeType.Arrow,
toolType: TLDrawToolType.Handle,
canStyleFill: false,
pathCache: new WeakMap<ArrowShape, string>(),
defaultProps: {
id: 'id',
type: TLDrawShapeType.Arrow as const,
type: TLDrawShapeType.Arrow,
name: 'Arrow',
parentId: 'page',
childIndex: 1,
@ -73,136 +68,62 @@ export class Arrow extends TLDrawShapeUtil<ArrowShape, SVGSVGElement> {
...defaultStyle,
isFilled: false,
},
}
},
shouldRender = (prev: ArrowShape, next: ArrowShape) => {
return next.handles !== prev.handles || next.style !== prev.style
}
Component({ shape, meta, events }, ref) {
const {
handles: { start, bend, end },
decorations = {},
style,
} = shape
render = React.forwardRef<SVGSVGElement, TLShapeProps<ArrowShape, SVGSVGElement>>(
({ shape, meta, events }, ref) => {
const {
handles: { start, bend, end },
decorations = {},
style,
} = shape
const isDraw = style.dash === DashStyle.Draw
const isDraw = style.dash === DashStyle.Draw
// TODO: Improve drawn arrows
// TODO: Improve drawn arrows
const isStraightLine = Vec.dist(bend.point, Vec.round(Vec.med(start.point, end.point))) < 1
const isStraightLine = Vec.dist(bend.point, Vec.round(Vec.med(start.point, end.point))) < 1
const styles = getShapeStyle(style, meta.isDarkMode)
const styles = getShapeStyle(style, meta.isDarkMode)
const { strokeWidth } = styles
const { strokeWidth } = styles
const arrowDist = Vec.dist(start.point, end.point)
const arrowDist = Vec.dist(start.point, end.point)
const arrowHeadLength = Math.min(arrowDist / 3, strokeWidth * 8)
const arrowHeadLength = Math.min(arrowDist / 3, strokeWidth * 8)
let shaftPath: JSX.Element | null
let startArrowHead: { left: number[]; right: number[] } | undefined
let endArrowHead: { left: number[]; right: number[] } | undefined
let shaftPath: JSX.Element | null
let startArrowHead: { left: number[]; right: number[] } | undefined
let endArrowHead: { left: number[]; right: number[] } | undefined
if (isStraightLine) {
const sw = strokeWidth * (isDraw ? 1.25 : 1.618)
if (isStraightLine) {
const sw = strokeWidth * (isDraw ? 1.25 : 1.618)
const path = isDraw
? renderFreehandArrowShaft(shape)
: 'M' + Vec.round(start.point) + 'L' + Vec.round(end.point)
const path = Utils.getFromCache(this.pathCache, shape, () =>
isDraw
? renderFreehandArrowShaft(shape)
: 'M' + Vec.round(start.point) + 'L' + Vec.round(end.point)
)
const { strokeDasharray, strokeDashoffset } = getPerfectDashProps(
arrowDist,
sw,
shape.style.dash,
2
)
const { strokeDasharray, strokeDashoffset } = getPerfectDashProps(
arrowDist,
sw,
shape.style.dash,
2
)
if (decorations.start) {
startArrowHead = getStraightArrowHeadPoints(start.point, end.point, arrowHeadLength)
}
if (decorations.start) {
startArrowHead = getStraightArrowHeadPoints(start.point, end.point, arrowHeadLength)
}
if (decorations.end) {
endArrowHead = getStraightArrowHeadPoints(end.point, start.point, arrowHeadLength)
}
if (decorations.end) {
endArrowHead = getStraightArrowHeadPoints(end.point, start.point, arrowHeadLength)
}
// Straight arrow path
shaftPath =
arrowDist > 2 ? (
<>
<path
d={path}
fill="none"
strokeWidth={Math.max(8, strokeWidth * 2)}
strokeDasharray="none"
strokeDashoffset="none"
strokeLinecap="round"
strokeLinejoin="round"
pointerEvents="stroke"
/>
<path
d={path}
fill={styles.stroke}
stroke={styles.stroke}
strokeWidth={sw}
strokeDasharray={strokeDasharray}
strokeDashoffset={strokeDashoffset}
strokeLinecap="round"
strokeLinejoin="round"
pointerEvents="stroke"
/>
</>
) : null
} else {
const circle = getCtp(shape)
const sw = strokeWidth * (isDraw ? 1.25 : 1.618)
const path = Utils.getFromCache(this.pathCache, shape, () =>
isDraw
? renderCurvedFreehandArrowShaft(shape, circle)
: getArrowArcPath(start, end, circle, shape.bend)
)
const { center, radius, length } = getArrowArc(shape)
const { strokeDasharray, strokeDashoffset } = getPerfectDashProps(
length - 1,
sw,
shape.style.dash,
2
)
if (decorations.start) {
startArrowHead = getCurvedArrowHeadPoints(
start.point,
arrowHeadLength,
center,
radius,
length < 0
)
}
if (decorations.end) {
endArrowHead = getCurvedArrowHeadPoints(
end.point,
arrowHeadLength,
center,
radius,
length >= 0
)
}
// Curved arrow path
shaftPath = (
// Straight arrow path
shaftPath =
arrowDist > 2 ? (
<>
<path
d={path}
fill="none"
stroke="transparent"
strokeWidth={Math.max(8, strokeWidth * 2)}
strokeDasharray="none"
strokeDashoffset="none"
@ -212,7 +133,7 @@ export class Arrow extends TLDrawShapeUtil<ArrowShape, SVGSVGElement> {
/>
<path
d={path}
fill={isDraw ? styles.stroke : 'none'}
fill={styles.stroke}
stroke={styles.stroke}
strokeWidth={sw}
strokeDasharray={strokeDasharray}
@ -222,81 +143,145 @@ export class Arrow extends TLDrawShapeUtil<ArrowShape, SVGSVGElement> {
pointerEvents="stroke"
/>
</>
) : null
} else {
const circle = getCtp(shape)
const sw = strokeWidth * (isDraw ? 1.25 : 1.618)
const path = isDraw
? renderCurvedFreehandArrowShaft(shape, circle)
: getArrowArcPath(start, end, circle, shape.bend)
const { center, radius, length } = getArrowArc(shape)
const { strokeDasharray, strokeDashoffset } = getPerfectDashProps(
length - 1,
sw,
shape.style.dash,
2
)
if (decorations.start) {
startArrowHead = getCurvedArrowHeadPoints(
start.point,
arrowHeadLength,
center,
radius,
length < 0
)
}
const sw = strokeWidth * 1.618
if (decorations.end) {
endArrowHead = getCurvedArrowHeadPoints(
end.point,
arrowHeadLength,
center,
radius,
length >= 0
)
}
return (
<SVGContainer ref={ref} {...events}>
<g pointerEvents="none">
{shaftPath}
{startArrowHead && (
<path
d={`M ${startArrowHead.left} L ${start.point} ${startArrowHead.right}`}
fill="none"
stroke={styles.stroke}
strokeWidth={sw}
strokeDashoffset="none"
strokeDasharray="none"
strokeLinecap="round"
strokeLinejoin="round"
pointerEvents="stroke"
/>
)}
{endArrowHead && (
<path
d={`M ${endArrowHead.left} L ${end.point} ${endArrowHead.right}`}
fill="none"
stroke={styles.stroke}
strokeWidth={sw}
strokeDashoffset="none"
strokeDasharray="none"
strokeLinecap="round"
strokeLinejoin="round"
pointerEvents="stroke"
/>
)}
</g>
</SVGContainer>
// Curved arrow path
shaftPath = (
<>
<path
d={path}
fill="none"
stroke="transparent"
strokeWidth={Math.max(8, strokeWidth * 2)}
strokeDasharray="none"
strokeDashoffset="none"
strokeLinecap="round"
strokeLinejoin="round"
pointerEvents="stroke"
/>
<path
d={path}
fill={isDraw ? styles.stroke : 'none'}
stroke={styles.stroke}
strokeWidth={sw}
strokeDasharray={strokeDasharray}
strokeDashoffset={strokeDashoffset}
strokeLinecap="round"
strokeLinejoin="round"
pointerEvents="stroke"
/>
</>
)
}
)
renderIndicator(shape: ArrowShape) {
const path = Utils.getFromCache(this.simplePathCache, shape.handles, () => getArrowPath(shape))
const sw = strokeWidth * 1.618
return (
<SVGContainer ref={ref} {...events}>
<g pointerEvents="none">
{shaftPath}
{startArrowHead && (
<path
d={`M ${startArrowHead.left} L ${start.point} ${startArrowHead.right}`}
fill="none"
stroke={styles.stroke}
strokeWidth={sw}
strokeDashoffset="none"
strokeDasharray="none"
strokeLinecap="round"
strokeLinejoin="round"
pointerEvents="stroke"
/>
)}
{endArrowHead && (
<path
d={`M ${endArrowHead.left} L ${end.point} ${endArrowHead.right}`}
fill="none"
stroke={styles.stroke}
strokeWidth={sw}
strokeDashoffset="none"
strokeDasharray="none"
strokeLinecap="round"
strokeLinejoin="round"
pointerEvents="stroke"
/>
)}
</g>
</SVGContainer>
)
},
Indicator({ shape }) {
const path = Utils.getFromCache(simplePathCache, shape.handles, () => getArrowPath(shape))
return <path d={path} />
}
},
getBounds = (shape: ArrowShape) => {
shouldRender(prev, next) {
return next.handles !== prev.handles || next.style !== prev.style
},
getBounds(shape) {
const bounds = Utils.getFromCache(this.boundsCache, shape, () => {
const { start, bend, end } = shape.handles
return Utils.getBoundsFromPoints([start.point, bend.point, end.point])
})
return Utils.translateBounds(bounds, shape.point)
}
},
getRotatedBounds = (shape: ArrowShape) => {
getRotatedBounds(shape) {
const { start, bend, end } = shape.handles
return Utils.translateBounds(
Utils.getBoundsFromPoints([start.point, bend.point, end.point], shape.rotation),
shape.point
)
}
},
getCenter = (shape: ArrowShape) => {
getCenter(shape) {
const { start, end } = shape.handles
return Vec.add(shape.point, Vec.med(start.point, end.point))
}
},
hitTest = () => {
return true
}
hitTestBounds = (shape: ArrowShape, brushBounds: TLBounds) => {
hitTestBounds(shape, brushBounds: TLBounds) {
const { start, end, bend } = shape.handles
const sp = Vec.add(shape.point, start.point)
@ -314,13 +299,9 @@ export class Arrow extends TLDrawShapeUtil<ArrowShape, SVGSVGElement> {
return intersectArcBounds(cp, r, sp, ep, brushBounds).length > 0
}
}
},
transform = (
_shape: ArrowShape,
bounds: TLBounds,
{ initialShape, scaleX, scaleY }: TLTransformInfo<ArrowShape>
): Partial<ArrowShape> => {
transform(_shape, bounds, { initialShape, scaleX, scaleY }) {
const initialShapeBounds = this.getBounds(initialShape)
const handles: (keyof ArrowShape['handles'])[] = ['start', 'end']
@ -362,9 +343,9 @@ export class Arrow extends TLDrawShapeUtil<ArrowShape, SVGSVGElement> {
point: [bounds.minX, bounds.minY],
handles: nextHandles,
}
}
},
onDoubleClickHandle = (shape: ArrowShape, handle: Partial<ArrowShape['handles']>) => {
onDoubleClickHandle(shape, handle) {
switch (handle) {
case 'bend': {
return {
@ -397,16 +378,10 @@ export class Arrow extends TLDrawShapeUtil<ArrowShape, SVGSVGElement> {
}
return this
}
},
onBindingChange = (
shape: ArrowShape,
binding: ArrowBinding,
target: TLDrawShape,
targetBounds: TLBounds,
center: number[]
): void | Partial<ArrowShape> => {
const handle = shape.handles[binding.handleId]
onBindingChange(shape, binding: ArrowBinding, target, targetBounds, center) {
const handle = shape.handles[binding.meta.handleId as keyof ArrowShape['handles']]
const expandedBounds = Utils.expandBounds(targetBounds, 32)
// The anchor is the "actual" point in the target shape
@ -416,7 +391,7 @@ export class Arrow extends TLDrawShapeUtil<ArrowShape, SVGSVGElement> {
[expandedBounds.minX, expandedBounds.minY],
Vec.mulV(
[expandedBounds.width, expandedBounds.height],
Vec.rotWith(binding.point, [0.5, 0.5], target.rotation || 0)
Vec.rotWith(binding.meta.point, [0.5, 0.5], target.rotation || 0)
)
),
shape.point
@ -425,8 +400,8 @@ export class Arrow extends TLDrawShapeUtil<ArrowShape, SVGSVGElement> {
// We're looking for the point to put the dragging handle
let handlePoint = anchor
if (binding.distance) {
const intersectBounds = Utils.expandBounds(targetBounds, binding.distance)
if (binding.meta.distance) {
const intersectBounds = Utils.expandBounds(targetBounds, binding.meta.distance)
// The direction vector starts from the arrow's opposite handle
const origin = Vec.add(
@ -437,7 +412,9 @@ export class Arrow extends TLDrawShapeUtil<ArrowShape, SVGSVGElement> {
// And passes through the dragging handle
const direction = Vec.uni(Vec.sub(Vec.add(anchor, shape.point), origin))
if ([TLDrawShapeType.Rectangle, TLDrawShapeType.Text].includes(target.type)) {
if (
[TLDrawShapeType.Rectangle, TLDrawShapeType.Text].includes(target.type as TLDrawShapeType)
) {
let hits = intersectRayBounds(origin, direction, intersectBounds, target.rotation)
.filter((int) => int.didIntersect)
.map((int) => int.points[0])
@ -461,8 +438,8 @@ export class Arrow extends TLDrawShapeUtil<ArrowShape, SVGSVGElement> {
origin,
direction,
center,
target.radius[0] + binding.distance,
target.radius[1] + binding.distance,
(target as EllipseShape).radius[0] + binding.meta.distance,
(target as EllipseShape).radius[1] + binding.meta.distance,
target.rotation || 0
).points.sort((a, b) => Vec.dist(a, origin) - Vec.dist(b, origin))
@ -484,13 +461,9 @@ export class Arrow extends TLDrawShapeUtil<ArrowShape, SVGSVGElement> {
},
{ shiftKey: false }
)
}
},
onHandleChange = (
shape: ArrowShape,
handles: Partial<ArrowShape['handles']>,
{ shiftKey }: Partial<TLPointerInfo>
) => {
onHandleChange(shape, handles, { shiftKey }) {
let nextHandles = Utils.deepMerge<ArrowShape['handles']>(shape.handles, handles)
let nextBend = shape.bend
@ -586,8 +559,12 @@ export class Arrow extends TLDrawShapeUtil<ArrowShape, SVGSVGElement> {
}
return nextShape
}
}
},
}))
/* -------------------------------------------------- */
/* Helpers */
/* -------------------------------------------------- */
function getArrowArcPath(start: TLHandle, end: TLHandle, circle: number[], bend: number) {
return [

View file

@ -0,0 +1,23 @@
// Jest Snapshot v1, https://goo.gl/fbAQLP
exports[`Draw shape Creates a shape: draw 1`] = `
Object {
"childIndex": 1,
"id": "draw",
"name": "Draw",
"parentId": "page",
"point": Array [
0,
0,
],
"points": Array [],
"rotation": 0,
"style": Object {
"color": "Black",
"dash": "Draw",
"isFilled": false,
"size": "Medium",
},
"type": "draw",
}
`;

View file

@ -1,7 +1,7 @@
import { Draw } from './draw'
describe('Draw shape', () => {
it('Creates an instance', () => {
new Draw()
it('Creates a shape', () => {
expect(Draw.create({ id: 'draw' })).toMatchSnapshot('draw')
})
})

View file

@ -1,31 +1,25 @@
import * as React from 'react'
import { SVGContainer, TLBounds, Utils, TLTransformInfo } from '@tldraw/core'
import { SVGContainer, TLBounds, Utils, TLTransformInfo, ShapeUtil } from '@tldraw/core'
import { Vec } from '@tldraw/vec'
import { intersectBoundsBounds, intersectBoundsPolyline } from '@tldraw/intersect'
import getStroke, { getStrokePoints } from 'perfect-freehand'
import { defaultStyle, getShapeStyle } from '~shape/shape-styles'
import {
DrawShape,
DashStyle,
TLDrawShapeUtil,
TLDrawShapeType,
TLDrawToolType,
TLDrawShapeProps,
} from '~types'
import { DrawShape, DashStyle, TLDrawShapeType, TLDrawToolType, TLDrawMeta } from '~types'
export class Draw extends TLDrawShapeUtil<DrawShape, SVGSVGElement> {
type = TLDrawShapeType.Draw as const
toolType = TLDrawToolType.Draw
const pointsBoundsCache = new WeakMap<DrawShape['points'], TLBounds>([])
const rotatedCache = new WeakMap<DrawShape, number[][]>([])
const drawPathCache = new WeakMap<DrawShape['points'], string>([])
const simplePathCache = new WeakMap<DrawShape['points'], string>([])
const polygonCache = new WeakMap<DrawShape['points'], string>([])
pointsBoundsCache = new WeakMap<DrawShape['points'], TLBounds>([])
rotatedCache = new WeakMap<DrawShape, number[][]>([])
drawPathCache = new WeakMap<DrawShape['points'], string>([])
simplePathCache = new WeakMap<DrawShape['points'], string>([])
polygonCache = new WeakMap<DrawShape['points'], string>([])
export const Draw = new ShapeUtil<DrawShape, SVGSVGElement, TLDrawMeta>(() => ({
type: TLDrawShapeType.Draw,
defaultProps: DrawShape = {
toolType: TLDrawToolType.Draw,
defaultProps: {
id: 'id',
type: TLDrawShapeType.Draw as const,
type: TLDrawShapeType.Draw,
name: 'Draw',
parentId: 'page',
childIndex: 1,
@ -33,130 +27,122 @@ export class Draw extends TLDrawShapeUtil<DrawShape, SVGSVGElement> {
points: [],
rotation: 0,
style: defaultStyle,
}
},
shouldRender(prev: DrawShape, next: DrawShape): boolean {
return next.points !== prev.points || next.style !== prev.style
}
Component({ shape, meta, events, isEditing }, ref) {
const { points, style } = shape
render = React.forwardRef<SVGSVGElement, TLDrawShapeProps<DrawShape, SVGSVGElement>>(
({ shape, meta, events, isEditing }, ref) => {
const { points, style } = shape
const styles = getShapeStyle(style, meta.isDarkMode)
const styles = getShapeStyle(style, meta.isDarkMode)
const strokeWidth = styles.strokeWidth
const strokeWidth = styles.strokeWidth
// For very short lines, draw a point instead of a line
const bounds = this.getBounds(shape)
// For very short lines, draw a point instead of a line
const bounds = this.getBounds(shape)
const verySmall = bounds.width < strokeWidth / 2 && bounds.height < strokeWidth / 2
const verySmall = bounds.width < strokeWidth / 2 && bounds.height < strokeWidth / 2
if (!isEditing && verySmall) {
const sw = strokeWidth * 0.618
return (
<SVGContainer ref={ref} {...events}>
<circle
r={strokeWidth * 0.618}
fill={styles.stroke}
stroke={styles.stroke}
strokeWidth={sw}
pointerEvents="all"
/>
</SVGContainer>
)
}
const shouldFill =
style.isFilled &&
points.length > 3 &&
Vec.dist(points[0], points[points.length - 1]) < +styles.strokeWidth * 2
// For drawn lines, draw a line from the path cache
if (shape.style.dash === DashStyle.Draw) {
const polygonPathData = Utils.getFromCache(this.polygonCache, points, () =>
getFillPath(shape)
)
const drawPathData = isEditing
? getDrawStrokePath(shape, true)
: Utils.getFromCache(this.drawPathCache, points, () => getDrawStrokePath(shape, false))
return (
<SVGContainer ref={ref} {...events}>
{shouldFill && (
<path
d={polygonPathData}
stroke="none"
fill={styles.fill}
strokeLinejoin="round"
strokeLinecap="round"
pointerEvents="fill"
/>
)}
<path
d={drawPathData}
fill={styles.stroke}
stroke={styles.stroke}
strokeWidth={strokeWidth}
strokeLinejoin="round"
strokeLinecap="round"
pointerEvents="all"
/>
</SVGContainer>
)
}
// For solid, dash and dotted lines, draw a regular stroke path
const strokeDasharray = {
[DashStyle.Draw]: 'none',
[DashStyle.Solid]: `none`,
[DashStyle.Dotted]: `${strokeWidth / 10} ${strokeWidth * 3}`,
[DashStyle.Dashed]: `${strokeWidth * 3} ${strokeWidth * 3}`,
}[style.dash]
const strokeDashoffset = {
[DashStyle.Draw]: 'none',
[DashStyle.Solid]: `none`,
[DashStyle.Dotted]: `-${strokeWidth / 20}`,
[DashStyle.Dashed]: `-${strokeWidth}`,
}[style.dash]
const path = Utils.getFromCache(this.simplePathCache, points, () => getSolidStrokePath(shape))
const sw = strokeWidth * 1.618
if (!isEditing && verySmall) {
const sw = strokeWidth * 0.618
return (
<SVGContainer ref={ref} {...events}>
<path
d={path}
fill={shouldFill ? styles.fill : 'none'}
stroke="transparent"
strokeWidth={Math.min(4, strokeWidth * 2)}
strokeLinejoin="round"
strokeLinecap="round"
pointerEvents={shouldFill ? 'all' : 'stroke'}
/>
<path
d={path}
fill="transparent"
<circle
r={strokeWidth * 0.618}
fill={styles.stroke}
stroke={styles.stroke}
strokeWidth={sw}
strokeDasharray={strokeDasharray}
strokeDashoffset={strokeDashoffset}
strokeLinejoin="round"
strokeLinecap="round"
pointerEvents="stroke"
pointerEvents="all"
/>
</SVGContainer>
)
}
)
renderIndicator(shape: DrawShape): JSX.Element {
const shouldFill =
style.isFilled &&
points.length > 3 &&
Vec.dist(points[0], points[points.length - 1]) < +styles.strokeWidth * 2
// For drawn lines, draw a line from the path cache
if (shape.style.dash === DashStyle.Draw) {
const polygonPathData = Utils.getFromCache(polygonCache, points, () => getFillPath(shape))
const drawPathData = isEditing
? getDrawStrokePath(shape, true)
: Utils.getFromCache(drawPathCache, points, () => getDrawStrokePath(shape, false))
return (
<SVGContainer ref={ref} {...events}>
{shouldFill && (
<path
d={polygonPathData}
stroke="none"
fill={styles.fill}
strokeLinejoin="round"
strokeLinecap="round"
pointerEvents="fill"
/>
)}
<path
d={drawPathData}
fill={styles.stroke}
stroke={styles.stroke}
strokeWidth={strokeWidth}
strokeLinejoin="round"
strokeLinecap="round"
pointerEvents="all"
/>
</SVGContainer>
)
}
// For solid, dash and dotted lines, draw a regular stroke path
const strokeDasharray = {
[DashStyle.Draw]: 'none',
[DashStyle.Solid]: `none`,
[DashStyle.Dotted]: `${strokeWidth / 10} ${strokeWidth * 3}`,
[DashStyle.Dashed]: `${strokeWidth * 3} ${strokeWidth * 3}`,
}[style.dash]
const strokeDashoffset = {
[DashStyle.Draw]: 'none',
[DashStyle.Solid]: `none`,
[DashStyle.Dotted]: `-${strokeWidth / 20}`,
[DashStyle.Dashed]: `-${strokeWidth}`,
}[style.dash]
const path = Utils.getFromCache(simplePathCache, points, () => getSolidStrokePath(shape))
const sw = strokeWidth * 1.618
return (
<SVGContainer ref={ref} {...events}>
<path
d={path}
fill={shouldFill ? styles.fill : 'none'}
stroke="transparent"
strokeWidth={Math.min(4, strokeWidth * 2)}
strokeLinejoin="round"
strokeLinecap="round"
pointerEvents={shouldFill ? 'all' : 'stroke'}
/>
<path
d={path}
fill="transparent"
stroke={styles.stroke}
strokeWidth={sw}
strokeDasharray={strokeDasharray}
strokeDashoffset={strokeDashoffset}
strokeLinejoin="round"
strokeLinecap="round"
pointerEvents="stroke"
/>
</SVGContainer>
)
},
Indicator({ shape }) {
const { points } = shape
const bounds = this.getBounds(shape)
@ -167,34 +153,23 @@ export class Draw extends TLDrawShapeUtil<DrawShape, SVGSVGElement> {
return <circle x={bounds.width / 2} y={bounds.height / 2} r={1} />
}
const path = Utils.getFromCache(this.simplePathCache, points, () => getSolidStrokePath(shape))
const path = Utils.getFromCache(simplePathCache, points, () => getSolidStrokePath(shape))
return <path d={path} />
}
},
getBounds(shape: DrawShape): TLBounds {
return Utils.translateBounds(
Utils.getFromCache(this.pointsBoundsCache, shape.points, () =>
Utils.getFromCache(pointsBoundsCache, shape.points, () =>
Utils.getBoundsFromPoints(shape.points)
),
shape.point
)
}
},
getRotatedBounds(shape: DrawShape): TLBounds {
return Utils.translateBounds(
Utils.getBoundsFromPoints(shape.points, shape.rotation),
shape.point
)
}
getCenter(shape: DrawShape): number[] {
return Utils.getBoundsCenter(this.getBounds(shape))
}
hitTest(): boolean {
return true
}
shouldRender(prev: DrawShape, next: DrawShape): boolean {
return next.points !== prev.points || next.style !== prev.style
},
hitTestBounds(shape: DrawShape, brushBounds: TLBounds): boolean {
// Test axis-aligned shape
@ -215,7 +190,7 @@ export class Draw extends TLDrawShapeUtil<DrawShape, SVGSVGElement> {
// Test rotated shape
const rBounds = this.getRotatedBounds(shape)
const rotatedBounds = Utils.getFromCache(this.rotatedCache, shape, () => {
const rotatedBounds = Utils.getFromCache(rotatedCache, shape, () => {
const c = Utils.getBoundsCenter(Utils.getBoundsFromPoints(shape.points))
return shape.points.map((pt) => Vec.rotWith(pt, c, shape.rotation || 0))
})
@ -227,7 +202,7 @@ export class Draw extends TLDrawShapeUtil<DrawShape, SVGSVGElement> {
rotatedBounds
).length > 0
)
}
},
transform(
shape: DrawShape,
@ -260,16 +235,12 @@ export class Draw extends TLDrawShapeUtil<DrawShape, SVGSVGElement> {
points,
point,
}
}
},
}))
transformSingle(
shape: DrawShape,
bounds: TLBounds,
info: TLTransformInfo<DrawShape>
): Partial<DrawShape> {
return this.transform(shape, bounds, info)
}
}
/* -------------------------------------------------- */
/* Helpers */
/* -------------------------------------------------- */
const simulatePressureSettings = {
simulatePressure: true,

View file

@ -0,0 +1,26 @@
// Jest Snapshot v1, https://goo.gl/fbAQLP
exports[`Ellipse shape Creates a shape: ellipse 1`] = `
Object {
"childIndex": 1,
"id": "ellipse",
"name": "Ellipse",
"parentId": "page",
"point": Array [
0,
0,
],
"radius": Array [
1,
1,
],
"rotation": 0,
"style": Object {
"color": "Black",
"dash": "Draw",
"isFilled": false,
"size": "Medium",
},
"type": "ellipse",
}
`;

View file

@ -1,7 +1,7 @@
import { Ellipse } from './ellipse'
describe('Ellipse shape', () => {
it('Creates an instance', () => {
new Ellipse()
it('Creates a shape', () => {
expect(Ellipse.create({ id: 'ellipse' })).toMatchSnapshot('ellipse')
})
})

View file

@ -1,34 +1,27 @@
import * as React from 'react'
import { SVGContainer, Utils, TLTransformInfo, TLBounds, TLShapeProps } from '@tldraw/core'
import { SVGContainer, Utils, ShapeUtil, TLTransformInfo, TLBounds } from '@tldraw/core'
import { Vec } from '@tldraw/vec'
import {
ArrowShape,
DashStyle,
EllipseShape,
TLDrawShapeType,
TLDrawShapeUtil,
TLDrawToolType,
} from '~types'
import { DashStyle, EllipseShape, TLDrawShapeType, TLDrawMeta, TLDrawToolType } from '~types'
import { defaultStyle, getPerfectDashProps, getShapeStyle } from '~shape/shape-styles'
import getStroke from 'perfect-freehand'
import {
intersectBoundsEllipse,
intersectLineSegmentEllipse,
intersectPolylineBounds,
intersectRayEllipse,
} from '@tldraw/intersect'
// TODO
// [ ] Improve indicator shape for drawn shapes
export const Ellipse = new ShapeUtil<EllipseShape, SVGSVGElement, TLDrawMeta>(() => ({
type: TLDrawShapeType.Ellipse,
export class Ellipse extends TLDrawShapeUtil<EllipseShape, SVGSVGElement> {
type = TLDrawShapeType.Ellipse as const
toolType = TLDrawToolType.Bounds
pathCache = new WeakMap<EllipseShape, string>([])
canBind = true
toolType: TLDrawToolType.Bounds,
defaultProps = {
pathCache: new WeakMap<EllipseShape, string>([]),
canBind: true,
defaultProps: {
id: 'id',
type: TLDrawShapeType.Ellipse as const,
type: TLDrawShapeType.Ellipse,
name: 'Ellipse',
parentId: 'page',
childIndex: 1,
@ -36,75 +29,22 @@ export class Ellipse extends TLDrawShapeUtil<EllipseShape, SVGSVGElement> {
radius: [1, 1],
rotation: 0,
style: defaultStyle,
}
},
shouldRender(prev: EllipseShape, next: EllipseShape) {
return next.radius !== prev.radius || next.style !== prev.style
}
Component({ shape, meta, isBinding, events }, ref) {
const {
radius: [radiusX, radiusY],
style,
} = shape
render = React.forwardRef<SVGSVGElement, TLShapeProps<EllipseShape, SVGSVGElement>>(
({ shape, meta, isBinding, events }, ref) => {
const {
radius: [radiusX, radiusY],
style,
} = shape
const styles = getShapeStyle(style, meta.isDarkMode)
const strokeWidth = +styles.strokeWidth
const styles = getShapeStyle(style, meta.isDarkMode)
const strokeWidth = +styles.strokeWidth
const rx = Math.max(0, radiusX - strokeWidth / 2)
const ry = Math.max(0, radiusY - strokeWidth / 2)
const rx = Math.max(0, radiusX - strokeWidth / 2)
const ry = Math.max(0, radiusY - strokeWidth / 2)
if (style.dash === DashStyle.Draw) {
const path = Utils.getFromCache(this.pathCache, shape, () =>
renderPath(shape, this.getCenter(shape))
)
return (
<SVGContainer ref={ref} {...events}>
{isBinding && (
<ellipse
className="tl-binding-indicator"
cx={radiusX}
cy={radiusY}
rx={rx + 2}
ry={ry + 2}
/>
)}
<ellipse
cx={radiusX}
cy={radiusY}
rx={rx}
ry={ry}
stroke="none"
fill={style.isFilled ? styles.fill : 'none'}
pointerEvents="all"
/>
<path
d={path}
fill={styles.stroke}
stroke={styles.stroke}
strokeWidth={strokeWidth}
pointerEvents="all"
strokeLinecap="round"
strokeLinejoin="round"
/>
</SVGContainer>
)
}
const h = Math.pow(rx - ry, 2) / Math.pow(rx + ry, 2)
const perimeter = Math.PI * (rx + ry) * (1 + (3 * h) / (10 + Math.sqrt(4 - 3 * h)))
const { strokeDasharray, strokeDashoffset } = getPerfectDashProps(
perimeter,
strokeWidth * 1.618,
shape.style.dash,
4
)
const sw = strokeWidth * 1.618
if (style.dash === DashStyle.Draw) {
const path = renderPath(shape, this.getCenter(shape))
return (
<SVGContainer ref={ref} {...events}>
@ -113,8 +53,8 @@ export class Ellipse extends TLDrawShapeUtil<EllipseShape, SVGSVGElement> {
className="tl-binding-indicator"
cx={radiusX}
cy={radiusY}
rx={rx + 32}
ry={ry + 32}
rx={rx + 2}
ry={ry + 2}
/>
)}
<ellipse
@ -122,11 +62,15 @@ export class Ellipse extends TLDrawShapeUtil<EllipseShape, SVGSVGElement> {
cy={radiusY}
rx={rx}
ry={ry}
fill={styles.fill}
stroke="none"
fill={style.isFilled ? styles.fill : 'none'}
pointerEvents="all"
/>
<path
d={path}
fill={styles.stroke}
stroke={styles.stroke}
strokeWidth={sw}
strokeDasharray={strokeDasharray}
strokeDashoffset={strokeDashoffset}
strokeWidth={strokeWidth}
pointerEvents="all"
strokeLinecap="round"
strokeLinejoin="round"
@ -134,9 +78,50 @@ export class Ellipse extends TLDrawShapeUtil<EllipseShape, SVGSVGElement> {
</SVGContainer>
)
}
)
renderIndicator(shape: EllipseShape) {
const h = Math.pow(rx - ry, 2) / Math.pow(rx + ry, 2)
const perimeter = Math.PI * (rx + ry) * (1 + (3 * h) / (10 + Math.sqrt(4 - 3 * h)))
const { strokeDasharray, strokeDashoffset } = getPerfectDashProps(
perimeter,
strokeWidth * 1.618,
shape.style.dash,
4
)
const sw = strokeWidth * 1.618
return (
<SVGContainer ref={ref} {...events}>
{isBinding && (
<ellipse
className="tl-binding-indicator"
cx={radiusX}
cy={radiusY}
rx={rx + 32}
ry={ry + 32}
/>
)}
<ellipse
cx={radiusX}
cy={radiusY}
rx={rx}
ry={ry}
fill={styles.fill}
stroke={styles.stroke}
strokeWidth={sw}
strokeDasharray={strokeDasharray}
strokeDashoffset={strokeDashoffset}
pointerEvents="all"
strokeLinecap="round"
strokeLinejoin="round"
/>
</SVGContainer>
)
},
Indicator({ shape }) {
const {
style,
radius: [rx, ry],
@ -147,10 +132,16 @@ export class Ellipse extends TLDrawShapeUtil<EllipseShape, SVGSVGElement> {
const sw = strokeWidth
// TODO Improve indicator shape for drawn shapes, which are
// intentionally not perfect circles.
return <ellipse cx={rx} cy={ry} rx={rx - sw / 2} ry={ry - sw / 2} />
}
},
getBounds(shape: EllipseShape) {
shouldRender(prev, next) {
return next.radius !== prev.radius || next.style !== prev.style
},
getBounds(shape) {
return Utils.getFromCache(this.boundsCache, shape, () => {
return Utils.getRotatedEllipseBounds(
shape.point[0],
@ -160,38 +151,39 @@ export class Ellipse extends TLDrawShapeUtil<EllipseShape, SVGSVGElement> {
shape.rotation || 0
)
})
}
},
getRotatedBounds(shape: EllipseShape) {
getRotatedBounds(shape) {
return Utils.getBoundsFromPoints(Utils.getRotatedCorners(this.getBounds(shape), shape.rotation))
}
},
getCenter(shape: EllipseShape): number[] {
getCenter(shape): number[] {
return [shape.point[0] + shape.radius[0], shape.point[1] + shape.radius[1]]
}
},
hitTest(shape: EllipseShape, point: number[]) {
return Utils.pointInBounds(point, this.getBounds(shape))
}
hitTestBounds(shape: EllipseShape, bounds: TLBounds) {
const rotatedCorners = Utils.getRotatedCorners(this.getBounds(shape), shape.rotation)
return (
rotatedCorners.every((point) => Utils.pointInBounds(point, bounds)) ||
intersectPolylineBounds(rotatedCorners, bounds).length > 0
hitTest(shape, point: number[]) {
return Utils.pointInEllipse(
point,
this.getCenter(shape),
shape.radius[0],
shape.radius[1],
shape.rotation
)
}
},
getBindingPoint(
shape: EllipseShape,
fromShape: ArrowShape,
point: number[],
origin: number[],
direction: number[],
padding: number,
anywhere: boolean
) {
hitTestBounds(shape, bounds) {
return (
intersectBoundsEllipse(
bounds,
this.getCenter(shape),
shape.radius[0],
shape.radius[1],
shape.rotation
).length > 0
)
},
getBindingPoint(shape, fromShape, point, origin, direction, padding, anywhere) {
{
const bounds = this.getBounds(shape)
@ -216,14 +208,6 @@ export class Ellipse extends TLDrawShapeUtil<EllipseShape, SVGSVGElement> {
distance = 0
} else {
// Find furthest intersection between ray from
// origin through point and expanded bounds.
// const intersection = Intersect.ray
// .bounds(origin, direction, expandedBounds)
// .filter((int) => int.didIntersect)
// .map((int) => int.points[0])
// .sort((a, b) => Vec.dist(b, origin) - Vec.dist(a, origin))[0]
let intersection = intersectRayEllipse(
origin,
direction,
@ -287,10 +271,10 @@ export class Ellipse extends TLDrawShapeUtil<EllipseShape, SVGSVGElement> {
distance,
}
}
}
},
transform(
_shape: EllipseShape,
_shape,
bounds: TLBounds,
{ scaleX, scaleY, initialShape }: TLTransformInfo<EllipseShape>
) {
@ -304,15 +288,19 @@ export class Ellipse extends TLDrawShapeUtil<EllipseShape, SVGSVGElement> {
? -(rotation || 0)
: rotation || 0,
}
}
},
transformSingle(shape: EllipseShape, bounds: TLBounds) {
transformSingle(shape, bounds: TLBounds) {
return {
point: Vec.round([bounds.minX, bounds.minY]),
radius: Vec.div([bounds.width, bounds.height], 2),
}
}
}
},
}))
/* -------------------------------------------------- */
/* Helpers */
/* -------------------------------------------------- */
function renderPath(shape: EllipseShape, boundsCenter: number[]) {
const {

View file

@ -0,0 +1,27 @@
// Jest Snapshot v1, https://goo.gl/fbAQLP
exports[`Group shape Creates a shape: group 1`] = `
Object {
"childIndex": 1,
"children": Array [],
"id": "group",
"name": "Group",
"parentId": "page",
"point": Array [
0,
0,
],
"rotation": 0,
"size": Array [
100,
100,
],
"style": Object {
"color": "Black",
"dash": "Draw",
"isFilled": false,
"size": "Medium",
},
"type": "group",
}
`;

View file

@ -1,7 +1,7 @@
import { Group } from './group'
describe('Group shape', () => {
it('Creates an instance', () => {
new Group()
it('Creates a shape', () => {
expect(Group.create({ id: 'group' })).toMatchSnapshot('group')
})
})

View file

@ -1,31 +1,26 @@
import * as React from 'react'
import { SVGContainer, TLBounds, Utils, TLShapeProps } from '@tldraw/core'
import { Vec } from '@tldraw/vec'
import { intersectRayBounds, intersectPolylineBounds } from '@tldraw/intersect'
import { SVGContainer, Utils, ShapeUtil } from '@tldraw/core'
import { defaultStyle, getPerfectDashProps } from '~shape/shape-styles'
import {
GroupShape,
TLDrawShapeUtil,
TLDrawShapeType,
TLDrawToolType,
ColorStyle,
DashStyle,
ArrowShape,
TLDrawMeta,
} from '~types'
import { getBoundsRectangle } from '../shared'
// TODO
// [ ] - Find bounds based on common bounds of descendants
export const Group = new ShapeUtil<GroupShape, SVGSVGElement, TLDrawMeta>(() => ({
type: TLDrawShapeType.Group,
export class Group extends TLDrawShapeUtil<GroupShape, SVGSVGElement> {
type = TLDrawShapeType.Group as const
toolType = TLDrawToolType.Bounds
canBind = true
toolType: TLDrawToolType.Bounds,
pathCache = new WeakMap<number[], string>([])
canBind: true,
defaultProps: GroupShape = {
defaultProps: {
id: 'id',
type: TLDrawShapeType.Group as const,
type: TLDrawShapeType.Group,
name: 'Group',
parentId: 'page',
childIndex: 1,
@ -34,76 +29,63 @@ export class Group extends TLDrawShapeUtil<GroupShape, SVGSVGElement> {
rotation: 0,
children: [],
style: defaultStyle,
}
},
shouldRender(prev: GroupShape, next: GroupShape) {
return next.size !== prev.size || next.style !== prev.style
}
Component({ shape, isBinding, isHovered, isSelected, events }, ref) {
const { id, size } = shape
render = React.forwardRef<SVGSVGElement, TLShapeProps<GroupShape, SVGSVGElement>>(
({ shape, isBinding, isHovered, isSelected, events }, ref) => {
const { id, size } = shape
const sw = 2
const w = Math.max(0, size[0] - sw / 2)
const h = Math.max(0, size[1] - sw / 2)
const sw = 2
const w = Math.max(0, size[0] - sw / 2)
const h = Math.max(0, size[1] - sw / 2)
const strokes: [number[], number[], number][] = [
[[sw / 2, sw / 2], [w, sw / 2], w - sw / 2],
[[w, sw / 2], [w, h], h - sw / 2],
[[w, h], [sw / 2, h], w - sw / 2],
[[sw / 2, h], [sw / 2, sw / 2], h - sw / 2],
]
const strokes: [number[], number[], number][] = [
[[sw / 2, sw / 2], [w, sw / 2], w - sw / 2],
[[w, sw / 2], [w, h], h - sw / 2],
[[w, h], [sw / 2, h], w - sw / 2],
[[sw / 2, h], [sw / 2, sw / 2], h - sw / 2],
]
const paths = strokes.map(([start, end, length], i) => {
const { strokeDasharray, strokeDashoffset } = getPerfectDashProps(
length,
sw,
DashStyle.Dotted
)
return (
<line
key={id + '_' + i}
x1={start[0]}
y1={start[1]}
x2={end[0]}
y2={end[1]}
stroke={ColorStyle.Black}
strokeWidth={isHovered || isSelected ? sw : 0}
strokeLinecap="round"
strokeDasharray={strokeDasharray}
strokeDashoffset={strokeDashoffset}
/>
)
})
const paths = strokes.map(([start, end, length], i) => {
const { strokeDasharray, strokeDashoffset } = getPerfectDashProps(
length,
sw,
DashStyle.Dotted
)
return (
<SVGContainer ref={ref} {...events}>
{isBinding && (
<rect
className="tl-binding-indicator"
x={-32}
y={-32}
width={size[0] + 64}
height={size[1] + 64}
/>
)}
<rect
x={0}
y={0}
width={size[0]}
height={size[1]}
fill="transparent"
pointerEvents="all"
/>
<g pointerEvents="stroke">{paths}</g>
</SVGContainer>
<line
key={id + '_' + i}
x1={start[0]}
y1={start[1]}
x2={end[0]}
y2={end[1]}
stroke={ColorStyle.Black}
strokeWidth={isHovered || isSelected ? sw : 0}
strokeLinecap="round"
strokeDasharray={strokeDasharray}
strokeDashoffset={strokeDashoffset}
/>
)
}
)
})
renderIndicator(shape: GroupShape) {
return (
<SVGContainer ref={ref} {...events}>
{isBinding && (
<rect
className="tl-binding-indicator"
x={-32}
y={-32}
width={size[0] + 64}
height={size[1] + 64}
/>
)}
<rect x={0} y={0} width={size[0]} height={size[1]} fill="transparent" pointerEvents="all" />
<g pointerEvents="stroke">{paths}</g>
</SVGContainer>
)
},
Indicator({ shape }) {
const [width, height] = shape.size
const sw = 2
@ -118,126 +100,13 @@ export class Group extends TLDrawShapeUtil<GroupShape, SVGSVGElement> {
height={Math.max(1, height - sw)}
/>
)
}
},
getBounds(shape: GroupShape) {
const bounds = Utils.getFromCache(this.boundsCache, shape, () => {
const [width, height] = shape.size
return {
minX: 0,
maxX: width,
minY: 0,
maxY: height,
width,
height,
}
})
shouldRender(prev, next) {
return next.size !== prev.size || next.style !== prev.style
},
return Utils.translateBounds(bounds, shape.point)
}
getRotatedBounds(shape: GroupShape) {
return Utils.getBoundsFromPoints(Utils.getRotatedCorners(this.getBounds(shape), shape.rotation))
}
getCenter(shape: GroupShape): number[] {
return Utils.getBoundsCenter(this.getBounds(shape))
}
getBindingPoint(
shape: GroupShape,
fromShape: ArrowShape,
point: number[],
origin: number[],
direction: number[],
padding: number,
anywhere: boolean
) {
const bounds = this.getBounds(shape)
const expandedBounds = Utils.expandBounds(bounds, padding)
let bindingPoint: number[]
let distance: number
// The point must be inside of the expanded bounding box
if (!Utils.pointInBounds(point, expandedBounds)) return
// The point is inside of the shape, so we'll assume the user is
// indicating a specific point inside of the shape.
if (anywhere) {
if (Vec.dist(point, this.getCenter(shape)) < 12) {
bindingPoint = [0.5, 0.5]
} else {
bindingPoint = Vec.divV(Vec.sub(point, [expandedBounds.minX, expandedBounds.minY]), [
expandedBounds.width,
expandedBounds.height,
])
}
distance = 0
} else {
// Find furthest intersection between ray from
// origin through point and expanded bounds.
// TODO: Make this a ray vs rounded rect intersection
const intersection = intersectRayBounds(origin, direction, expandedBounds)
.filter((int) => int.didIntersect)
.map((int) => int.points[0])
.sort((a, b) => Vec.dist(b, origin) - Vec.dist(a, origin))[0]
// The anchor is a point between the handle and the intersection
const anchor = Vec.med(point, intersection)
// If we're close to the center, snap to the center
if (Vec.distanceToLineSegment(point, anchor, this.getCenter(shape)) < 12) {
bindingPoint = [0.5, 0.5]
} else {
// Or else calculate a normalized point
bindingPoint = Vec.divV(Vec.sub(anchor, [expandedBounds.minX, expandedBounds.minY]), [
expandedBounds.width,
expandedBounds.height,
])
}
if (Utils.pointInBounds(point, bounds)) {
distance = 16
} else {
// If the binding point was close to the shape's center, snap to the center
// Find the distance between the point and the real bounds of the shape
distance = Math.max(
16,
Utils.getBoundsSides(bounds)
.map((side) => Vec.distanceToLineSegment(side[1][0], side[1][1], point))
.sort((a, b) => a - b)[0]
)
}
}
return {
point: Vec.clampV(bindingPoint, 0, 1),
distance,
}
}
hitTest(shape: GroupShape, point: number[]) {
return Utils.pointInBounds(point, this.getBounds(shape))
}
hitTestBounds(shape: GroupShape, bounds: TLBounds) {
const rotatedCorners = Utils.getRotatedCorners(this.getBounds(shape), shape.rotation)
return (
rotatedCorners.every((point) => Utils.pointInBounds(point, bounds)) ||
intersectPolylineBounds(rotatedCorners, bounds).length > 0
)
}
transform() {
return {}
}
transformSingle() {
return {}
}
}
getBounds(shape) {
return getBoundsRectangle(shape, this.boundsCache)
},
}))

View file

@ -1,7 +1,8 @@
import { PostIt } from './post-it'
describe('Post-It shape', () => {
it('Creates an instance', () => {
new PostIt()
it('Creates a shape', () => {
expect(PostIt.create).toBeDefined()
// expect(PostIt.create({ id: 'postit' })).toMatchSnapshot('postit')
})
})

View file

@ -1,22 +1,21 @@
import * as React from 'react'
import { TLBounds, Utils, TLTransformInfo, TLShapeProps, HTMLContainer } from '@tldraw/core'
import { Vec } from '@tldraw/vec'
import { HTMLContainer, ShapeUtil } from '@tldraw/core'
import { defaultStyle, getShapeStyle } from '~shape/shape-styles'
import { PostItShape, TLDrawShapeUtil, TLDrawShapeType, TLDrawToolType, ArrowShape } from '~types'
import { intersectPolylineBounds, intersectRayBounds } from '@tldraw/intersect'
import { PostItShape, TLDrawMeta, TLDrawShapeType, TLDrawToolType } from '~types'
import { getBoundsRectangle, transformRectangle, transformSingleRectangle } from '../shared'
// TODO
// [ ] - Make sure that fill does not extend drawn shape at corners
export const PostIt = new ShapeUtil<PostItShape, HTMLDivElement, TLDrawMeta>(() => ({
type: TLDrawShapeType.PostIt,
export class PostIt extends TLDrawShapeUtil<PostItShape, HTMLDivElement> {
type = TLDrawShapeType.PostIt as const
toolType = TLDrawToolType.Bounds
canBind = true
pathCache = new WeakMap<number[], string>([])
toolType: TLDrawToolType.Bounds,
defaultProps: PostItShape = {
canBind: true,
pathCache: new WeakMap<number[], string>([]),
defaultProps: {
id: 'id',
type: TLDrawShapeType.PostIt as const,
type: TLDrawShapeType.PostIt,
name: 'PostIt',
parentId: 'page',
childIndex: 1,
@ -25,43 +24,41 @@ export class PostIt extends TLDrawShapeUtil<PostItShape, HTMLDivElement> {
text: '',
rotation: 0,
style: defaultStyle,
}
},
shouldRender(prev: PostItShape, next: PostItShape) {
shouldRender(prev, next) {
return next.size !== prev.size || next.style !== prev.style
}
},
render = React.forwardRef<HTMLDivElement, TLShapeProps<PostItShape, HTMLDivElement>>(
({ shape, isBinding, meta, events }, ref) => {
const [count, setCount] = React.useState(0)
Component({ events }, ref) {
const [count, setCount] = React.useState(0)
return (
<HTMLContainer ref={ref} {...events}>
<div
style={{
pointerEvents: 'all',
backgroundColor: 'rgba(255, 220, 100)',
border: '1px solid black',
fontFamily: 'sans-serif',
height: '100%',
width: '100%',
}}
>
<div onPointerDown={(e) => e.preventDefault()}>
<input
type="textarea"
style={{ width: '100%', height: '50%', background: 'none' }}
onPointerDown={(e) => e.stopPropagation()}
/>
<button onPointerDown={() => setCount((count) => count + 1)}>{count}</button>
</div>
return (
<HTMLContainer ref={ref} {...events}>
<div
style={{
pointerEvents: 'all',
backgroundColor: 'rgba(255, 220, 100)',
border: '1px solid black',
fontFamily: 'sans-serif',
height: '100%',
width: '100%',
}}
>
<div onPointerDown={(e) => e.preventDefault()}>
<input
type="textarea"
style={{ width: '100%', height: '50%', background: 'none' }}
onPointerDown={(e) => e.stopPropagation()}
/>
<button onPointerDown={() => setCount((count) => count + 1)}>{count}</button>
</div>
</HTMLContainer>
)
}
)
</div>
</HTMLContainer>
)
},
renderIndicator(shape: PostItShape) {
Indicator({ shape }) {
const {
style,
size: [width, height],
@ -82,162 +79,13 @@ export class PostIt extends TLDrawShapeUtil<PostItShape, HTMLDivElement> {
height={Math.max(1, height - sw)}
/>
)
}
},
getBounds(shape: PostItShape) {
const bounds = Utils.getFromCache(this.boundsCache, shape, () => {
const [width, height] = shape.size
return {
minX: 0,
maxX: width,
minY: 0,
maxY: height,
width,
height,
}
})
getBounds(shape) {
return getBoundsRectangle(shape, this.boundsCache)
},
return Utils.translateBounds(bounds, shape.point)
}
transform: transformRectangle,
getRotatedBounds(shape: PostItShape) {
return Utils.getBoundsFromPoints(Utils.getRotatedCorners(this.getBounds(shape), shape.rotation))
}
getCenter(shape: PostItShape): number[] {
return Utils.getBoundsCenter(this.getBounds(shape))
}
getBindingPoint(
shape: PostItShape,
fromShape: ArrowShape,
point: number[],
origin: number[],
direction: number[],
padding: number,
anywhere: boolean
) {
const bounds = this.getBounds(shape)
const expandedBounds = Utils.expandBounds(bounds, padding)
let bindingPoint: number[]
let distance: number
// The point must be inside of the expanded bounding box
if (!Utils.pointInBounds(point, expandedBounds)) return
// The point is inside of the shape, so we'll assume the user is
// indicating a specific point inside of the shape.
if (anywhere) {
if (Vec.dist(point, this.getCenter(shape)) < 12) {
bindingPoint = [0.5, 0.5]
} else {
bindingPoint = Vec.divV(Vec.sub(point, [expandedBounds.minX, expandedBounds.minY]), [
expandedBounds.width,
expandedBounds.height,
])
}
distance = 0
} else {
// TODO: What if the shape has a curve? In that case, should we
// intersect the circle-from-three-points instead?
// Find furthest intersection between ray from
// origin through point and expanded bounds.
// TODO: Make this a ray vs rounded rect intersection
const intersection = intersectRayBounds(origin, direction, expandedBounds)
.filter((int) => int.didIntersect)
.map((int) => int.points[0])
.sort((a, b) => Vec.dist(b, origin) - Vec.dist(a, origin))[0]
// The anchor is a point between the handle and the intersection
const anchor = Vec.med(point, intersection)
// If we're close to the center, snap to the center
if (Vec.distanceToLineSegment(point, anchor, this.getCenter(shape)) < 12) {
bindingPoint = [0.5, 0.5]
} else {
// Or else calculate a normalized point
bindingPoint = Vec.divV(Vec.sub(anchor, [expandedBounds.minX, expandedBounds.minY]), [
expandedBounds.width,
expandedBounds.height,
])
}
if (Utils.pointInBounds(point, bounds)) {
distance = 16
} else {
// If the binding point was close to the shape's center, snap to the center
// Find the distance between the point and the real bounds of the shape
distance = Math.max(
16,
Utils.getBoundsSides(bounds)
.map((side) => Vec.distanceToLineSegment(side[1][0], side[1][1], point))
.sort((a, b) => a - b)[0]
)
}
}
return {
point: Vec.clampV(bindingPoint, 0, 1),
distance,
}
}
hitTestBounds(shape: PostItShape, bounds: TLBounds) {
const rotatedCorners = Utils.getRotatedCorners(this.getBounds(shape), shape.rotation)
return (
rotatedCorners.every((point) => Utils.pointInBounds(point, bounds)) ||
intersectPolylineBounds(rotatedCorners, bounds).length > 0
)
}
transform(
shape: PostItShape,
bounds: TLBounds,
{ initialShape, transformOrigin, scaleX, scaleY }: TLTransformInfo<PostItShape>
) {
if (!shape.rotation && !shape.isAspectRatioLocked) {
return {
point: Vec.round([bounds.minX, bounds.minY]),
size: Vec.round([bounds.width, bounds.height]),
}
} else {
const size = Vec.round(
Vec.mul(initialShape.size, Math.min(Math.abs(scaleX), Math.abs(scaleY)))
)
const point = Vec.round([
bounds.minX +
(bounds.width - shape.size[0]) *
(scaleX < 0 ? 1 - transformOrigin[0] : transformOrigin[0]),
bounds.minY +
(bounds.height - shape.size[1]) *
(scaleY < 0 ? 1 - transformOrigin[1] : transformOrigin[1]),
])
const rotation =
(scaleX < 0 && scaleY >= 0) || (scaleY < 0 && scaleX >= 0)
? initialShape.rotation
? -initialShape.rotation
: 0
: initialShape.rotation
return {
size,
point,
rotation,
}
}
}
transformSingle(_shape: PostItShape, bounds: TLBounds) {
return {
size: Vec.round([bounds.width, bounds.height]),
point: Vec.round([bounds.minX, bounds.minY]),
}
}
}
transformSingle: transformSingleRectangle,
}))

View file

@ -0,0 +1,26 @@
// Jest Snapshot v1, https://goo.gl/fbAQLP
exports[`Rectangle shape Creates a shape: rectangle 1`] = `
Object {
"childIndex": 1,
"id": "rectangle",
"name": "Rectangle",
"parentId": "page",
"point": Array [
0,
0,
],
"rotation": 0,
"size": Array [
1,
1,
],
"style": Object {
"color": "Black",
"dash": "Draw",
"isFilled": false,
"size": "Medium",
},
"type": "rectangle",
}
`;

View file

@ -1,7 +1,7 @@
import { Rectangle } from './rectangle'
describe('Rectangle shape', () => {
it('Creates an instance', () => {
new Rectangle()
it('Creates a shape', () => {
expect(Rectangle.create({ id: 'rectangle' })).toMatchSnapshot('rectangle')
})
})

View file

@ -1,30 +1,23 @@
import * as React from 'react'
import { TLBounds, Utils, TLTransformInfo, TLShapeProps, SVGContainer } from '@tldraw/core'
import { intersectRayBounds } from '@tldraw/intersect'
import { Utils, SVGContainer, ShapeUtil } from '@tldraw/core'
import { Vec } from '@tldraw/vec'
import getStroke from 'perfect-freehand'
import { getPerfectDashProps, defaultStyle, getShapeStyle } from '~shape/shape-styles'
import {
RectangleShape,
DashStyle,
TLDrawShapeUtil,
TLDrawShapeType,
TLDrawToolType,
ArrowShape,
} from '~types'
import { RectangleShape, DashStyle, TLDrawShapeType, TLDrawToolType, TLDrawMeta } from '~types'
import { getBoundsRectangle, transformRectangle, transformSingleRectangle } from '../shared'
// TODO
// [ ] - Make sure that fill does not extend drawn shape at corners
const pathCache = new WeakMap<number[], string>([])
export class Rectangle extends TLDrawShapeUtil<RectangleShape, SVGSVGElement> {
type = TLDrawShapeType.Rectangle as const
toolType = TLDrawToolType.Bounds
canBind = true
pathCache = new WeakMap<number[], string>([])
export const Rectangle = new ShapeUtil<RectangleShape, SVGSVGElement, TLDrawMeta>(() => ({
type: TLDrawShapeType.Rectangle,
defaultProps: RectangleShape = {
toolType: TLDrawToolType.Bounds,
canBind: true,
defaultProps: {
id: 'id',
type: TLDrawShapeType.Rectangle as const,
type: TLDrawShapeType.Rectangle,
name: 'Rectangle',
parentId: 'page',
childIndex: 1,
@ -32,115 +25,116 @@ export class Rectangle extends TLDrawShapeUtil<RectangleShape, SVGSVGElement> {
size: [1, 1],
rotation: 0,
style: defaultStyle,
}
},
shouldRender(prev: RectangleShape, next: RectangleShape) {
shouldRender(prev, next) {
return next.size !== prev.size || next.style !== prev.style
}
},
render = React.forwardRef<SVGSVGElement, TLShapeProps<RectangleShape, SVGSVGElement>>(
({ shape, isBinding, meta, events }, ref) => {
const { id, size, style } = shape
const styles = getShapeStyle(style, meta.isDarkMode)
const strokeWidth = +styles.strokeWidth
Component({ shape, isBinding, meta, events }, ref) {
const { id, size, style } = shape
const styles = getShapeStyle(style, meta.isDarkMode)
const strokeWidth = +styles.strokeWidth
if (style.dash === DashStyle.Draw) {
const pathData = Utils.getFromCache(this.pathCache, shape.size, () => renderPath(shape))
this
return (
<SVGContainer ref={ref} {...events}>
{isBinding && (
<rect
className="tl-binding-indicator"
x={strokeWidth / 2 - 32}
y={strokeWidth / 2 - 32}
width={Math.max(0, size[0] - strokeWidth / 2) + 64}
height={Math.max(0, size[1] - strokeWidth / 2) + 64}
/>
)}
<rect
x={+styles.strokeWidth / 2}
y={+styles.strokeWidth / 2}
width={Math.max(0, size[0] - strokeWidth)}
height={Math.max(0, size[1] - strokeWidth)}
fill={style.isFilled ? styles.fill : 'none'}
stroke="none"
pointerEvents="all"
/>
<path
d={pathData}
fill={styles.stroke}
stroke={styles.stroke}
strokeWidth={styles.strokeWidth}
pointerEvents="all"
/>
</SVGContainer>
)
}
const sw = strokeWidth * 1.618
const w = Math.max(0, size[0] - sw / 2)
const h = Math.max(0, size[1] - sw / 2)
const strokes: [number[], number[], number][] = [
[[sw / 2, sw / 2], [w, sw / 2], w - sw / 2],
[[w, sw / 2], [w, h], h - sw / 2],
[[w, h], [sw / 2, h], w - sw / 2],
[[sw / 2, h], [sw / 2, sw / 2], h - sw / 2],
]
const paths = strokes.map(([start, end, length], i) => {
const { strokeDasharray, strokeDashoffset } = getPerfectDashProps(
length,
sw,
shape.style.dash
)
return (
<line
key={id + '_' + i}
x1={start[0]}
y1={start[1]}
x2={end[0]}
y2={end[1]}
stroke={styles.stroke}
strokeWidth={sw}
strokeLinecap="round"
strokeDasharray={strokeDasharray}
strokeDashoffset={strokeDashoffset}
/>
)
})
if (style.dash === DashStyle.Draw) {
const pathData = Utils.getFromCache(pathCache, shape.size, () => renderPath(shape))
return (
<SVGContainer ref={ref} {...events}>
{isBinding && (
<rect
className="tl-binding-indicator"
x={sw / 2 - 32}
y={sw / 2 - 32}
width={w + 64}
height={h + 64}
x={strokeWidth / 2 - 32}
y={strokeWidth / 2 - 32}
width={Math.max(0, size[0] - strokeWidth / 2) + 64}
height={Math.max(0, size[1] - strokeWidth / 2) + 64}
/>
)}
<rect
x={sw / 2}
y={sw / 2}
width={w}
height={h}
fill={styles.fill}
stroke="transparent"
strokeWidth={sw}
x={+styles.strokeWidth / 2}
y={+styles.strokeWidth / 2}
width={Math.max(0, size[0] - strokeWidth)}
height={Math.max(0, size[1] - strokeWidth)}
fill={style.isFilled ? styles.fill : 'none'}
radius={strokeWidth}
stroke="none"
pointerEvents="all"
/>
<path
d={pathData}
fill={styles.stroke}
stroke={styles.stroke}
strokeWidth={styles.strokeWidth}
pointerEvents="all"
/>
<g pointerEvents="stroke">{paths}</g>
</SVGContainer>
)
}
)
renderIndicator(shape: RectangleShape) {
const sw = strokeWidth * 1.618
const w = Math.max(0, size[0] - sw / 2)
const h = Math.max(0, size[1] - sw / 2)
const strokes: [number[], number[], number][] = [
[[sw / 2, sw / 2], [w, sw / 2], w - sw / 2],
[[w, sw / 2], [w, h], h - sw / 2],
[[w, h], [sw / 2, h], w - sw / 2],
[[sw / 2, h], [sw / 2, sw / 2], h - sw / 2],
]
const paths = strokes.map(([start, end, length], i) => {
const { strokeDasharray, strokeDashoffset } = getPerfectDashProps(
length,
sw,
shape.style.dash
)
return (
<line
key={id + '_' + i}
x1={start[0]}
y1={start[1]}
x2={end[0]}
y2={end[1]}
stroke={styles.stroke}
strokeWidth={sw}
strokeLinecap="round"
strokeDasharray={strokeDasharray}
strokeDashoffset={strokeDashoffset}
/>
)
})
return (
<SVGContainer ref={ref} {...events}>
{isBinding && (
<rect
className="tl-binding-indicator"
x={sw / 2 - 32}
y={sw / 2 - 32}
width={w + 64}
height={h + 64}
/>
)}
<rect
x={sw / 2}
y={sw / 2}
width={w}
height={h}
fill={styles.fill}
stroke="transparent"
strokeWidth={sw}
pointerEvents="all"
/>
<g pointerEvents="stroke">{paths}</g>
</SVGContainer>
)
},
Indicator({ shape }) {
const {
style,
size: [width, height],
@ -161,156 +155,20 @@ export class Rectangle extends TLDrawShapeUtil<RectangleShape, SVGSVGElement> {
height={Math.max(1, height - sw)}
/>
)
}
},
getBounds(shape: RectangleShape) {
const bounds = Utils.getFromCache(this.boundsCache, shape, () => {
const [width, height] = shape.size
return {
minX: 0,
maxX: width,
minY: 0,
maxY: height,
width,
height,
}
})
getBounds(shape) {
return getBoundsRectangle(shape, this.boundsCache)
},
return Utils.translateBounds(bounds, shape.point)
}
transform: transformRectangle,
getRotatedBounds(shape: RectangleShape) {
return Utils.getBoundsFromPoints(Utils.getRotatedCorners(this.getBounds(shape), shape.rotation))
}
transformSingle: transformSingleRectangle,
}))
getCenter(shape: RectangleShape): number[] {
return Utils.getBoundsCenter(this.getBounds(shape))
}
getBindingPoint(
shape: RectangleShape,
fromShape: ArrowShape,
point: number[],
origin: number[],
direction: number[],
padding: number,
anywhere: boolean
) {
const bounds = this.getBounds(shape)
const expandedBounds = Utils.expandBounds(bounds, padding)
let bindingPoint: number[]
let distance: number
// The point must be inside of the expanded bounding box
if (!Utils.pointInBounds(point, expandedBounds)) return
// The point is inside of the shape, so we'll assume the user is
// indicating a specific point inside of the shape.
if (anywhere) {
if (Vec.dist(point, this.getCenter(shape)) < 12) {
bindingPoint = [0.5, 0.5]
} else {
bindingPoint = Vec.divV(Vec.sub(point, [expandedBounds.minX, expandedBounds.minY]), [
expandedBounds.width,
expandedBounds.height,
])
}
distance = 0
} else {
// TODO: What if the shape has a curve? In that case, should we
// intersect the circle-from-three-points instead?
// Find furthest intersection between ray from
// origin through point and expanded bounds.
// TODO: Make this a ray vs rounded rect intersection
const intersection = intersectRayBounds(origin, direction, expandedBounds)
.filter((int) => int.didIntersect)
.map((int) => int.points[0])
.sort((a, b) => Vec.dist(b, origin) - Vec.dist(a, origin))[0]
// The anchor is a point between the handle and the intersection
const anchor = Vec.med(point, intersection)
// If we're close to the center, snap to the center
if (Vec.distanceToLineSegment(point, anchor, this.getCenter(shape)) < 12) {
bindingPoint = [0.5, 0.5]
} else {
// Or else calculate a normalized point
bindingPoint = Vec.divV(Vec.sub(anchor, [expandedBounds.minX, expandedBounds.minY]), [
expandedBounds.width,
expandedBounds.height,
])
}
if (Utils.pointInBounds(point, bounds)) {
distance = 16
} else {
// If the binding point was close to the shape's center, snap to the center
// Find the distance between the point and the real bounds of the shape
distance = Math.max(
16,
Utils.getBoundsSides(bounds)
.map((side) => Vec.distanceToLineSegment(side[1][0], side[1][1], point))
.sort((a, b) => a - b)[0]
)
}
}
return {
point: Vec.clampV(bindingPoint, 0, 1),
distance,
}
}
transform(
shape: RectangleShape,
bounds: TLBounds,
{ initialShape, transformOrigin, scaleX, scaleY }: TLTransformInfo<RectangleShape>
) {
if (shape.rotation || shape.isAspectRatioLocked) {
const size = Vec.round(
Vec.mul(initialShape.size, Math.min(Math.abs(scaleX), Math.abs(scaleY)))
)
const point = Vec.round([
bounds.minX +
(bounds.width - shape.size[0]) *
(scaleX < 0 ? 1 - transformOrigin[0] : transformOrigin[0]),
bounds.minY +
(bounds.height - shape.size[1]) *
(scaleY < 0 ? 1 - transformOrigin[1] : transformOrigin[1]),
])
const rotation =
(scaleX < 0 && scaleY >= 0) || (scaleY < 0 && scaleX >= 0)
? initialShape.rotation
? -initialShape.rotation
: 0
: initialShape.rotation
return {
size,
point,
rotation,
}
} else {
return {
point: Vec.round([bounds.minX, bounds.minY]),
size: Vec.round([bounds.width, bounds.height]),
}
}
}
transformSingle(_shape: RectangleShape, bounds: TLBounds) {
return {
size: Vec.round([bounds.width, bounds.height]),
point: Vec.round([bounds.minX, bounds.minY]),
}
}
}
/* -------------------------------------------------- */
/* Helpers */
/* -------------------------------------------------- */
function renderPath(shape: RectangleShape) {
const styles = getShapeStyle(shape.style)

View file

@ -0,0 +1,83 @@
import { Vec } from '@tldraw/vec'
import { TLBounds, TLShape, TLTransformInfo, Utils } from '@tldraw/core'
/**
* Transform a rectangular shape.
* @param shape
* @param bounds
* @param param2
*/
export function transformRectangle<T extends TLShape & { size: number[] }>(
shape: T,
bounds: TLBounds,
{ initialShape, transformOrigin, scaleX, scaleY }: TLTransformInfo<T>
) {
if (shape.rotation || initialShape.isAspectRatioLocked) {
const size = Vec.round(Vec.mul(initialShape.size, Math.min(Math.abs(scaleX), Math.abs(scaleY))))
const point = Vec.round([
bounds.minX +
(bounds.width - shape.size[0]) * (scaleX < 0 ? 1 - transformOrigin[0] : transformOrigin[0]),
bounds.minY +
(bounds.height - shape.size[1]) *
(scaleY < 0 ? 1 - transformOrigin[1] : transformOrigin[1]),
])
const rotation =
(scaleX < 0 && scaleY >= 0) || (scaleY < 0 && scaleX >= 0)
? initialShape.rotation
? -initialShape.rotation
: 0
: initialShape.rotation
return {
size,
point,
rotation,
}
} else {
return {
point: Vec.round([bounds.minX, bounds.minY]),
size: Vec.round([bounds.width, bounds.height]),
}
}
}
/**
* Transform a single rectangular shape.
* @param shape
* @param bounds
*/
export function transformSingleRectangle<T extends TLShape & { size: number[] }>(
shape: T,
bounds: TLBounds
) {
return {
size: Vec.round([bounds.width, bounds.height]),
point: Vec.round([bounds.minX, bounds.minY]),
}
}
/**
* Find the bounds of a rectangular shape.
* @param shape
* @param boundsCache
*/
export function getBoundsRectangle<T extends TLShape & { size: number[] }>(
shape: T,
boundsCache: WeakMap<T, TLBounds>
) {
const bounds = Utils.getFromCache(boundsCache, shape, () => {
const [width, height] = shape.size
return {
minX: 0,
maxX: width,
minY: 0,
maxY: height,
width,
height,
}
})
return Utils.translateBounds(bounds, shape.point)
}

View file

@ -0,0 +1,23 @@
// Jest Snapshot v1, https://goo.gl/fbAQLP
exports[`Text shape Creates a shape: text 1`] = `
Object {
"childIndex": 1,
"id": "text",
"name": "Text",
"parentId": "page",
"point": Array [
-0.5,
-0.5,
],
"rotation": 0,
"style": Object {
"color": "Black",
"dash": "Draw",
"isFilled": false,
"size": "Medium",
},
"text": " ",
"type": "text",
}
`;

View file

@ -1,7 +1,7 @@
import { Text } from './text'
describe('Text shape', () => {
it('Creates an instance', () => {
new Text()
it('Creates a shape', () => {
expect(Text.create({ id: 'text' })).toMatchSnapshot('text')
})
})

View file

@ -1,19 +1,11 @@
/* eslint-disable @typescript-eslint/no-non-null-assertion */
import * as React from 'react'
import { HTMLContainer, TLBounds, Utils, TLTransformInfo } from '@tldraw/core'
import { HTMLContainer, TLBounds, Utils, TLTransformInfo, ShapeUtil } from '@tldraw/core'
import { Vec } from '@tldraw/vec'
import { getShapeStyle, getFontStyle, defaultStyle } from '~shape/shape-styles'
import {
TextShape,
TLDrawShapeUtil,
TLDrawShapeType,
TLDrawToolType,
ArrowShape,
TLDrawShapeProps,
} from '~types'
import { TextShape, TLDrawShapeType, TLDrawToolType, TLDrawMeta } from '~types'
import styled from '~styles'
import TextAreaUtils from './text-utils'
import { intersectPolylineBounds, intersectRayBounds } from '@tldraw/intersect'
const LETTER_SPACING = -1.5
@ -59,18 +51,20 @@ if (typeof window !== 'undefined') {
melm = getMeasurementDiv()
}
export class Text extends TLDrawShapeUtil<TextShape, HTMLDivElement> {
type = TLDrawShapeType.Text as const
toolType = TLDrawToolType.Text
isAspectRatioLocked = true
isEditableText = true
canBind = true
export const Text = new ShapeUtil<TextShape, HTMLDivElement, TLDrawMeta>(() => ({
type: TLDrawShapeType.Text,
pathCache = new WeakMap<number[], string>([])
toolType: TLDrawToolType.Text,
defaultProps = {
isAspectRatioLocked: true,
isEditableText: true,
canBind: true,
defaultProps: {
id: 'id',
type: TLDrawShapeType.Text as const,
type: TLDrawShapeType.Text,
name: 'Text',
parentId: 'page',
childIndex: 1,
@ -78,142 +72,135 @@ export class Text extends TLDrawShapeUtil<TextShape, HTMLDivElement> {
rotation: 0,
text: ' ',
style: defaultStyle,
}
},
create(props: Partial<TextShape>): TextShape {
create(props) {
const shape = { ...this.defaultProps, ...props }
const bounds = this.getBounds(shape)
shape.point = Vec.sub(shape.point, [bounds.width / 2, bounds.height / 2])
return shape
}
},
shouldRender(prev: TextShape, next: TextShape): boolean {
shouldRender(prev, next): boolean {
return (
next.text !== prev.text || next.style.scale !== prev.style.scale || next.style !== prev.style
)
}
},
render = React.forwardRef<HTMLDivElement, TLDrawShapeProps<TextShape, HTMLDivElement>>(
({ shape, meta, isEditing, isBinding, onShapeChange, onShapeBlur, events }, ref) => {
const rInput = React.useRef<HTMLTextAreaElement>(null)
const { text, style } = shape
const styles = getShapeStyle(style, meta.isDarkMode)
const font = getFontStyle(shape.style)
Component({ shape, meta, isEditing, isBinding, onShapeChange, onShapeBlur, events }, ref) {
const rInput = React.useRef<HTMLTextAreaElement>(null)
const { text, style } = shape
const styles = getShapeStyle(style, meta.isDarkMode)
const font = getFontStyle(shape.style)
const handleChange = React.useCallback(
(e: React.ChangeEvent<HTMLTextAreaElement>) => {
onShapeChange?.({ ...shape, text: normalizeText(e.currentTarget.value) })
},
[shape]
)
const handleKeyDown = React.useCallback(
(e: React.KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === 'Escape') return
e.stopPropagation()
if (e.key === 'Tab') {
e.preventDefault()
if (e.shiftKey) {
TextAreaUtils.unindent(e.currentTarget)
} else {
TextAreaUtils.indent(e.currentTarget)
}
const handleChange = React.useCallback(
(e: React.ChangeEvent<HTMLTextAreaElement>) => {
onShapeChange?.({ ...shape, text: normalizeText(e.currentTarget.value) })
},
[shape]
)
const handleKeyDown = React.useCallback(
(e: React.KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === 'Escape') return
e.stopPropagation()
if (e.key === 'Tab') {
e.preventDefault()
if (e.shiftKey) {
TextAreaUtils.unindent(e.currentTarget)
} else {
TextAreaUtils.indent(e.currentTarget)
}
onShapeChange?.({ ...shape, text: normalizeText(e.currentTarget.value) })
}
},
[shape, onShapeChange]
)
const handleBlur = React.useCallback(
(e: React.FocusEvent<HTMLTextAreaElement>) => {
e.currentTarget.setSelectionRange(0, 0)
onShapeBlur?.()
},
[isEditing, shape]
)
const handleFocus = React.useCallback(
(e: React.FocusEvent<HTMLTextAreaElement>) => {
if (!isEditing) return
if (document.activeElement === e.currentTarget) {
e.currentTarget.select()
}
},
[isEditing]
)
const handlePointerDown = React.useCallback(
(e) => {
if (isEditing) {
e.stopPropagation()
}
},
[isEditing]
)
React.useEffect(() => {
if (isEditing) {
setTimeout(() => {
const elm = rInput.current!
elm.focus()
elm.select()
}, 0)
} else {
const elm = rInput.current!
elm.setSelectionRange(0, 0)
}
}, [isEditing])
},
[shape, onShapeChange]
)
return (
<HTMLContainer ref={ref} {...events}>
<StyledWrapper isEditing={isEditing} onPointerDown={handlePointerDown}>
<StyledTextArea
ref={rInput}
style={{
font,
color: styles.stroke,
}}
name="text"
defaultValue={text}
tabIndex={-1}
autoComplete="false"
autoCapitalize="false"
autoCorrect="false"
autoSave="false"
placeholder=""
color={styles.stroke}
onFocus={handleFocus}
onBlur={handleBlur}
onChange={handleChange}
onKeyDown={handleKeyDown}
onPointerDown={handlePointerDown}
autoFocus={isEditing}
isEditing={isEditing}
isBinding={isBinding}
readOnly={!isEditing}
wrap="off"
dir="auto"
datatype="wysiwyg"
/>
</StyledWrapper>
</HTMLContainer>
)
}
)
const handleBlur = React.useCallback(
(e: React.FocusEvent<HTMLTextAreaElement>) => {
e.currentTarget.setSelectionRange(0, 0)
onShapeBlur?.()
},
[isEditing, shape]
)
renderIndicator(): JSX.Element | null {
const handleFocus = React.useCallback(
(e: React.FocusEvent<HTMLTextAreaElement>) => {
if (!isEditing) return
if (document.activeElement === e.currentTarget) {
e.currentTarget.select()
}
},
[isEditing]
)
const handlePointerDown = React.useCallback(
(e) => {
if (isEditing) {
e.stopPropagation()
}
},
[isEditing]
)
React.useEffect(() => {
if (isEditing) {
setTimeout(() => {
const elm = rInput.current!
elm.focus()
elm.select()
}, 0)
} else {
const elm = rInput.current!
elm.setSelectionRange(0, 0)
}
}, [isEditing])
return (
<HTMLContainer ref={ref} {...events}>
<StyledWrapper isEditing={isEditing} onPointerDown={handlePointerDown}>
<StyledTextArea
ref={rInput}
style={{
font,
color: styles.stroke,
}}
name="text"
defaultValue={text}
tabIndex={-1}
autoComplete="false"
autoCapitalize="false"
autoCorrect="false"
autoSave="false"
placeholder=""
color={styles.stroke}
onFocus={handleFocus}
onBlur={handleBlur}
onChange={handleChange}
onKeyDown={handleKeyDown}
onPointerDown={handlePointerDown}
autoFocus={isEditing}
isEditing={isEditing}
isBinding={isBinding}
readOnly={!isEditing}
wrap="off"
dir="auto"
datatype="wysiwyg"
/>
</StyledWrapper>
</HTMLContainer>
)
},
Indicator() {
return null
// if (isEditing) return null
},
// const { width, height } = this.getBounds(shape)
// return <rect className="tl-selected" width={width} height={height} />
}
getBounds(shape: TextShape): TLBounds {
getBounds(shape): TLBounds {
const bounds = Utils.getFromCache(this.boundsCache, shape, () => {
if (!melm) {
// We're in SSR
@ -238,34 +225,13 @@ export class Text extends TLDrawShapeUtil<TextShape, HTMLDivElement> {
})
return Utils.translateBounds(bounds, shape.point)
}
getRotatedBounds(shape: TextShape): TLBounds {
return Utils.getBoundsFromPoints(Utils.getRotatedCorners(this.getBounds(shape), shape.rotation))
}
getCenter(shape: TextShape): number[] {
return Utils.getBoundsCenter(this.getBounds(shape))
}
hitTest(shape: TextShape, point: number[]): boolean {
return Utils.pointInBounds(point, this.getBounds(shape))
}
hitTestBounds(shape: TextShape, bounds: TLBounds): boolean {
const rotatedCorners = Utils.getRotatedCorners(this.getBounds(shape), shape.rotation)
return (
rotatedCorners.every((point) => Utils.pointInBounds(point, bounds)) ||
intersectPolylineBounds(rotatedCorners, bounds).length > 0
)
}
},
transform(
_shape: TextShape,
_shape,
bounds: TLBounds,
{ initialShape, scaleX, scaleY }: TLTransformInfo<TextShape>
): Partial<TextShape> {
) {
const {
rotation = 0,
style: { scale = 1 },
@ -282,13 +248,13 @@ export class Text extends TLDrawShapeUtil<TextShape, HTMLDivElement> {
scale: nextScale,
},
}
}
},
transformSingle(
_shape: TextShape,
_shape,
bounds: TLBounds,
{ initialShape, scaleX, scaleY }: TLTransformInfo<TextShape>
): Partial<TextShape> {
) {
const {
style: { scale = 1 },
} = initialShape
@ -300,9 +266,9 @@ export class Text extends TLDrawShapeUtil<TextShape, HTMLDivElement> {
scale: scale * Math.max(Math.abs(scaleY), Math.abs(scaleX)),
},
}
}
},
onBoundsReset(shape: TextShape): Partial<TextShape> {
onDoubleClickBoundsHandle(shape) {
const center = this.getCenter(shape)
const newCenter = this.getCenter({
@ -320,9 +286,9 @@ export class Text extends TLDrawShapeUtil<TextShape, HTMLDivElement> {
},
point: Vec.round(Vec.add(shape.point, Vec.sub(center, newCenter))),
}
}
},
onStyleChange(shape: TextShape): Partial<TextShape> {
onStyleChange(shape) {
const center = this.getCenter(shape)
this.boundsCache.delete(shape)
@ -332,88 +298,12 @@ export class Text extends TLDrawShapeUtil<TextShape, HTMLDivElement> {
return {
point: Vec.round(Vec.add(shape.point, Vec.sub(center, newCenter))),
}
}
},
}))
shouldDelete(shape: TextShape): boolean {
return shape.text.trim().length === 0
}
getBindingPoint(
shape: TextShape,
fromShape: ArrowShape,
point: number[],
origin: number[],
direction: number[],
padding: number,
anywhere: boolean
) {
const bounds = this.getBounds(shape)
const expandedBounds = Utils.expandBounds(bounds, padding)
let bindingPoint: number[]
let distance: number
// The point must be inside of the expanded bounding box
if (!Utils.pointInBounds(point, expandedBounds)) return
// The point is inside of the shape, so we'll assume the user is
// indicating a specific point inside of the shape.
if (anywhere) {
if (Vec.dist(point, this.getCenter(shape)) < 12) {
bindingPoint = [0.5, 0.5]
} else {
bindingPoint = Vec.divV(Vec.sub(point, [expandedBounds.minX, expandedBounds.minY]), [
expandedBounds.width,
expandedBounds.height,
])
}
distance = 0
} else {
// Find furthest intersection between ray from
// origin through point and expanded bounds.
// TODO: Make this a ray vs rounded rect intersection
const intersection = intersectRayBounds(origin, direction, expandedBounds)
.filter((int) => int.didIntersect)
.map((int) => int.points[0])
.sort((a, b) => Vec.dist(b, origin) - Vec.dist(a, origin))[0]
// The anchor is a point between the handle and the intersection
const anchor = Vec.med(point, intersection)
// If we're close to the center, snap to the center
if (Vec.distanceToLineSegment(point, anchor, this.getCenter(shape)) < 12) {
bindingPoint = [0.5, 0.5]
} else {
// Or else calculate a normalized point
bindingPoint = Vec.divV(Vec.sub(anchor, [expandedBounds.minX, expandedBounds.minY]), [
expandedBounds.width,
expandedBounds.height,
])
}
if (Utils.pointInBounds(point, bounds)) {
distance = 16
} else {
// If the binding point was close to the shape's center, snap to the center
// Find the distance between the point and the real bounds of the shape
distance = Math.max(
16,
Utils.getBoundsSides(bounds)
.map((side) => Vec.distanceToLineSegment(side[1][0], side[1][1], point))
.sort((a, b) => a - b)[0]
)
}
}
return {
point: Vec.clampV(bindingPoint, 0, 1),
distance,
}
}
}
/* -------------------------------------------------- */
/* Helpers */
/* -------------------------------------------------- */
const StyledWrapper = styled('div', {
width: '100%',

View file

@ -1,7 +1,7 @@
import { TLDR } from '~state/tldr'
import { TLDrawState } from '~state'
import { mockDocument } from '~test'
import type { TLDrawShape } from '~types'
import { TLDrawShape, TLDrawShapeType } from '~types'
describe('Delete command', () => {
const tlstate = new TLDrawState()
@ -61,12 +61,10 @@ describe('Delete command', () => {
tlstate
.deselectAll()
.create(
TLDR.getShapeUtils({ type: 'arrow' } as TLDrawShape).create({
id: 'arrow1',
parentId: 'page1',
})
)
.createShapes({
id: 'arrow1',
type: TLDrawShapeType.Arrow,
})
.select('arrow1')
.startHandleSession([0, 0], 'start')
.updateHandleSession([110, 110])
@ -77,7 +75,7 @@ describe('Delete command', () => {
expect(binding).toBeTruthy()
expect(binding.fromId).toBe('arrow1')
expect(binding.toId).toBe('rect3')
expect(binding.handleId).toBe('start')
expect(binding.meta.handleId).toBe('start')
expect(tlstate.getShape('arrow1').handles?.start.bindingId).toBe(binding.id)
tlstate.select('rect3').delete()

View file

@ -76,7 +76,7 @@ export function group(
// Create the group
beforeShapes[groupId] = undefined
afterShapes[groupId] = TLDR.getShapeUtils({ type: TLDrawShapeType.Group } as TLDrawShape).create({
afterShapes[groupId] = TLDR.getShapeUtils(TLDrawShapeType.Group).create({
id: groupId,
childIndex: groupChildIndex,
parentId: groupParentId,

View file

@ -1,11 +1,12 @@
import type { Data, TLDrawCommand } from '~types'
import { TLDR } from '~state/tldr'
import { TLBoundsEdge } from '~../../core/src/types'
export function resetBounds(data: Data, ids: string[], pageId: string): TLDrawCommand {
const { before, after } = TLDR.mutateShapes(
data,
ids,
(shape) => TLDR.getShapeUtils(shape).onBoundsReset(shape),
(shape) => TLDR.getShapeUtils(shape).onDoubleClickBoundsHandle(shape),
pageId
)

View file

@ -1,7 +1,7 @@
import { TLDR } from '~state/tldr'
import { TLDrawState } from '~state'
import { mockDocument } from '~test'
import { ArrowShape, Decoration, TLDrawShape } from '~types'
import { ArrowShape, Decoration, TLDrawShape, TLDrawShapeType } from '~types'
describe('Toggle decoration command', () => {
const tlstate = new TLDrawState()
@ -32,12 +32,10 @@ describe('Toggle decoration command', () => {
it('does, undoes and redoes command', () => {
tlstate
.create(
TLDR.getShapeUtils({ type: 'arrow' } as TLDrawShape).create({
id: 'arrow1',
parentId: 'page1',
})
)
.createShapes({
id: 'arrow1',
type: TLDrawShapeType.Arrow,
})
.select('arrow1')
expect(tlstate.getShape<ArrowShape>('arrow1').decorations?.end).toBe(Decoration.Arrow)

View file

@ -28,7 +28,7 @@ describe('Arrow session', () => {
expect(binding).toBeTruthy()
expect(binding.fromId).toBe('arrow1')
expect(binding.toId).toBe('target1')
expect(binding.handleId).toBe('start')
expect(binding.meta.handleId).toBe('start')
expect(tlstate.appState.status.current).toBe(TLDrawStatus.Idle)
expect(tlstate.getShape('arrow1').handles?.start.bindingId).toBe(binding.id)
@ -62,7 +62,7 @@ describe('Arrow session', () => {
.select('arrow1')
.startHandleSession([200, 200], 'start')
.updateHandleSession([50, 50])
expect(tlstate.bindings[0].point).toStrictEqual([0.5, 0.5])
expect(tlstate.bindings[0].meta.point).toStrictEqual([0.5, 0.5])
})
it('Snaps to the center', () => {
@ -71,7 +71,7 @@ describe('Arrow session', () => {
.select('arrow1')
.startHandleSession([200, 200], 'start')
.updateHandleSession([55, 55])
expect(tlstate.bindings[0].point).toStrictEqual([0.5, 0.5])
expect(tlstate.bindings[0].meta.point).toStrictEqual([0.5, 0.5])
})
it('Binds at the bottom left', () => {
@ -80,7 +80,7 @@ describe('Arrow session', () => {
.select('arrow1')
.startHandleSession([200, 200], 'start')
.updateHandleSession([132, -32])
expect(tlstate.bindings[0].point).toStrictEqual([1, 0])
expect(tlstate.bindings[0].meta.point).toStrictEqual([1, 0])
})
it('Cancels the bind when off of the expanded bounds', () => {
@ -100,7 +100,7 @@ describe('Arrow session', () => {
.startHandleSession([200, 200], 'start')
.updateHandleSession([91, 9])
expect(tlstate.bindings[0].point).toStrictEqual([0.68, 0.13])
expect(tlstate.bindings[0].meta.point).toStrictEqual([0.68, 0.13])
tlstate.updateHandleSession([91, 9], false, false, true)
})
@ -112,11 +112,11 @@ describe('Arrow session', () => {
.startHandleSession([200, 200], 'start')
.updateHandleSession([91, 9])
expect(tlstate.bindings[0].point).toStrictEqual([0.68, 0.13])
expect(tlstate.bindings[0].meta.point).toStrictEqual([0.68, 0.13])
tlstate.updateHandleSession([91, 9], false, false, true)
expect(tlstate.bindings[0].point).toStrictEqual([0.75, 0.25])
expect(tlstate.bindings[0].meta.point).toStrictEqual([0.75, 0.25])
})
it('ignores binding when alt is held', () => {
@ -126,11 +126,11 @@ describe('Arrow session', () => {
.startHandleSession([200, 200], 'start')
.updateHandleSession([55, 45])
expect(tlstate.bindings[0].point).toStrictEqual([0.5, 0.5])
expect(tlstate.bindings[0].meta.point).toStrictEqual([0.5, 0.5])
tlstate.updateHandleSession([55, 45], false, false, true)
expect(tlstate.bindings[0].point).toStrictEqual([0.5, 0.5])
expect(tlstate.bindings[0].meta.point).toStrictEqual([0.5, 0.5])
})
})

View file

@ -66,7 +66,7 @@ export class ArrowSession implements Session {
}
// First update the handle's next point
const change = TLDR.getShapeUtils(shape).onHandleChange(
const change = TLDR.getShapeUtils<ArrowShape>(shape.type).onHandleChange(
shape,
{
[handleId]: handle,
@ -77,7 +77,7 @@ export class ArrowSession implements Session {
// If the handle changed produced no change, bail here
if (!change) return
// If we've made it this far, the shape should be a new objet reference
// If we've made it this far, the shape should be a new object reference
// that incorporates the changes we've made due to the handle movement.
let nextShape = { ...shape, ...change }
@ -124,7 +124,7 @@ export class ArrowSession implements Session {
target = TLDR.getShape(data, id, data.appState.currentPageId)
const util = TLDR.getShapeUtils(target)
const util = TLDR.getShapeUtils<TLDrawShape>(target.type)
const bindingPoint = util.getBindingPoint(
target,
@ -143,10 +143,12 @@ export class ArrowSession implements Session {
id: this.newBindingId,
type: 'arrow',
fromId: initialShape.id,
handleId: this.handleId,
toId: target.id,
point: Vec.round(bindingPoint.point),
distance: bindingPoint.distance,
meta: {
handleId: this.handleId,
point: Vec.round(bindingPoint.point),
distance: bindingPoint.distance,
},
}
break
@ -191,7 +193,7 @@ export class ArrowSession implements Session {
// Now update the arrow in response to the new binding
const targetUtils = TLDR.getShapeUtils(target)
const arrowChange = TLDR.getShapeUtils(nextShape).onBindingChange(
const arrowChange = TLDR.getShapeUtils<ArrowShape>(nextShape.type).onBindingChange(
nextShape,
binding,
target,
@ -300,9 +302,7 @@ export class ArrowSession implements Session {
[data.appState.currentPageId]: {
shapes: {
[initialShape.id]: TLDR.onSessionComplete(
data,
TLDR.getShape(data, initialShape.id, data.appState.currentPageId),
data.appState.currentPageId
TLDR.getShape(data, initialShape.id, data.appState.currentPageId)
),
},
bindings: afterBindings,

View file

@ -123,8 +123,8 @@ export function getBrushSnapshot(data: Data) {
)
.map((shape) => ({
id: shape.id,
util: getShapeUtils(shape),
bounds: getShapeUtils(shape).getBounds(shape),
util: TLDR.getShapeUtils(shape),
bounds: TLDR.getShapeUtils(shape).getBounds(shape),
selectId: TLDR.getTopParentId(data, shape.id, currentPageId),
}))

View file

@ -1,7 +1,7 @@
import { TLDrawState } from '~state'
import { mockDocument } from '~test'
import { TLDR } from '~state/tldr'
import { TLDrawShape, TLDrawStatus } from '~types'
import { TLDrawShape, TLDrawShapeType, TLDrawStatus } from '~types'
describe('Handle session', () => {
const tlstate = new TLDrawState()
@ -9,12 +9,10 @@ describe('Handle session', () => {
it('begins, updates and completes session', () => {
tlstate
.loadDocument(mockDocument)
.create(
TLDR.getShapeUtils({ type: 'arrow' } as TLDrawShape).create({
id: 'arrow1',
parentId: 'page1',
})
)
.createShapes({
id: 'arrow1',
type: TLDrawShapeType.Arrow,
})
.select('arrow1')
.startHandleSession([-10, -10], 'end')
.updateHandleSession([10, 10])
@ -28,10 +26,9 @@ describe('Handle session', () => {
it('cancels session', () => {
tlstate
.loadDocument(mockDocument)
.create({
...TLDR.getShapeUtils({ type: 'arrow' } as TLDrawShape).defaultProps,
.createShapes({
type: TLDrawShapeType.Arrow,
id: 'arrow1',
parentId: 'page1',
})
.select('arrow1')
.startHandleSession([-10, -10], 'end')

View file

@ -106,9 +106,7 @@ export class HandleSession implements Session {
[pageId]: {
shapes: {
[initialShape.id]: TLDR.onSessionComplete(
data,
TLDR.getShape(data, this.initialShape.id, pageId),
pageId
TLDR.getShape(data, this.initialShape.id, pageId)
),
},
},

View file

@ -9,12 +9,10 @@ describe('Text session', () => {
it('begins, updates and completes session', () => {
tlstate
.loadDocument(mockDocument)
.create(
TLDR.getShapeUtils({ type: TLDrawShapeType.Text } as TLDrawShape).create({
id: 'text1',
parentId: 'page1',
})
)
.createShapes({
id: 'text1',
type: TLDrawShapeType.Text,
})
.select('text1')
.startTextSession('text1')
.updateTextSession('Hello world')
@ -28,12 +26,10 @@ describe('Text session', () => {
it('cancels session', () => {
tlstate
.loadDocument(mockDocument)
.create(
TLDR.getShapeUtils({ type: TLDrawShapeType.Text } as TLDrawShape).create({
id: 'text1',
parentId: 'page1',
})
)
.createShapes({
id: 'text1',
type: TLDrawShapeType.Text,
})
.select('text1')
.startTextSession('text1')
.updateTextSession('Hello world')

View file

@ -65,9 +65,7 @@ export class TextSession implements Session {
[pageId]: {
shapes: {
[initialShape.id]: TLDR.onSessionComplete(
data,
TLDR.getShape(data, initialShape.id, pageId),
pageId
TLDR.getShape(data, initialShape.id, pageId)
),
},
},
@ -157,9 +155,7 @@ export class TextSession implements Session {
[pageId]: {
shapes: {
[initialShape.id]: TLDR.onSessionComplete(
data,
TLDR.getShape(data, initialShape.id, pageId),
pageId
TLDR.getShape(data, initialShape.id, pageId)
),
},
},

View file

@ -99,9 +99,7 @@ export class TransformSingleSession implements Session {
beforeShapes[initialShape.id] = initialShape
afterShapes[initialShape.id] = TLDR.onSessionComplete(
data,
TLDR.getShape(data, initialShape.id, data.appState.currentPageId),
data.appState.currentPageId
TLDR.getShape(data, initialShape.id, data.appState.currentPageId)
)
return {

View file

@ -59,7 +59,6 @@ export class TransformSession implements Session {
)
shapes[id] = TLDR.transform(
data,
TLDR.getShape(data, id, data.appState.currentPageId),
newShapeBounds,
{
@ -68,8 +67,7 @@ export class TransformSession implements Session {
scaleX: this.scaleX,
scaleY: this.scaleY,
transformOrigin,
},
data.appState.currentPageId
}
)
})

View file

@ -5,35 +5,37 @@ import type {
ShapeStyles,
ShapesWithProp,
TLDrawShape,
TLDrawShapeUtil,
TLDrawBinding,
TLDrawPage,
TLDrawCommand,
TLDrawPatch,
TLDrawShapeUtil,
} from '~types'
import { Vec } from '@tldraw/vec'
export class TLDR {
static getShapeUtils<T extends TLDrawShape>(
shape: T | T['type']
): TLDrawShapeUtil<T, HTMLElement | SVGElement> {
return getShapeUtils(typeof shape === 'string' ? ({ type: shape } as T) : shape)
// eslint-disable-next-line @typescript-eslint/no-explicit-any
static getShapeUtils<T extends TLDrawShape>(type: T['type']): TLDrawShapeUtil<T, any>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
static getShapeUtils<T extends TLDrawShape>(shape: T): TLDrawShapeUtil<T, any>
static getShapeUtils<T extends TLDrawShape>(shape: T | T['type']) {
return getShapeUtils<T>(typeof shape === 'string' ? shape : shape.type)
}
static getSelectedShapes(data: Data, pageId: string) {
const page = this.getPage(data, pageId)
const selectedIds = this.getSelectedIds(data, pageId)
const page = TLDR.getPage(data, pageId)
const selectedIds = TLDR.getSelectedIds(data, pageId)
return selectedIds.map((id) => page.shapes[id])
}
static screenToWorld(data: Data, point: number[]) {
const camera = this.getPageState(data, data.appState.currentPageId).camera
const camera = TLDR.getPageState(data, data.appState.currentPageId).camera
return Vec.sub(Vec.div(point, camera.zoom), camera.point)
}
static getViewport(data: Data): TLBounds {
const [minX, minY] = this.screenToWorld(data, [0, 0])
const [maxX, maxY] = this.screenToWorld(data, [window.innerWidth, window.innerHeight])
const [minX, minY] = TLDR.screenToWorld(data, [0, 0])
const [maxX, maxY] = TLDR.screenToWorld(data, [window.innerWidth, window.innerHeight])
return {
minX,
@ -58,15 +60,15 @@ export class TLDR {
}
static getSelectedIds(data: Data, pageId: string): string[] {
return this.getPageState(data, pageId).selectedIds
return TLDR.getPageState(data, pageId).selectedIds
}
static getShapes(data: Data, pageId: string): TLDrawShape[] {
return Object.values(this.getPage(data, pageId).shapes)
return Object.values(TLDR.getPage(data, pageId).shapes)
}
static getCamera(data: Data, pageId: string): TLPageState['camera'] {
return this.getPageState(data, pageId).camera
return TLDR.getPageState(data, pageId).camera
}
static getShape<T extends TLDrawShape = TLDrawShape>(
@ -74,56 +76,56 @@ export class TLDR {
shapeId: string,
pageId: string
): T {
return this.getPage(data, pageId).shapes[shapeId] as T
return TLDR.getPage(data, pageId).shapes[shapeId] as T
}
static getBounds<T extends TLDrawShape>(shape: T) {
return getShapeUtils(shape).getBounds(shape)
return TLDR.getShapeUtils(shape).getBounds(shape)
}
static getRotatedBounds<T extends TLDrawShape>(shape: T) {
return getShapeUtils(shape).getRotatedBounds(shape)
return TLDR.getShapeUtils(shape).getRotatedBounds(shape)
}
static getSelectedBounds(data: Data): TLBounds {
return Utils.getCommonBounds(
this.getSelectedShapes(data, data.appState.currentPageId).map((shape) =>
getShapeUtils(shape).getBounds(shape)
TLDR.getSelectedShapes(data, data.appState.currentPageId).map((shape) =>
TLDR.getShapeUtils(shape).getBounds(shape)
)
)
}
static getParentId(data: Data, id: string, pageId: string) {
return this.getShape(data, id, pageId).parentId
return TLDR.getShape(data, id, pageId).parentId
}
static getPointedId(data: Data, id: string, pageId: string): string {
const page = this.getPage(data, pageId)
const pageState = this.getPageState(data, data.appState.currentPageId)
const shape = this.getShape(data, id, pageId)
const page = TLDR.getPage(data, pageId)
const pageState = TLDR.getPageState(data, data.appState.currentPageId)
const shape = TLDR.getShape(data, id, pageId)
if (!shape) return id
return shape.parentId === pageState.currentParentId || shape.parentId === page.id
? id
: this.getPointedId(data, shape.parentId, pageId)
: TLDR.getPointedId(data, shape.parentId, pageId)
}
static getDrilledPointedId(data: Data, id: string, pageId: string): string {
const shape = this.getShape(data, id, pageId)
const shape = TLDR.getShape(data, id, pageId)
const { currentPageId } = data.appState
const { currentParentId, pointedId } = this.getPageState(data, data.appState.currentPageId)
const { currentParentId, pointedId } = TLDR.getPageState(data, data.appState.currentPageId)
return shape.parentId === currentPageId ||
shape.parentId === pointedId ||
shape.parentId === currentParentId
? id
: this.getDrilledPointedId(data, shape.parentId, pageId)
: TLDR.getDrilledPointedId(data, shape.parentId, pageId)
}
static getTopParentId(data: Data, id: string, pageId: string): string {
const page = this.getPage(data, pageId)
const pageState = this.getPageState(data, pageId)
const shape = this.getShape(data, id, pageId)
const page = TLDR.getPage(data, pageId)
const pageState = TLDR.getPageState(data, pageId)
const shape = TLDR.getShape(data, id, pageId)
if (shape.parentId === shape.id) {
throw Error(`Shape has the same id as its parent! ${shape.id}`)
@ -131,18 +133,18 @@ export class TLDR {
return shape.parentId === page.id || shape.parentId === pageState.currentParentId
? id
: this.getTopParentId(data, shape.parentId, pageId)
: TLDR.getTopParentId(data, shape.parentId, pageId)
}
// Get an array of a shape id and its descendant shapes' ids
static getDocumentBranch(data: Data, id: string, pageId: string): string[] {
const shape = this.getShape(data, id, pageId)
const shape = TLDR.getShape(data, id, pageId)
if (shape.children === undefined) return [id]
return [
id,
...shape.children.flatMap((childId) => this.getDocumentBranch(data, childId, pageId)),
...shape.children.flatMap((childId) => TLDR.getDocumentBranch(data, childId, pageId)),
]
}
@ -158,10 +160,10 @@ export class TLDR {
pageId: string,
fn?: (shape: TLDrawShape) => K
): (TLDrawShape | K)[] {
const page = this.getPage(data, pageId)
const page = TLDR.getPage(data, pageId)
const copies = this.getSelectedIds(data, pageId)
.flatMap((id) => this.getDocumentBranch(data, id, pageId).map((id) => page.shapes[id]))
const copies = TLDR.getSelectedIds(data, pageId)
.flatMap((id) => TLDR.getDocumentBranch(data, id, pageId).map((id) => page.shapes[id]))
.filter((shape) => !shape.isLocked)
.map(Utils.deepClone)
@ -184,7 +186,7 @@ export class TLDR {
pageId: string,
fn?: (shape: TLDrawShape) => K
): (TLDrawShape | K)[] {
const copies = this.getSelectedShapes(data, pageId)
const copies = TLDR.getSelectedShapes(data, pageId)
.filter((shape) => !shape.isLocked)
.map(Utils.deepClone)
@ -198,7 +200,7 @@ export class TLDR {
// For a given array of shape ids, an array of all other shapes that may be affected by a mutation to it.
// Use this to decide which shapes to clone as before / after for a command.
static getAllEffectedShapeIds(data: Data, ids: string[], pageId: string): string[] {
const page = this.getPage(data, pageId)
const page = TLDR.getPage(data, pageId)
const visited = new Set(ids)
@ -241,105 +243,6 @@ export class TLDR {
return Array.from(visited.values())
}
static recursivelyUpdateChildren<T extends TLDrawShape>(
data: Data,
id: string,
beforeShapes: Record<string, Partial<TLDrawShape>> = {},
afterShapes: Record<string, Partial<TLDrawShape>> = {},
pageId: string
): Data {
const page = this.getPage(data, pageId)
const shape = page.shapes[id] as T
if (shape.children !== undefined) {
const deltas = this.getShapeUtils(shape).updateChildren(
shape,
shape.children.map((childId) => page.shapes[childId])
)
if (deltas) {
return deltas.reduce<Data>((cData, delta) => {
if (!delta.id) throw Error('Delta must include an id!')
const cPage = this.getPage(cData, pageId)
const deltaShape = this.getShape(cData, delta.id, pageId)
if (!beforeShapes[delta.id]) {
beforeShapes[delta.id] = deltaShape
}
cPage.shapes[delta.id] = this.getShapeUtils(deltaShape).mutate(deltaShape, delta)
afterShapes[delta.id] = cPage.shapes[delta.id]
if (deltaShape.children !== undefined) {
this.recursivelyUpdateChildren(cData, delta.id, beforeShapes, afterShapes, pageId)
}
return cData
}, data)
}
}
return data
}
static recursivelyUpdateParents<T extends TLDrawShape>(
data: Data,
id: string,
beforeShapes: Record<string, Partial<TLDrawShape>> = {},
afterShapes: Record<string, Partial<TLDrawShape>> = {},
pageId: string
): Data {
const page = { ...this.getPage(data, pageId) }
const shape = this.getShape<T>(data, id, pageId)
if (page.id === 'doc') {
throw Error('wtf')
}
if (shape.parentId !== page.id) {
const parent = this.getShape(data, shape.parentId, pageId)
if (!parent.children) throw Error('No children in parent!')
const delta = this.getShapeUtils(parent).onChildrenChange(
parent,
parent.children.map((childId) => this.getShape(data, childId, pageId))
)
if (delta) {
if (!beforeShapes[parent.id]) {
beforeShapes[parent.id] = parent
}
page.shapes[parent.id] = this.getShapeUtils(parent).mutate(parent, delta)
afterShapes[parent.id] = page.shapes[parent.id]
}
if (parent.parentId !== page.id) {
return this.recursivelyUpdateParents(
data,
parent.parentId,
beforeShapes,
afterShapes,
pageId
)
}
}
if (data.appState.currentPageId === 'doc') {
console.error('WTF?')
}
return {
...data,
document: {
...data.document,
pages: {
...data.document.pages,
[page.id]: page,
},
},
}
}
static updateBindings(
data: Data,
id: string,
@ -347,30 +250,28 @@ export class TLDR {
afterShapes: Record<string, Partial<TLDrawShape>> = {},
pageId: string
): Data {
const page = { ...this.getPage(data, pageId) }
const page = { ...TLDR.getPage(data, pageId) }
return Object.values(page.bindings)
.filter((binding) => binding.fromId === id || binding.toId === id)
.reduce((cData, binding) => {
if (!beforeShapes[binding.fromId]) {
beforeShapes[binding.fromId] = Utils.deepClone(
this.getShape(cData, binding.fromId, pageId)
TLDR.getShape(cData, binding.fromId, pageId)
)
}
if (!beforeShapes[binding.toId]) {
beforeShapes[binding.toId] = Utils.deepClone(this.getShape(cData, binding.toId, pageId))
beforeShapes[binding.toId] = Utils.deepClone(TLDR.getShape(cData, binding.toId, pageId))
}
this.onBindingChange(
cData,
this.getShape(cData, binding.fromId, pageId),
TLDR.onBindingChange(
TLDR.getShape(cData, binding.fromId, pageId),
binding,
this.getShape(cData, binding.toId, pageId),
pageId
TLDR.getShape(cData, binding.toId, pageId)
)
afterShapes[binding.fromId] = Utils.deepClone(this.getShape(cData, binding.fromId, pageId))
afterShapes[binding.toId] = Utils.deepClone(this.getShape(cData, binding.toId, pageId))
afterShapes[binding.fromId] = Utils.deepClone(TLDR.getShape(cData, binding.fromId, pageId))
afterShapes[binding.toId] = Utils.deepClone(TLDR.getShape(cData, binding.toId, pageId))
return cData
}, data)
@ -421,7 +322,7 @@ export class TLDR {
const afterShapes: Record<string, Partial<T>> = {}
ids.forEach((id, i) => {
const shape = this.getShape<T>(data, id, pageId)
const shape = TLDR.getShape<T>(data, id, pageId)
const change = fn(shape, i)
if (change) {
beforeShapes[id] = Object.fromEntries(
@ -440,18 +341,9 @@ export class TLDR {
},
},
})
const dataWithChildrenChanges = ids.reduce<Data>((cData, id) => {
return this.recursivelyUpdateChildren(cData, id, beforeShapes, afterShapes, pageId)
}, dataWithMutations)
const dataWithParentChanges = ids.reduce<Data>((cData, id) => {
return this.recursivelyUpdateParents(cData, id, beforeShapes, afterShapes, pageId)
}, dataWithChildrenChanges)
const dataWithBindingChanges = ids.reduce<Data>((cData, id) => {
return this.updateBindings(cData, id, beforeShapes, afterShapes, pageId)
}, dataWithParentChanges)
return TLDR.updateBindings(cData, id, beforeShapes, afterShapes, pageId)
}, dataWithMutations)
return {
before: beforeShapes,
@ -474,7 +366,7 @@ export class TLDR {
// If the shape is a child of another shape, also save that shape
if (shape.parentId !== pageId) {
const parent = this.getShape(data, shape.parentId, pageId)
const parent = TLDR.getShape(data, shape.parentId, pageId)
if (!parent.children) throw Error('No children in parent!')
results.push([parent.id, { children: parent.children }])
}
@ -502,7 +394,7 @@ export class TLDR {
// If the shape is a child of a different shape, update its parent
if (shape.parentId !== pageId) {
const parent = this.getShape(data, shape.parentId, pageId)
const parent = TLDR.getShape(data, shape.parentId, pageId)
if (!parent.children) throw Error('No children in parent!')
results.push([parent.id, { children: [...parent.children, shape.id] }])
}
@ -530,7 +422,7 @@ export class TLDR {
): TLDrawCommand {
pageId = pageId ? pageId : data.appState.currentPageId
const page = this.getPage(data, pageId)
const page = TLDR.getPage(data, pageId)
const shapeIds =
typeof shapes[0] === 'string'
@ -615,73 +507,56 @@ export class TLDR {
}
}
static mutate<T extends TLDrawShape>(data: Data, shape: T, props: Partial<T>, pageId: string) {
let next = getShapeUtils(shape).mutate(shape, props)
if (props.children) {
next = this.onChildrenChange(data, next, pageId) || next
}
return next
}
static onSessionComplete<T extends TLDrawShape>(data: Data, shape: T, pageId: string) {
const delta = getShapeUtils(shape).onSessionComplete(shape)
static onSessionComplete<T extends TLDrawShape>(shape: T) {
const delta = TLDR.getShapeUtils(shape).onSessionComplete(shape)
if (!delta) return shape
return this.mutate(data, shape, delta, pageId)
return { ...shape, ...delta }
}
static onChildrenChange<T extends TLDrawShape>(data: Data, shape: T, pageId: string) {
if (!shape.children) return
const delta = getShapeUtils(shape).onChildrenChange(
const delta = TLDR.getShapeUtils(shape).onChildrenChange(
shape,
shape.children.map((id) => this.getShape(data, id, pageId))
shape.children.map((id) => TLDR.getShape(data, id, pageId))
)
if (!delta) return shape
return this.mutate(data, shape, delta, pageId)
return { ...shape, ...delta }
}
static onBindingChange<T extends TLDrawShape>(
data: Data,
shape: T,
binding: TLDrawBinding,
otherShape: TLDrawShape,
pageId: string
otherShape: TLDrawShape
) {
const delta = getShapeUtils(shape).onBindingChange(
const delta = TLDR.getShapeUtils(shape).onBindingChange(
shape,
binding,
otherShape,
getShapeUtils(otherShape).getBounds(otherShape),
getShapeUtils(otherShape).getCenter(otherShape)
TLDR.getShapeUtils(otherShape).getBounds(otherShape),
TLDR.getShapeUtils(otherShape).getCenter(otherShape)
)
if (!delta) return shape
return this.mutate(data, shape, delta, pageId)
return { ...shape, ...delta }
}
static transform<T extends TLDrawShape>(
data: Data,
shape: T,
bounds: TLBounds,
info: TLTransformInfo<T>,
pageId: string
) {
const change = getShapeUtils(shape).transform(shape, bounds, info)
if (!change) return shape
return this.mutate(data, shape, change, pageId)
static transform<T extends TLDrawShape>(shape: T, bounds: TLBounds, info: TLTransformInfo<T>) {
const delta = TLDR.getShapeUtils(shape).transform(shape, bounds, info)
if (!delta) return shape
return { ...shape, ...delta }
}
static transformSingle<T extends TLDrawShape>(
data: Data,
shape: T,
bounds: TLBounds,
info: TLTransformInfo<T>,
pageId: string
info: TLTransformInfo<T>
) {
const change = getShapeUtils(shape).transformSingle(shape, bounds, info)
if (!change) return shape
return this.mutate(data, shape, change, pageId)
const delta = TLDR.getShapeUtils(shape).transformSingle(shape, bounds, info)
if (!delta) return shape
return { ...shape, ...delta }
}
/* -------------------------------------------------- */
@ -689,11 +564,11 @@ export class TLDR {
/* -------------------------------------------------- */
static updateParents(data: Data, pageId: string, changedShapeIds: string[]): void {
const page = this.getPage(data, pageId)
const page = TLDR.getPage(data, pageId)
if (changedShapeIds.length === 0) return
const { shapes } = this.getPage(data, pageId)
const { shapes } = TLDR.getPage(data, pageId)
const parentToUpdateIds = Array.from(
new Set(changedShapeIds.map((id) => shapes[id].parentId).values())
@ -706,10 +581,10 @@ export class TLDR {
throw Error('A shape is parented to a shape without a children array.')
}
this.onChildrenChange(data, parent, pageId)
TLDR.onChildrenChange(data, parent, pageId)
}
this.updateParents(data, pageId, parentToUpdateIds)
TLDR.updateParents(data, pageId, parentToUpdateIds)
}
static getSelectedStyle(data: Data, pageId: string): ShapeStyles | false {
@ -754,16 +629,16 @@ export class TLDR {
/* -------------------------------------------------- */
static getBinding(data: Data, id: string, pageId: string): TLDrawBinding {
return this.getPage(data, pageId).bindings[id]
return TLDR.getPage(data, pageId).bindings[id]
}
static getBindings(data: Data, pageId: string): TLDrawBinding[] {
const page = this.getPage(data, pageId)
const page = TLDR.getPage(data, pageId)
return Object.values(page.bindings)
}
static getBindableShapeIds(data: Data) {
return this.getShapes(data, data.appState.currentPageId)
return TLDR.getShapes(data, data.appState.currentPageId)
.filter((shape) => TLDR.getShapeUtils(shape).canBind)
.sort((a, b) => b.childIndex - a.childIndex)
.map((shape) => shape.id)
@ -772,7 +647,7 @@ export class TLDR {
static getBindingsWithShapeIds(data: Data, ids: string[], pageId: string): TLDrawBinding[] {
return Array.from(
new Set(
this.getBindings(data, pageId).filter((binding) => {
TLDR.getBindings(data, pageId).filter((binding) => {
return ids.includes(binding.toId) || ids.includes(binding.fromId)
})
).values()
@ -782,7 +657,7 @@ export class TLDR {
static getRelatedBindings(data: Data, ids: string[], pageId: string): TLDrawBinding[] {
const changedShapeIds = new Set(ids)
const page = this.getPage(data, pageId)
const page = TLDR.getPage(data, pageId)
// Find all bindings that we need to update
const bindingsArr = Object.values(page.bindings)

View file

@ -18,7 +18,6 @@ import {
import { Vec } from '@tldraw/vec'
import {
FlipType,
TextShape,
TLDrawDocument,
MoveType,
AlignType,
@ -139,6 +138,7 @@ export class TLDrawState extends StateManager<Data> {
this.session = undefined
this.pointedId = undefined
}
/* -------------------- Internal -------------------- */
onReady = () => {
@ -413,10 +413,7 @@ export class TLDrawState extends StateManager<Data> {
{
appState: {
activeTool: tool,
activeToolType:
tool === 'select'
? 'select'
: TLDR.getShapeUtils({ type: tool } as TLDrawShape).toolType,
activeToolType: tool === 'select' ? 'select' : TLDR.getShapeUtils(tool).toolType,
},
},
`selected_tool:${tool}`
@ -831,7 +828,7 @@ export class TLDrawState extends StateManager<Data> {
const childIndex =
this.getShapes().sort((a, b) => b.childIndex - a.childIndex)[0].childIndex + 1
const shape = TLDR.getShapeUtils<TextShape>(TLDrawShapeType.Text).create({
const shape = TLDR.getShapeUtils(TLDrawShapeType.Text).create({
id: Utils.uniqueId(),
parentId: this.appState.currentPageId,
childIndex,
@ -1638,7 +1635,7 @@ export class TLDrawState extends StateManager<Data> {
if (shapes.length === 0) return this
return this.create(
...shapes.map((shape) => {
return TLDR.getShapeUtils(shape as TLDrawShape).create({
return TLDR.getShapeUtils(shape.type).create({
...shape,
parentId: shape.parentId || this.currentPageId,
})
@ -2044,7 +2041,7 @@ export class TLDrawState extends StateManager<Data> {
if (!this.appState.activeToolType) throw Error
const utils = TLDR.getShapeUtils({ type: this.appState.activeTool } as TLDrawShape)
const utils = TLDR.getShapeUtils(this.appState.activeTool)
const shapes = this.getShapes()

View file

@ -1,7 +1,7 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
/* eslint-disable @typescript-eslint/ban-types */
import type { TLBinding, TLShapeProps } from '@tldraw/core'
import { TLShape, TLShapeUtil, TLHandle } from '@tldraw/core'
import type { TLShape, TLShapeUtil, TLHandle } from '@tldraw/core'
import type { TLPage, TLPageState } from '@tldraw/core'
import type { StoreApi } from 'zustand'
import type { Command, Patch } from 'rko'
@ -212,24 +212,16 @@ export type TLDrawShape =
| GroupShape
| PostItShape
export abstract class TLDrawShapeUtil<
T extends TLDrawShape,
E extends HTMLElement | SVGElement
> extends TLShapeUtil<T, E> {
abstract toolType: TLDrawToolType
export interface TLDrawShapeUtil<T extends TLDrawShape, E extends Element>
extends TLShapeUtil<T, E, TLDrawMeta> {
toolType: TLDrawToolType
}
export type TLDrawShapeUtils = Record<
TLDrawShapeType,
TLDrawShapeUtil<TLDrawShape, HTMLElement | SVGElement>
>
export interface ArrowBinding extends TLBinding {
type: 'arrow'
export type ArrowBinding = TLBinding<{
handleId: keyof ArrowShape['handles']
distance: number
point: number[]
}
}>
export type TLDrawBinding = ArrowBinding

View file

@ -8,9 +8,13 @@
"baseUrl": "src",
"emitDeclarationOnly": false,
"paths": {
"~*": ["./*"]
"~*": ["./*"],
"@tldraw/core": ["../core"],
"@tldraw/vec": ["../vec"],
"@tldraw/intersect": ["../intersect"]
}
},
"references": [{ "path": "../vec" }, { "path": "../intersect" }, { "path": "../core" }],
"typedocOptions": {
"entryPoints": ["src/index.ts"],
"out": "docs"

View file

@ -17,15 +17,11 @@
"baseUrl": ".",
"rootDir": ".",
"paths": {
"-*": ["./*"]
"-*": ["./*"],
"@tldraw/tldraw": ["../tldraw"]
}
},
"include": ["next-env.d.ts", "**/*.ts", "**/*.tsx"],
"exclude": ["node_modules"],
"references": [
{
"path": "../../packages/tldraw"
},
{ "path": "../../packages/core" }
]
"references": [{ "path": "../tldraw" }, { "path": "../core" }]
}

View file

@ -4,10 +4,10 @@
"exclude": ["node_modules", "**/*.test.ts", "**/*.spec.ts"],
"files": [],
"references": [
{ "path": "./packages/vec/src" },
{ "path": "./packages/intersect/src" },
{ "path": "./packages/tldraw/src" },
{ "path": "./packages/core/src" }
{ "path": "./packages/vec" },
{ "path": "./packages/intersect" },
{ "path": "./packages/tldraw" },
{ "path": "./packages/core" }
],
"compilerOptions": {
"baseUrl": ".",