Adds arrows

This commit is contained in:
Steve Ruiz 2021-08-11 13:26:34 +01:00
parent 923dad6dbe
commit 283e678a4d
25 changed files with 874 additions and 277 deletions

View file

@ -0,0 +1,15 @@
import type { TLBinding } from '@tldraw/core/src/types'
interface BindingProps {
point: number[]
type: TLBinding['type']
}
export function Binding({ point: [x, y], type }: BindingProps): JSX.Element {
return (
<g pointerEvents="none">
{type === 'center' && <circle className="tl-binding" cx={x} cy={y} r={8} />}
{type !== 'pin' && <use className="tl-binding" href="#cross" x={x} y={y} />}
</g>
)
}

View file

@ -0,0 +1 @@
export * from './binding'

View file

@ -218,6 +218,12 @@ const tlcss = css`
.tl-current-parent > *[data-shy='true'] {
opacity: 1;
}
.tl-binding {
fill: none;
stroke: var(--tl-selectStroke);
stroke-width: calc(2px * var(--tl-scale));
}
`
export function useTLTheme(theme?: Partial<TLTheme>) {

View file

@ -15,7 +15,7 @@ export interface TLPageState {
pointedId?: string
hoveredId?: string
editingId?: string
editingBindingId?: string
bindingId?: string
boundsRotation?: number
currentParentId?: string
selectedIds: string[]
@ -29,6 +29,8 @@ export interface TLHandle {
id: string
index: number
point: number[]
canBind?: boolean
bindingId?: string
}
export interface TLShape {
@ -258,6 +260,7 @@ export abstract class TLShapeUtil<T extends TLShape> {
isEditableText = false
isAspectRatioLocked = false
canEdit = false
canBind = false
abstract type: T['type']
@ -294,6 +297,17 @@ export abstract class TLShapeUtil<T extends TLShape> {
return [bounds.width / 2, bounds.height / 2]
}
getBindingPoint(
shape: T,
point: number[],
origin: number[],
direction: number[],
padding: number,
anywhere: boolean
): { point: number[]; distance: number } | undefined {
return undefined
}
create(props: Partial<T>): T {
return { ...this.defaultProps, ...props }
}
@ -317,7 +331,7 @@ export abstract class TLShapeUtil<T extends TLShape> {
_targetBounds: TLBounds,
_center: number[]
): Partial<T> | void {
return
return undefined
}
onHandleChange(

View file

@ -238,8 +238,6 @@ export class Arrow extends TLDrawShapeUtil<ArrowShape> {
style,
} = shape
const circle = getCtp(shape)
const path = Utils.getFromCache(this.simplePathCache, shape, () =>
getArrowArcPath(start, end, getCtp(shape), bend)
)
@ -251,7 +249,21 @@ export class Arrow extends TLDrawShapeUtil<ArrowShape> {
const arrowHeadlength = Math.min(arrowDist / 3, strokeWidth * 8)
const arcLength = Utils.getArcLength([circle[0], circle[1]], circle[2], start.point, end.point)
let insetStart: number[]
let insetEnd: number[]
if (bend === 0) {
insetStart = Vec.nudge(start.point, end.point, arrowHeadlength)
insetEnd = Vec.nudge(end.point, start.point, arrowHeadlength)
} else {
const circle = getCtp(shape)
const arcLength = Utils.getArcLength(
[circle[0], circle[1]],
circle[2],
start.point,
end.point
)
const center = [circle[0], circle[1]]
const radius = circle[2]
@ -259,8 +271,9 @@ export class Arrow extends TLDrawShapeUtil<ArrowShape> {
const ea = Vec.angle(center, end.point)
const t = arrowHeadlength / Math.abs(arcLength)
const insetStart = Vec.nudgeAtAngle(center, Utils.lerpAngles(sa, ea, t), radius)
const insetEnd = Vec.nudgeAtAngle(center, Utils.lerpAngles(ea, sa, t), radius)
insetStart = Vec.nudgeAtAngle(center, Utils.lerpAngles(sa, ea, t), radius)
insetEnd = Vec.nudgeAtAngle(center, Utils.lerpAngles(ea, sa, t), radius)
}
return (
<g>
@ -413,9 +426,10 @@ export class Arrow extends TLDrawShapeUtil<ArrowShape> {
center: number[]
): void | Partial<ArrowShape> => {
const handle = shape.handles[binding.handleId]
const bounds = this.getBounds(shape)
const expandedBounds = Utils.expandBounds(bounds, binding.distance)
const expandedBounds = Utils.expandBounds(targetBounds, 32)
// The anchor is the "actual" point in the target shape
// (Remember that the binding.point is normalized)
const anchor = Vec.sub(
Vec.add(
[expandedBounds.minX, expandedBounds.minY],
@ -424,20 +438,22 @@ export class Arrow extends TLDrawShapeUtil<ArrowShape> {
shape.point
)
let handlePoint: number[]
// We're looking for the point to put the dragging handle
let handlePoint = anchor
const origin = Vec.add(
shape.point,
shape.handles[binding.handleId === 'start' ? 'end' : 'start'].point
)
const direction = Vec.uni(Vec.sub(Vec.add(anchor, shape.point), origin))
// TODO: Abstract this part onto individual shape utils?
if ([TLDrawShapeType.Rectangle, TLDrawShapeType.Text].includes(target.type)) {
if (binding.distance) {
const intersectBounds = Utils.expandBounds(targetBounds, binding.distance)
// The direction vector starts from the arrow's opposite handle
const origin = Vec.add(
shape.point,
shape.handles[handle.id === 'start' ? 'end' : 'start'].point
)
// And passes through the dragging handle
const direction = Vec.uni(Vec.sub(Vec.add(anchor, shape.point), origin))
if ([TLDrawShapeType.Rectangle, TLDrawShapeType.Text].includes(target.type)) {
let hits = Intersect.ray
.bounds(origin, direction, intersectBounds)
.filter((int) => int.didIntersect)
@ -478,14 +494,12 @@ export class Arrow extends TLDrawShapeUtil<ArrowShape> {
origin,
binding.distance
)
} else {
handlePoint = anchor
}
}
return this.onHandleChange(
shape,
{
...shape.handles,
[handle.id]: {
...handle,
point: Vec.round(handlePoint),
@ -497,30 +511,23 @@ export class Arrow extends TLDrawShapeUtil<ArrowShape> {
onHandleChange = (
shape: ArrowShape,
handles: ArrowShape['handles'],
handles: Partial<ArrowShape['handles']>,
{ shiftKey }: Partial<TLPointerInfo>
) => {
let nextHandles = Utils.deepMerge(shape.handles, handles)
let nextHandles = Utils.deepMerge<ArrowShape['handles']>(shape.handles, handles)
let nextBend = shape.bend
// If the user is holding shift, we want to snap the handles to angles
for (const id in handles) {
if ((id === 'start' || id === 'end') && shiftKey) {
const point = handles[id].point
const other = id === 'start' ? shape.handles.end : shape.handles.start
Object.values(handles).forEach((handle) => {
if ((handle.id === 'start' || handle.id === 'end') && shiftKey) {
const point = handle.point
const other = handle.id === 'start' ? shape.handles.end : shape.handles.start
const angle = Vec.angle(other.point, point)
const distance = Vec.dist(other.point, point)
const newAngle = Utils.clampToRotationToSegments(angle, 24)
nextHandles = {
...nextHandles,
[id]: {
...nextHandles[id],
point: Vec.nudgeAtAngle(other.point, newAngle, distance),
},
}
}
handle.point = Vec.nudgeAtAngle(other.point, newAngle, distance)
}
})
// If the user is moving the bend handle, we want to move the bend point
if ('bend' in handles) {

View file

@ -14,6 +14,7 @@ export class Ellipse extends TLDrawShapeUtil<EllipseShape> {
type = TLDrawShapeType.Ellipse as const
toolType = TLDrawToolType.Bounds
pathCache = new WeakMap<EllipseShape, string>([])
canBind = true
defaultProps = {
id: 'id',

View file

@ -13,6 +13,7 @@ import {
export class Rectangle extends TLDrawShapeUtil<RectangleShape> {
type = TLDrawShapeType.Rectangle as const
toolType = TLDrawToolType.Bounds
canBind = true
pathCache = new WeakMap<number[], string>([])
@ -179,6 +180,82 @@ export class Rectangle extends TLDrawShapeUtil<RectangleShape> {
return Utils.getBoundsCenter(this.getBounds(shape))
}
getBindingPoint(
shape: RectangleShape,
point: number[],
origin: number[],
direction: number[],
padding: number,
anywhere: boolean
) {
const bounds = this.getBounds(shape)
const expandedBounds = Utils.expandBounds(bounds, padding)
let bindingPoint: number[]
let distance: number
// The point must be inside of the expanded bounding box
if (!Utils.pointInBounds(point, expandedBounds)) return
// The point is inside of the shape, so we'll assume the user is
// indicating a specific point inside of the shape.
if (anywhere) {
if (Vec.dist(point, this.getCenter(shape)) < 12) {
bindingPoint = [0.5, 0.5]
} else {
bindingPoint = Vec.divV(Vec.sub(point, [expandedBounds.minX, expandedBounds.minY]), [
expandedBounds.width,
expandedBounds.height,
])
}
distance = 0
} else {
// Find furthest intersection between ray from
// origin through point and expanded bounds.
// TODO: Make this a ray vs rounded rect intersection
const intersection = Intersect.ray
.bounds(origin, direction, expandedBounds)
.filter((int) => int.didIntersect)
.map((int) => int.points[0])
.sort((a, b) => Vec.dist(b, origin) - Vec.dist(a, origin))[0]
// The anchor is a point between the handle and the intersection
const anchor = Vec.med(point, intersection)
// If we're close to the center, snap to the center
if (Vec.distanceToLineSegment(point, anchor, this.getCenter(shape)) < 12) {
bindingPoint = [0.5, 0.5]
} else {
// Or else calculate a normalized point
bindingPoint = Vec.divV(Vec.sub(anchor, [expandedBounds.minX, expandedBounds.minY]), [
expandedBounds.width,
expandedBounds.height,
])
}
if (Utils.pointInBounds(point, bounds)) {
distance = 16
} else {
// If the binding point was close to the shape's center, snap to the center
// Find the distance between the point and the real bounds of the shape
distance = Math.max(
16,
Utils.getBoundsSides(bounds)
.map((side) => Vec.distanceToLineSegment(side[1][0], side[1][1], point))
.sort((a, b) => a - b)[0]
)
}
}
return {
point: bindingPoint,
distance,
}
}
hitTest(shape: RectangleShape, point: number[]) {
return Utils.pointInBounds(point, this.getBounds(shape))
}

View file

@ -50,8 +50,8 @@ export class Text extends TLDrawShapeUtil<TextShape> {
type = TLDrawShapeType.Text as const
toolType = TLDrawToolType.Text
canChangeAspectRatio = false
canBind = true
isEditableText = true
canBind = true
pathCache = new WeakMap<number[], string>([])

View file

@ -1,24 +1,75 @@
import type { Data, Command } from '../../state-types'
// - [x] Delete shapes
// - [ ] Delete bindings too
// - [ ] Update parents and possibly delete parents
export function deleteShapes(data: Data, ids: string[]): Command {
// We also need to delete any bindings that reference the deleted shapes
const bindingIdsToDelete = Object.values(data.page.bindings)
.filter((binding) => ids.includes(binding.fromId) || ids.includes(binding.toId))
.map((binding) => binding.id)
// We also need to update any shapes that reference the deleted bindings
const shapesWithBindingsToUpdate = Object.values(data.page.shapes).filter(
(shape) =>
shape.handles &&
Object.values(shape.handles).some(
(handle) => handle.bindingId && bindingIdsToDelete.includes(handle.bindingId)
)
)
return {
id: 'toggle_shapes',
id: 'delete_shapes',
before: {
page: {
shapes: Object.fromEntries(ids.map((id) => [id, data.page.shapes[id]])),
shapes: {
...Object.fromEntries(ids.map((id) => [id, data.page.shapes[id]])),
...Object.fromEntries(
shapesWithBindingsToUpdate.map((shape) => {
let handle = Object.values(shape.handles!).find((handle) => {
const bindingId = handle.bindingId
if (bindingId && bindingIdsToDelete.includes(bindingId)) {
return handle
}
return false
})!
return [shape.id, { handles: { [handle.id]: { bindingId: handle } } }]
})
),
},
bindings: Object.fromEntries(bindingIdsToDelete.map((id) => [id, data.page.bindings[id]])),
},
pageState: {
selectedIds: [...data.pageState.selectedIds],
hoveredId: undefined
hoveredId: undefined,
},
},
after: {
page: {
shapes: Object.fromEntries(ids.map((id) => [id, undefined])),
shapes: {
...Object.fromEntries(ids.map((id) => [id, undefined])),
...Object.fromEntries(
shapesWithBindingsToUpdate.map((shape) => {
for (const id in shape.handles) {
const handle = shape.handles[id as keyof typeof shape.handles]
const bindingId = handle.bindingId
if (bindingId && bindingIdsToDelete.includes(bindingId)) {
handle.bindingId = undefined
}
}
return [shape.id, shape]
})
),
},
bindings: Object.fromEntries(bindingIdsToDelete.map((id) => [id, undefined])),
},
pageState: {
selectedIds: [],
hoveredId: undefined
hoveredId: undefined,
},
},
}

View file

@ -4,7 +4,6 @@ import { mockDocument } from '../../test-helpers'
describe('Duplicate command', () => {
const tlstate = new TLDrawState()
tlstate.loadDocument(mockDocument)
tlstate.reset()
tlstate.select('rect1')
it('does, undoes and redoes command', () => {

View file

@ -5,8 +5,7 @@ import { mockDocument } from '../../test-helpers'
describe('Style command', () => {
const tlstate = new TLDrawState()
tlstate.loadDocument(mockDocument)
tlstate.reset()
tlstate.setSelectedIds(['rect1'])
tlstate.select('rect1')
it('does, undoes and redoes command', () => {
expect(tlstate.getShape('rect1').style.size).toEqual(SizeStyle.Medium)

View file

@ -0,0 +1,41 @@
import { TLDrawState } from '../../../tlstate'
import { mockDocument } from '../../../test-helpers'
import { TLDR } from '../../../tldr'
import type { TLDrawShape } from '../../../../shape'
describe('Handle session', () => {
const tlstate = new TLDrawState()
it('begins, updates and completes session', () => {
tlstate
.loadDocument(mockDocument)
.create(
TLDR.getShapeUtils({ type: 'arrow' } as TLDrawShape).create({
id: 'arrow1',
parentId: 'page1',
})
)
.select('arrow1')
.startHandleSession([-10, -10], 'end')
.updateHandleSession([10, 10])
.completeSession()
.undo()
.redo()
})
it('cancels session', () => {
tlstate
.loadDocument(mockDocument)
.create({
...TLDR.getShapeUtils({ type: 'arrow' } as TLDrawShape).defaultProps,
id: 'arrow1',
parentId: 'page1',
})
.select('arrow1')
.startHandleSession([-10, -10], 'end')
.updateHandleSession([10, 10])
.cancelSession()
expect(tlstate.getShape('rect1').point).toStrictEqual([0, 0])
})
})

View file

@ -0,0 +1,251 @@
import type { ArrowBinding, ArrowShape } from '../../../../shape'
import type { TLDrawShape } from '../../../../shape'
import type { Session } from '../../../state-types'
import type { Data } from '../../../state-types'
import { Vec, Utils, TLBinding } from '@tldraw/core'
import { TLDR } from '../../../tldr'
export class ArrowSession implements Session {
id = 'transform_single'
newBindingId = Utils.uniqueId()
delta = [0, 0]
origin: number[]
shiftKey = false
initialShape: ArrowShape
handleId: 'start' | 'end'
bindableShapeIds: string[]
initialBinding: TLBinding | undefined
didBind = false
constructor(data: Data, handleId: 'start' | 'end', point: number[]) {
const shapeId = data.pageState.selectedIds[0]
this.origin = point
this.handleId = handleId
this.initialShape = TLDR.getShape<ArrowShape>(data, shapeId)
this.bindableShapeIds = TLDR.getBindableShapeIds(data)
const initialBindingId = this.initialShape.handles[this.handleId].bindingId
if (initialBindingId) {
this.initialBinding = data.page.bindings[initialBindingId]
}
}
start = (data: Data) => data
update = (
data: Data,
point: number[],
shiftKey: boolean,
altKey: boolean,
metaKey: boolean
): Partial<Data> => {
const { initialShape, origin } = this
const shape = TLDR.getShape<ArrowShape>(data, initialShape.id)
TLDR.assertShapeHasProperty(shape, 'handles')
this.shiftKey = shiftKey
const delta = Vec.sub(point, origin)
const handles = shape.handles
const handleId = this.handleId as keyof typeof handles
const handle = handles[handleId]
let nextPoint = Vec.round(Vec.add(this.initialShape.handles[handleId].point, delta))
// First update the handle's next point
const change = TLDR.getShapeUtils(shape).onHandleChange(
shape,
{
[handleId]: {
...shape.handles[handleId],
point: nextPoint, // Vec.rot(delta, shape.rotation)),
},
},
{ delta, shiftKey, altKey, metaKey }
)
if (!change) return data
let nextBindings: Record<string, TLBinding> = { ...data.page.bindings }
let nextShape: ArrowShape = { ...shape, ...change }
let nextBinding: ArrowBinding | undefined = undefined
let nextTarget: TLDrawShape | undefined = undefined
if (handle.canBind) {
const oppositeHandle = handles[handle.id === 'start' ? 'end' : 'start']
// Find the origin and direction of the handle
const rayOrigin = Vec.add(oppositeHandle.point, shape.point)
const rayPoint = Vec.add(nextPoint, shape.point)
const rayDirection = Vec.uni(Vec.sub(rayPoint, rayOrigin))
// From all bindable shapes on the page...
for (const id of this.bindableShapeIds) {
if (id === initialShape.id) continue
const target = TLDR.getShape(data, id)
const util = TLDR.getShapeUtils(target)
const bindingPoint = util.getBindingPoint(
target,
rayPoint,
rayOrigin,
rayDirection,
32,
metaKey
)
// Not all shapes will produce a binding point
if (!bindingPoint) continue
// Stop at the first shape that will produce a binding point
nextTarget = target
nextBinding = {
id: this.newBindingId,
type: 'arrow',
fromId: initialShape.id,
handleId: this.handleId,
toId: target.id,
point: Vec.round(bindingPoint.point),
distance: bindingPoint.distance,
}
break
}
// If we didn't find a target...
if (nextBinding === undefined) {
this.didBind = false
if (handle.bindingId) {
delete nextBindings[handle.bindingId]
}
nextShape.handles[handleId].bindingId = undefined
} else if (nextTarget) {
this.didBind = true
if (handle.bindingId && handle.bindingId !== this.newBindingId) {
delete nextBindings[handle.bindingId]
nextShape.handles[handleId].bindingId = undefined
}
// If we found a new binding, add its id to the handle...
nextShape = {
...nextShape,
handles: {
...nextShape.handles,
[handleId]: {
...nextShape.handles[handleId],
bindingId: nextBinding.id,
},
},
}
// and add it to the page's bindings
nextBindings = {
...nextBindings,
[nextBinding.id]: nextBinding,
}
// Now update the arrow in response to the new binding
const arrowChange = TLDR.getShapeUtils(nextShape).onBindingChange(
nextShape,
nextBinding,
nextTarget,
TLDR.getShapeUtils(nextTarget).getBounds(nextTarget),
TLDR.getShapeUtils(nextTarget).getCenter(nextTarget)
)
if (arrowChange) {
nextShape = {
...nextShape,
...arrowChange,
}
}
}
}
return {
page: {
...data.page,
shapes: {
...data.page.shapes,
[shape.id]: nextShape,
},
bindings: nextBindings,
},
pageState: {
...data.pageState,
bindingId: nextShape.handles[handleId].bindingId,
},
}
}
cancel = (data: Data) => {
const { initialShape, newBindingId } = this
const nextBindings = { ...data.page.bindings }
if (this.didBind) {
delete nextBindings[newBindingId]
}
return {
page: {
...data.page,
shapes: {
...data.page.shapes,
[initialShape.id]: initialShape,
},
bindings: nextBindings,
},
}
}
complete(data: Data) {
let beforeBindings: Partial<Record<string, TLBinding>> = {}
let afterBindings: Partial<Record<string, TLBinding>> = {}
const currentShape = TLDR.getShape<ArrowShape>(data, this.initialShape.id)
const currentBindingId = currentShape.handles[this.handleId].bindingId
if (this.initialBinding) {
beforeBindings[this.initialBinding.id] = this.initialBinding
afterBindings[this.initialBinding.id] = undefined
}
if (currentBindingId) {
beforeBindings[currentBindingId] = undefined
afterBindings[currentBindingId] = data.page.bindings[currentBindingId]
}
return {
id: 'arrow',
before: {
page: {
shapes: {
[this.initialShape.id]: this.initialShape,
},
bindings: beforeBindings,
},
},
after: {
page: {
shapes: {
[this.initialShape.id]: TLDR.onSessionComplete(
data,
data.page.shapes[this.initialShape.id]
),
},
bindings: afterBindings,
},
},
}
}
}

View file

@ -0,0 +1 @@
export * from './arrow.session'

View file

@ -57,11 +57,10 @@ export class BrushSession implements Session {
selectedIds.size === data.pageState.selectedIds.length &&
data.pageState.selectedIds.every((id) => selectedIds.has(id))
) {
return data
return {}
}
return {
...data,
pageState: {
...data.pageState,
selectedIds: Array.from(selectedIds.values()),

View file

@ -29,12 +29,7 @@ export class DrawSession implements Session {
start = (data: Data) => data
update = (
data: Data,
point: number[],
pressure: number,
isLocked = false
) => {
update = (data: Data, point: number[], pressure: number, isLocked = false) => {
const { snapshot } = this
// Drawing while holding shift will "lock" the pen to either the
@ -82,10 +77,7 @@ export class DrawSession implements Session {
// Don't add duplicate points. It's important to test against the
// adjusted (low-passed) point rather than the input point.
const newPoint = Vec.round([
...Vec.sub(this.previous, this.origin),
pressure,
])
const newPoint = Vec.round([...Vec.sub(this.previous, this.origin), pressure])
if (Vec.isEqual(this.last, newPoint)) return data
@ -98,7 +90,6 @@ export class DrawSession implements Session {
if (this.points.length <= 2) return data
return {
...data,
page: {
...data.page,
shapes: {
@ -119,7 +110,6 @@ export class DrawSession implements Session {
cancel = (data: Data): Data => {
const { snapshot } = this
return {
...data,
page: {
...data.page,
// @ts-ignore
@ -152,10 +142,7 @@ export class DrawSession implements Session {
after: {
page: {
shapes: {
[snapshot.id]: TLDR.onSessionComplete(
data,
data.page.shapes[snapshot.id]
),
[snapshot.id]: TLDR.onSessionComplete(data, data.page.shapes[snapshot.id]),
},
},
pageState: {

View file

@ -1,3 +1,4 @@
import { ArrowBinding } from './../../../../shape/shape-types'
import { Vec } from '@tldraw/core'
import type { TLDrawShape } from '../../../../shape'
import type { Session } from '../../../state-types'
@ -13,12 +14,7 @@ export class HandleSession implements Session {
initialShape: TLDrawShape
handleId: string
constructor(
data: Data,
handleId: string,
point: number[],
commandId = 'move_handle'
) {
constructor(data: Data, handleId: string, point: number[], commandId = 'move_handle') {
const shapeId = data.pageState.selectedIds[0]
this.origin = point
this.handleId = handleId
@ -49,12 +45,17 @@ export class HandleSession implements Session {
const handleId = this.handleId as keyof typeof handles
const handle = handles[handleId]
let nextPoint = Vec.round(Vec.add(handle.point, delta))
// Now update the handle's next point
const change = TLDR.getShapeUtils(shape).onHandleChange(
shape,
{
[handleId]: {
...shape.handles[handleId],
point: Vec.round(Vec.add(handles[handleId].point, delta)), // Vec.rot(delta, shape.rotation)),
point: nextPoint, // Vec.rot(delta, shape.rotation)),
},
},
{ delta, shiftKey, altKey, metaKey }

View file

@ -6,3 +6,4 @@ export * from './draw'
export * from './rotate'
export * from './handle'
export * from './text'
export * from './arrow'

View file

@ -19,11 +19,10 @@ export class RotateSession implements Session {
start = (data: Data) => data
update = (data: Data, point: number[], isLocked = false): Data => {
update = (data: Data, point: number[], isLocked = false) => {
const { commonBoundsCenter, initialShapes } = this.snapshot
const next = {
...data,
page: {
...data.page,
},
@ -45,8 +44,7 @@ export class RotateSession implements Session {
rot = Utils.clampToRotationToSegments(rot, 24)
}
pageState.boundsRotation =
(PI2 + (this.snapshot.boundsRotation + rot)) % PI2
pageState.boundsRotation = (PI2 + (this.snapshot.boundsRotation + rot)) % PI2
next.page.shapes = {
...next.page.shapes,
@ -58,10 +56,7 @@ export class RotateSession implements Session {
? Utils.clampToRotationToSegments(rotation + rot, 24)
: rotation + rot
const nextPoint = Vec.sub(
Vec.rotWith(center, commonBoundsCenter, rot),
offset
)
const nextPoint = Vec.sub(Vec.rotWith(center, commonBoundsCenter, rot), offset)
return [
id,
@ -77,7 +72,9 @@ export class RotateSession implements Session {
),
}
return next
return {
page: next.page,
}
}
cancel = (data: Data) => {
@ -88,16 +85,12 @@ export class RotateSession implements Session {
}
return {
...data,
page: {
...data.page,
shapes: {
...data.page.shapes,
...Object.fromEntries(
initialShapes.map(({ id, shape }) => [
id,
TLDR.onSessionComplete(data, shape),
])
initialShapes.map(({ id, shape }) => [id, TLDR.onSessionComplete(data, shape)])
),
},
},
@ -114,11 +107,9 @@ export class RotateSession implements Session {
before: {
page: {
shapes: Object.fromEntries(
initialShapes.map(
({ shape: { id, point, rotation = undefined } }) => {
initialShapes.map(({ shape: { id, point, rotation = undefined } }) => {
return [id, { point, rotation }]
}
)
})
),
},
},
@ -169,10 +160,7 @@ export function getRotateSnapshot(data: Data) {
const center = Utils.getBoundsCenter(bounds)
const offset = Vec.sub(center, shape.point)
const rotationOffset = Vec.sub(
center,
Utils.getBoundsCenter(rotatedBounds[shape.id])
)
const rotationOffset = Vec.sub(center, Utils.getBoundsCenter(rotatedBounds[shape.id]))
return {
id: shape.id,

View file

@ -27,7 +27,7 @@ export class TransformSingleSession implements Session {
start = (data: Data) => data
update = (data: Data, point: number[], isAspectRatioLocked = false): Data => {
update = (data: Data, point: number[], isAspectRatioLocked = false): Partial<Data> => {
const { transformType } = this
const { initialShapeBounds, initialShape, id } = this.snapshot
@ -41,13 +41,10 @@ export class TransformSingleSession implements Session {
transformType,
Vec.sub(point, this.origin),
shape.rotation,
isAspectRatioLocked ||
shape.isAspectRatioLocked ||
utils.isAspectRatioLocked
isAspectRatioLocked || shape.isAspectRatioLocked || utils.isAspectRatioLocked
)
return {
...data,
page: {
...data.page,
shapes: {
@ -72,7 +69,6 @@ export class TransformSingleSession implements Session {
data.page.shapes[id] = initialShape
return {
...data,
page: {
...data.page,
shapes: {
@ -98,10 +94,7 @@ export class TransformSingleSession implements Session {
after: {
page: {
shapes: {
[this.snapshot.id]: TLDR.onSessionComplete(
data,
data.page.shapes[this.snapshot.id]
),
[this.snapshot.id]: TLDR.onSessionComplete(data, data.page.shapes[this.snapshot.id]),
},
},
},
@ -130,6 +123,4 @@ export function getTransformSingleSnapshot(
}
}
export type TransformSingleSnapshot = ReturnType<
typeof getTransformSingleSnapshot
>
export type TransformSingleSnapshot = ReturnType<typeof getTransformSingleSnapshot>

View file

@ -28,13 +28,13 @@ export class TransformSession implements Session {
point: number[],
isAspectRatioLocked = false,
_altKey = false
): Data => {
): Partial<Data> => {
const {
transformType,
snapshot: { shapeBounds, initialBounds, isAllAspectRatioLocked },
} = this
const next = {
const next: Data = {
...data,
page: {
...data.page,
@ -89,23 +89,21 @@ export class TransformSession implements Session {
),
}
return next
return {
page: next.page,
}
}
cancel = (data: Data) => {
const { shapeBounds } = this.snapshot
return {
...data,
page: {
...data.page,
shapes: {
...data.page.shapes,
...Object.fromEntries(
Object.entries(shapeBounds).map(([id, { initialShape }]) => [
id,
initialShape,
])
Object.entries(shapeBounds).map(([id, { initialShape }]) => [id, initialShape])
),
},
},
@ -122,10 +120,7 @@ export class TransformSession implements Session {
before: {
page: {
shapes: Object.fromEntries(
Object.entries(shapeBounds).map(([id, { initialShape }]) => [
id,
initialShape,
])
Object.entries(shapeBounds).map(([id, { initialShape }]) => [id, initialShape])
),
},
},
@ -143,17 +138,13 @@ export class TransformSession implements Session {
}
}
export function getTransformSnapshot(
data: Data,
transformType: TLBoundsEdge | TLBoundsCorner
) {
export function getTransformSnapshot(data: Data, transformType: TLBoundsEdge | TLBoundsCorner) {
const initialShapes = TLDR.getSelectedBranchSnapshot(data)
const hasUnlockedShapes = initialShapes.length > 0
const isAllAspectRatioLocked = initialShapes.every(
(shape) =>
shape.isAspectRatioLocked || TLDR.getShapeUtils(shape).isAspectRatioLocked
(shape) => shape.isAspectRatioLocked || TLDR.getShapeUtils(shape).isAspectRatioLocked
)
const shapesBounds = Object.fromEntries(
@ -164,9 +155,7 @@ export function getTransformSnapshot(
const commonBounds = Utils.getCommonBounds(boundsArr)
const initialInnerBounds = Utils.getBoundsFromPoints(
boundsArr.map(Utils.getBoundsCenter)
)
const initialInnerBounds = Utils.getBoundsFromPoints(boundsArr.map(Utils.getBoundsCenter))
// Return a mapping of shapes to bounds together with the relative
// positions of the shape's bounds within the common bounds shape.

View file

@ -21,12 +21,7 @@ export class TranslateSession implements Session {
return data
}
update = (
data: Data,
point: number[],
isAligned = false,
isCloning = false
) => {
update = (data: Data, point: number[], isAligned = false, isCloning = false) => {
const { clones, initialShapes } = this.snapshot
const next = {
@ -91,15 +86,13 @@ export class TranslateSession implements Session {
clone.id,
{
...clone,
point: Vec.round(
Vec.add(next.page.shapes[clone.id].point, trueDelta)
),
point: Vec.round(Vec.add(next.page.shapes[clone.id].point, trueDelta)),
},
])
),
}
return next
return { page: { ...next.page }, pageState: { ...next.pageState } }
}
// If not cloning...
@ -137,28 +130,22 @@ export class TranslateSession implements Session {
shape.id,
{
...next.page.shapes[shape.id],
point: Vec.round(
Vec.add(next.page.shapes[shape.id].point, trueDelta)
),
point: Vec.round(Vec.add(next.page.shapes[shape.id].point, trueDelta)),
},
])
),
}
return next
return { page: { ...next.page }, pageState: { ...next.pageState } }
}
cancel = (data: Data): Data => {
return {
...data,
page: {
...data.page,
// @ts-ignore - We need to set deleted shapes to undefined in order to correctly deep merge them away.
shapes: {
...data.page.shapes,
...Object.fromEntries(
this.snapshot.clones.map((clone) => [clone.id, undefined])
),
...Object.fromEntries(this.snapshot.clones.map((clone) => [clone.id, undefined])),
...Object.fromEntries(
this.snapshot.initialShapes.map((shape) => [
shape.id,
@ -178,38 +165,23 @@ export class TranslateSession implements Session {
return {
id: 'translate',
before: {
...data,
page: {
...data.page,
shapes: {
...data.page.shapes,
...Object.fromEntries(this.snapshot.clones.map((clone) => [clone.id, undefined])),
...Object.fromEntries(
this.snapshot.clones.map((clone) => [clone.id, undefined])
),
...Object.fromEntries(
this.snapshot.initialShapes.map((shape) => [
shape.id,
{ point: shape.point },
])
this.snapshot.initialShapes.map((shape) => [shape.id, { point: shape.point }])
),
},
},
pageState: {
...data.pageState,
selectedIds: this.snapshot.selectedIds,
},
},
after: {
...data,
page: {
...data.page,
shapes: {
...data.page.shapes,
...Object.fromEntries(
this.snapshot.clones.map((clone) => [
clone.id,
data.page.shapes[clone.id],
])
this.snapshot.clones.map((clone) => [clone.id, data.page.shapes[clone.id]])
),
...Object.fromEntries(
this.snapshot.initialShapes.map((shape) => [
@ -220,7 +192,6 @@ export class TranslateSession implements Session {
},
},
pageState: {
...data.pageState,
selectedIds: [...data.pageState.selectedIds],
},
},
@ -234,9 +205,7 @@ export function getTranslateSnapshot(data: Data) {
const hasUnlockedShapes = selectedShapes.length > 0
const initialParents = Array.from(
new Set(selectedShapes.map((s) => s.parentId)).values()
)
const initialParents = Array.from(new Set(selectedShapes.map((s) => s.parentId)).values())
.filter((id) => id !== data.page.id)
.map((id) => {
const shape = TLDR.getShape(data, id)

View file

@ -1,11 +1,6 @@
/* eslint-disable @typescript-eslint/ban-types */
import type { TLPage, TLPageState } from '@tldraw/core'
import type {
ShapeStyles,
TLDrawShape,
TLDrawShapeType,
TLDrawToolType,
} from '../shape'
import type { ShapeStyles, TLDrawShape, TLDrawShapeType, TLDrawToolType } from '../shape'
import type { TLDrawSettings } from '../types'
import type { StoreApi } from 'zustand'
@ -51,10 +46,10 @@ export interface History {
export interface Session {
id: string
start: (data: Readonly<Data>, ...args: any[]) => Data
update: (data: Readonly<Data>, ...args: any[]) => Data
complete: (data: Readonly<Data>, ...args: any[]) => Data | Command
cancel: (data: Readonly<Data>, ...args: any[]) => Data
start: (data: Readonly<Data>, ...args: any[]) => Partial<Data>
update: (data: Readonly<Data>, ...args: any[]) => Partial<Data>
complete: (data: Readonly<Data>, ...args: any[]) => Partial<Data> | Command
cancel: (data: Readonly<Data>, ...args: any[]) => Partial<Data>
}
export type TLDrawStatus =
@ -72,11 +67,6 @@ export type TLDrawStatus =
| 'editing-text'
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type ParametersExceptFirst<F> = F extends (
arg0: any,
...rest: infer R
) => any
? R
: never
export type ParametersExceptFirst<F> = F extends (arg0: any, ...rest: infer R) => any ? R : never
export {}

View file

@ -1,6 +1,6 @@
import { TLBinding, TLBounds, TLTransformInfo, Vec, Utils } from '@tldraw/core'
import { getShapeUtils, ShapeStyles, ShapesWithProp, TLDrawShape, TLDrawShapeUtil } from '../shape'
import type { Data } from './state-types'
import type { Data, DeepPartial } from './state-types'
export class TLDR {
static getShapeUtils<T extends TLDrawShape>(shape: T | T['type']): TLDrawShapeUtil<T> {
@ -389,31 +389,132 @@ export class TLDR {
}
}
static createShapes(data: Data, shapes: TLDrawShape[]): void {
static createShapes(
data: Data,
shapes: TLDrawShape[]
): { before: DeepPartial<Data>; after: DeepPartial<Data> } {
const page = this.getPage(data)
const shapeIds = shapes.map((shape) => shape.id)
// Update selected ids
this.setSelectedIds(data, shapeIds)
const before: DeepPartial<Data> = {
page: {
shapes: {
...Object.fromEntries(
shapes.flatMap((shape) => {
const results: [string, Partial<TLDrawShape> | undefined][] = [[shape.id, undefined]]
// Restore deleted shapes
shapes.forEach((shape) => {
const newShape = { ...shape }
page.shapes[shape.id] = newShape
// If the shape is a child of another shape, also add that shape
if (shape.parentId !== data.page.id) {
const parent = page.shapes[shape.parentId]
results.push([parent.id, { children: parent.children! }])
}
return results
})
),
},
},
}
// Update parents
shapes.forEach((shape) => {
if (shape.parentId === data.page.id) return
const after: DeepPartial<Data> = {
page: {
shapes: {
...Object.fromEntries(
shapes.flatMap((shape) => {
const results: [string, Partial<TLDrawShape> | undefined][] = [[shape.id, shape]]
// If the shape is a child of a different shape, update its parent
if (shape.parentId !== data.page.id) {
const parent = page.shapes[shape.parentId]
results.push([parent.id, { children: [...parent.children!, shape.id] }])
}
return results
})
),
},
},
}
return {
before,
after,
}
}
static deleteShapes(
data: Data,
shapes: TLDrawShape[] | string[]
): { before: DeepPartial<Data>; after: DeepPartial<Data> } {
const page = this.getPage(data)
const shapeIds =
typeof shapes[0] === 'string'
? (shapes as string[])
: (shapes as TLDrawShape[]).map((shape) => shape.id)
const before: DeepPartial<Data> = {
page: {
shapes: {
// These are the shapes that we're going to delete
...Object.fromEntries(
shapeIds.flatMap((id) => {
const shape = page.shapes[id]
const results: [string, Partial<TLDrawShape> | undefined][] = [[shape.id, shape]]
// If the shape is a child of another shape, also add that shape
if (shape.parentId !== data.page.id) {
const parent = page.shapes[shape.parentId]
results.push([parent.id, { children: parent.children! }])
}
return results
})
),
},
bindings: {
// These are the bindings that we're going to delete
...Object.fromEntries(
Object.values(page.bindings)
.filter((binding) => {
return shapeIds.includes(binding.fromId) || shapeIds.includes(binding.toId)
})
.map((binding) => {
return [binding.id, binding]
})
),
},
},
}
const after: DeepPartial<Data> = {
page: {
shapes: {
...Object.fromEntries(
shapeIds.flatMap((id) => {
const shape = page.shapes[id]
const results: [string, Partial<TLDrawShape> | undefined][] = [[shape.id, undefined]]
// If the shape is a child of a different shape, update its parent
if (shape.parentId !== data.page.id) {
const parent = page.shapes[shape.parentId]
this.mutate(data, parent, {
children: parent.children!.includes(shape.id)
? parent.children
: [...parent.children!, shape.id],
})
results.push([
parent.id,
{ children: parent.children!.filter((id) => id !== shape.id) },
])
}
return results
})
),
},
},
}
return {
before,
after,
}
}
static onSessionComplete<T extends TLDrawShape>(data: Data, shape: T) {
@ -515,7 +616,9 @@ export class TLDR {
return currentStyle
}
const shapeStyles = data.pageState.selectedIds.map((id) => page.shapes[id].style)
const shapeStyles = data.pageState.selectedIds.map((id) => {
return page.shapes[id].style
})
const commonStyle: ShapeStyles = {} as ShapeStyles
@ -552,6 +655,13 @@ export class TLDR {
return Object.values(page.bindings)
}
static getBindableShapeIds(data: Data) {
return Object.values(data.page.shapes)
.filter((shape) => TLDR.getShapeUtils(shape).canBind)
.sort((a, b) => b.childIndex - a.childIndex)
.map((shape) => shape.id)
}
static getBindingsWithShapeIds(data: Data, ids: string[]): TLBinding[] {
return Array.from(
new Set(
@ -567,13 +677,11 @@ export class TLDR {
bindings.forEach((binding) => (page.bindings[binding.id] = binding))
}
static deleteBindings(data: Data, ids: string[]): void {
if (ids.length === 0) return
const page = this.getPage(data)
ids.forEach((id) => delete page.bindings[id])
}
// static deleteBindings(data: Data, ids: string[]): void {
// if (ids.length === 0) return
// const page = this.getPage(data)
// ids.forEach((id) => delete page.bindings[id])
// }
/* -------------------------------------------------- */
/* Assertions */

View file

@ -1,3 +1,4 @@
import { ArrowSession } from './session/sessions/arrow/arrow.session'
import type { TextShape } from './../shape/shape-types'
import { FlipType } from './../types'
import createReact, { PartialState } from 'zustand'
@ -111,14 +112,81 @@ export class TLDrawState implements TLCallbacks {
let next = { ...current, ...result }
if ('page' in result) {
if (result.page) {
const shapes = { ...next.page.shapes }
for (let id in shapes) {
if (!shapes[id]) delete shapes[id]
}
const bindings = { ...next.page.bindings }
for (let id in bindings) {
if (!bindings[id]) delete bindings[id]
}
const changedShapeIds = new Set(
Object.values(shapes)
.filter((shape) => current.page.shapes[shape.id] !== shape)
.map((shape) => shape.id)
)
// Find all shapes that we need to update due to bindings
const bindingsArr = Object.values(bindings)
const bindingsToUpdate = new Set(
bindingsArr.filter(
(binding) => changedShapeIds.has(binding.toId) || changedShapeIds.has(binding.fromId)
)
)
let prevSize = bindingsToUpdate.size
while (true) {
bindingsToUpdate.forEach((binding) => {
const fromId = binding.fromId
for (const otherBinding of bindingsArr) {
if (otherBinding.fromId === fromId) {
bindingsToUpdate.add(otherBinding)
}
if (otherBinding.toId === fromId) {
bindingsToUpdate.add(otherBinding)
}
}
})
if (bindingsToUpdate.size === prevSize) break
prevSize = bindingsToUpdate.size
}
bindingsToUpdate.forEach((binding) => {
// Update the binding
const toShape = shapes[binding.toId]
const fromShape = shapes[binding.fromId]
const toUtils = TLDR.getShapeUtils(toShape)
const fromDelta = TLDR.getShapeUtils(fromShape).onBindingChange(
fromShape,
binding,
toShape,
toUtils.getBounds(toShape),
toUtils.getCenter(toShape)
)
if (fromDelta) {
shapes[fromShape.id] = {
...fromShape,
...fromDelta,
} as TLDrawShape
}
})
next.page = {
...next.page,
shapes: Object.fromEntries(
Object.entries(next.page.shapes).filter(([_, shape]) => {
return shape && (shape.parentId === next.page.id || next.page.shapes[shape.parentId])
})
),
shapes,
bindings,
}
}
@ -126,13 +194,10 @@ export class TLDrawState implements TLCallbacks {
const newSelectedStyle = TLDR.getSelectedStyle(next as Data)
if (newSelectedStyle) {
next = {
...next,
appState: {
next.appState = {
...current.appState,
...next.appState,
selectedStyle: newSelectedStyle,
},
}
}
@ -212,6 +277,18 @@ export class TLDrawState implements TLCallbacks {
...data.appState,
...initialData.settings,
},
page: {
...data.page,
shapes: {},
bindings: {},
},
pageState: {
...data.pageState,
editingId: undefined,
bindingId: undefined,
hoveredId: undefined,
selectedIds: [],
},
}))
this._onChange?.(this, `reset`)
return this
@ -519,7 +596,13 @@ export class TLDrawState implements TLCallbacks {
history.pointer = history.stack.length - 1
this.setState((data) => Utils.deepMerge<Data>(data, history.stack[history.pointer].after))
this.setState((data) =>
Object.fromEntries(
Object.entries(command.after).map(([key, partial]) => {
return [key, Utils.deepMerge(data[key as keyof Data], partial)]
})
)
)
this._onChange?.(this, `command:${command.id}`)
@ -533,7 +616,13 @@ export class TLDrawState implements TLCallbacks {
const command = history.stack[history.pointer]
this.setState((data) => Utils.deepMerge<Data>(data, command.before))
this.setState((data) =>
Object.fromEntries(
Object.entries(command.before).map(([key, partial]) => {
return [key, Utils.deepMerge(data[key as keyof Data], partial)]
})
)
)
history.pointer--
@ -551,8 +640,13 @@ export class TLDrawState implements TLCallbacks {
const command = history.stack[history.pointer]
this.setState((data) => Utils.deepMerge<Data>(data, command.after))
this.setState((data) =>
Object.fromEntries(
Object.entries(command.after).map(([key, partial]) => {
return [key, Utils.deepMerge(data[key as keyof Data], partial)]
})
)
)
this._onChange?.(this, `redo:${command.id}`)
return this
@ -956,14 +1050,21 @@ export class TLDrawState implements TLCallbacks {
}
startHandleSession = (point: number[], handleId: string, commandId?: string) => {
const selectedShape = this.page.shapes[this.selectedIds[0]]
if (selectedShape.type === TLDrawShapeType.Arrow) {
this.startSession<ArrowSession>(
new ArrowSession(this.store.getState(), handleId as 'start' | 'end', point)
)
} else {
this.startSession<HandleSession>(
new HandleSession(this.store.getState(), handleId, point, commandId)
)
}
return this
}
updateHandleSession = (point: number[], shiftKey = false, altKey = false, metaKey = false) => {
this.updateSession<HandleSession>(point, shiftKey, altKey, metaKey)
this.updateSession<HandleSession | ArrowSession>(point, shiftKey, altKey, metaKey)
return this
}
@ -1003,7 +1104,12 @@ export class TLDrawState implements TLCallbacks {
break
}
case 'translatingHandle': {
this.updateHandleSession(this.getPagePoint(info.point), info.shiftKey, info.altKey)
this.updateHandleSession(
this.getPagePoint(info.point),
info.shiftKey,
info.altKey,
info.metaKey
)
break
}
case 'creating': {
@ -1017,7 +1123,12 @@ export class TLDrawState implements TLCallbacks {
break
}
case 'handle': {
this.updateHandleSession(this.getPagePoint(info.point), info.shiftKey, info.altKey)
this.updateHandleSession(
this.getPagePoint(info.point),
info.shiftKey,
info.altKey,
info.metaKey
)
break
}
case 'point': {