[improvement] arrows binding logic (#542)

* Improve arrows binding logic

* Update ArrowSession.ts

* more arrow improvements

* major arrow cleanup / refactor

* point toward anchor rather than center
This commit is contained in:
Steve Ruiz 2022-01-30 21:13:57 +00:00 committed by GitHub
parent 0ff6f0628f
commit 03ff422680
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 63506 additions and 12267 deletions

View file

@ -1,4 +1,5 @@
import { TLBounds, TLTransformInfo, Utils, TLPageState } from '@tldraw/core' /* eslint-disable @typescript-eslint/no-non-null-assertion */
import { TLBounds, TLTransformInfo, Utils, TLPageState, TLHandle } from '@tldraw/core'
import { import {
TDSnapshot, TDSnapshot,
ShapeStyles, ShapeStyles,
@ -10,10 +11,16 @@ import {
TldrawPatch, TldrawPatch,
TDShapeType, TDShapeType,
ArrowShape, ArrowShape,
TDHandle,
} from '~types' } from '~types'
import { Vec } from '@tldraw/vec' import { Vec } from '@tldraw/vec'
import type { TDShapeUtil } from './shapes/TDShapeUtil' import type { TDShapeUtil } from './shapes/TDShapeUtil'
import { getShapeUtil } from './shapes' import { getShapeUtil } from './shapes'
import type { TldrawApp } from './TldrawApp'
import { deepCopy } from './StateManager/copy'
import { intersectRayBounds, intersectRayEllipse, intersectRayLineSegment } from '@tldraw/intersect'
import { getTrianglePoints } from './shapes/TriangleUtil/triangleHelpers'
import { BINDING_DISTANCE } from '~constants'
const isDev = process.env.NODE_ENV === 'development' const isDev = process.env.NODE_ENV === 'development'
export class TLDR { export class TLDR {
@ -247,10 +254,16 @@ export class TLDR {
return Object.values(page.bindings) return Object.values(page.bindings)
.filter((binding) => binding.fromId === id || binding.toId === id) .filter((binding) => binding.fromId === id || binding.toId === id)
.reduce((cTDSnapshot, binding) => { .reduce((cTDSnapshot, binding) => {
let oppositeShape: TDShape | undefined = undefined
if (!beforeShapes[binding.fromId]) { if (!beforeShapes[binding.fromId]) {
beforeShapes[binding.fromId] = Utils.deepClone( const arrowShape = TLDR.getShape<ArrowShape>(cTDSnapshot, binding.fromId, pageId)
TLDR.getShape(cTDSnapshot, binding.fromId, pageId) beforeShapes[binding.fromId] = Utils.deepClone(arrowShape)
) const oppositeHandle = arrowShape.handles[binding.handleId === 'start' ? 'end' : 'start']
if (oppositeHandle.bindingId) {
const oppositeBinding = page.bindings[oppositeHandle.bindingId]
oppositeShape = TLDR.getShape(data, oppositeBinding.toId, data.appState.currentPageId)
}
} }
if (!beforeShapes[binding.toId]) { if (!beforeShapes[binding.toId]) {
@ -259,11 +272,12 @@ export class TLDR {
) )
} }
TLDR.onBindingChange( // TLDR.onBindingChange(
TLDR.getShape(cTDSnapshot, binding.fromId, pageId), // TLDR.getShape(cTDSnapshot, binding.fromId, pageId),
binding, // binding,
TLDR.getShape(cTDSnapshot, binding.toId, pageId) // TLDR.getShape(cTDSnapshot, binding.toId, pageId),
) // oppositeShape
// )
afterShapes[binding.fromId] = Utils.deepClone( afterShapes[binding.fromId] = Utils.deepClone(
TLDR.getShape(cTDSnapshot, binding.fromId, pageId) TLDR.getShape(cTDSnapshot, binding.fromId, pageId)
@ -628,18 +642,229 @@ export class TLDR {
return { ...shape, ...delta } return { ...shape, ...delta }
} }
static onBindingChange<T extends TDShape>(shape: T, binding: TDBinding, otherShape: TDShape) { static updateArrowBindings(page: TDPage, arrowShape: ArrowShape) {
const delta = TLDR.getShapeUtil(shape).onBindingChange?.( const result = {
shape, start: deepCopy(arrowShape.handles.start),
binding, end: deepCopy(arrowShape.handles.end),
otherShape, }
TLDR.getShapeUtil(otherShape).getBounds(otherShape), type HandleInfo = {
TLDR.getShapeUtil(otherShape).getExpandedBounds(otherShape), handle: TDHandle
TLDR.getShapeUtil(otherShape).getCenter(otherShape) point: number[] // in page space
} & (
| {
isBound: false
}
| {
isBound: true
hasDecoration: boolean
binding: TDBinding
util: TDShapeUtil<TDShape, any>
target: TDShape
bounds: TLBounds
expandedBounds: TLBounds
intersectBounds: TLBounds
center: number[]
}
) )
if (!delta) return shape let start: HandleInfo = {
isBound: false,
handle: arrowShape.handles.start,
point: Vec.add(arrowShape.handles.start.point, arrowShape.point),
}
let end: HandleInfo = {
isBound: false,
handle: arrowShape.handles.end,
point: Vec.add(arrowShape.handles.end.point, arrowShape.point),
}
if (arrowShape.handles.start.bindingId) {
const hasDecoration = arrowShape.decorations?.start !== undefined
const handle = arrowShape.handles.start
const binding = page.bindings[arrowShape.handles.start.bindingId]
if (!binding) throw Error("Could not find a binding to match the start handle's bindingId")
const target = page.shapes[binding.toId]
const util = TLDR.getShapeUtil(target)
const bounds = util.getBounds(target)
const expandedBounds = util.getExpandedBounds(target)
const intersectBounds = hasDecoration ? Utils.expandBounds(bounds, binding.distance) : bounds
const { minX, minY, width, height } = expandedBounds
const anchorPoint = Vec.add(
[minX, minY],
Vec.mulV([width, height], Vec.rotWith(binding.point, [0.5, 0.5], target.rotation || 0))
)
start = {
isBound: true,
hasDecoration,
binding,
handle,
point: anchorPoint,
util,
target,
bounds,
expandedBounds,
intersectBounds,
center: util.getCenter(target),
}
}
if (arrowShape.handles.end.bindingId) {
const hasDecoration = arrowShape.decorations?.end !== undefined
const handle = arrowShape.handles.end
const binding = page.bindings[arrowShape.handles.end.bindingId]
if (!binding) throw Error("Could not find a binding to match the end handle's bindingId")
const target = page.shapes[binding.toId]
const util = TLDR.getShapeUtil(target)
const bounds = util.getBounds(target)
const expandedBounds = util.getExpandedBounds(target)
const intersectBounds = hasDecoration ? Utils.expandBounds(bounds, binding.distance) : bounds
const { minX, minY, width, height } = expandedBounds
const anchorPoint = Vec.add(
[minX, minY],
Vec.mulV([width, height], Vec.rotWith(binding.point, [0.5, 0.5], target.rotation || 0))
)
end = {
isBound: true,
hasDecoration,
binding,
handle,
point: anchorPoint,
util,
target,
bounds,
expandedBounds,
intersectBounds,
center: util.getCenter(target),
}
}
return { ...shape, ...delta } for (const ID of ['end', 'start'] as const) {
const A = ID === 'start' ? start : end
const B = ID === 'start' ? end : start
if (A.isBound) {
if (!A.binding.distance) {
// If the binding distance is zero, then the arrow is bound to a specific point
// in the target shape. The resulting handle should be exactly at that point.
result[ID].point = Vec.sub(A.point, arrowShape.point)
} else {
// We'll need to figure out the handle's true point based on some intersections
// between the opposite handle point and this handle point. This is different
// for each type of shape.
const direction = Vec.uni(Vec.sub(A.point, B.point))
switch (A.target.type) {
case TDShapeType.Ellipse: {
const hits = intersectRayEllipse(
B.point,
direction,
A.center,
A.target.radius[0] + (A.hasDecoration ? A.binding.distance : 0),
A.target.radius[1] + (A.hasDecoration ? A.binding.distance : 0),
A.target.rotation || 0
).points.sort((a, b) => Vec.dist(a, B.point) - Vec.dist(b, B.point))
if (hits[0] !== undefined) {
result[ID].point = Vec.toFixed(Vec.sub(hits[0], arrowShape.point))
}
break
}
case TDShapeType.Triangle: {
const targetPoint = A.target.point
const points = getTrianglePoints(
A.target.size,
A.hasDecoration ? BINDING_DISTANCE : 0,
A.target.rotation
).map((pt) => Vec.add(pt, targetPoint))
const hits = Utils.pointsToLineSegments(points, true)
.map(([p0, p1]) => intersectRayLineSegment(B.point, direction, p0, p1))
.filter((intersection) => intersection.didIntersect)
.flatMap((intersection) => intersection.points)
.sort((a, b) => Vec.dist(a, B.point) - Vec.dist(b, B.point))
if (hits[0] !== undefined) {
result[ID].point = Vec.toFixed(Vec.sub(hits[0], arrowShape.point))
}
break
}
default: {
const hits = intersectRayBounds(
B.point,
direction,
A.intersectBounds,
A.target.rotation
)
.filter((int) => int.didIntersect)
.map((int) => int.points[0])
.sort((a, b) => Vec.dist(a, B.point) - Vec.dist(b, B.point))
let bHit: number[] | undefined = undefined
if (B.isBound) {
const bHits = intersectRayBounds(
B.point,
direction,
B.intersectBounds,
B.target.rotation
)
.filter((int) => int.didIntersect)
.map((int) => int.points[0])
.sort((a, b) => Vec.dist(a, B.point) - Vec.dist(b, B.point))
bHit = bHits[0]
}
if (
B.isBound &&
(hits.length < 2 ||
(bHit &&
hits[0] &&
Math.ceil(Vec.dist(hits[0], bHit)) < BINDING_DISTANCE * 2.5) ||
Utils.boundsContain(A.expandedBounds, B.expandedBounds) ||
Utils.boundsCollide(A.expandedBounds, B.expandedBounds))
) {
// If the other handle is bound, and if we need to fallback to the short arrow method...
const shortArrowDirection = Vec.uni(Vec.sub(B.point, A.point))
const shortArrowHits = intersectRayBounds(
A.point,
shortArrowDirection,
A.bounds,
A.target.rotation
)
.filter((int) => int.didIntersect)
.map((int) => int.points[0])
result[ID].point = Vec.toFixed(Vec.sub(shortArrowHits[0], arrowShape.point))
result[ID === 'start' ? 'end' : 'start'].point = Vec.toFixed(
Vec.add(
Vec.sub(shortArrowHits[0], arrowShape.point),
Vec.mul(
shortArrowDirection,
Math.min(
Vec.dist(shortArrowHits[0], B.point),
BINDING_DISTANCE *
2.5 *
(Utils.boundsContain(B.bounds, A.intersectBounds) ? -1 : 1)
)
)
)
)
} else if (
!B.isBound &&
((hits[0] && Vec.dist(hits[0], B.point) < BINDING_DISTANCE * 2.5) ||
Utils.pointInBounds(B.point, A.intersectBounds))
) {
// Short arrow time!
const shortArrowDirection = Vec.uni(Vec.sub(A.center, B.point))
return TLDR.getShapeUtil<ArrowShape>(arrowShape).onHandleChange?.(arrowShape, {
[ID]: {
...arrowShape.handles[ID],
point: Vec.toFixed(
Vec.add(
Vec.sub(B.point, arrowShape.point),
Vec.mul(shortArrowDirection, BINDING_DISTANCE * 2.5)
)
),
},
})
} else if (hits[0]) {
result[ID].point = Vec.toFixed(Vec.sub(hits[0], arrowShape.point))
}
}
}
}
}
}
return TLDR.getShapeUtil<ArrowShape>(arrowShape).onHandleChange?.(arrowShape, result)
} }
static transform<T extends TDShape>(shape: T, bounds: TLBounds, info: TLTransformInfo<T>) { static transform<T extends TDShape>(shape: T, bounds: TLBounds, info: TLTransformInfo<T>) {
@ -690,8 +915,7 @@ export class TLDR {
const point = Vec.toFixed(Vec.rotWith(handle.point, relativeCenter, delta)) const point = Vec.toFixed(Vec.rotWith(handle.point, relativeCenter, delta))
return [handleId, { ...handle, point }] return [handleId, { ...handle, point }]
}) })
) as T['handles'], ) as T['handles']
{ shiftKey: false }
) )
return change return change

View file

@ -349,8 +349,10 @@ describe('TldrawTestApp', () => {
.startSession(SessionType.Arrow, 'arrow', 'start') .startSession(SessionType.Arrow, 'arrow', 'start')
.movePointer([10, 10]) .movePointer([10, 10])
.completeSession() .completeSession()
.selectAll()
.style({ color: ColorStyle.Red }) expect(app.bindings.length).toBe(1)
app.selectAll().style({ color: ColorStyle.Red })
expect(app.getShape('arrow').style.color).toBe(ColorStyle.Red) expect(app.getShape('arrow').style.color).toBe(ColorStyle.Red)
expect(app.getShape('rect').style.color).toBe(ColorStyle.Red) expect(app.getShape('rect').style.color).toBe(ColorStyle.Red)

View file

@ -41,6 +41,7 @@ import {
TDAssets, TDAssets,
TDExport, TDExport,
ImageShape, ImageShape,
ArrowShape,
} from '~types' } from '~types'
import { import {
migrate, migrate,
@ -78,6 +79,7 @@ import { LineTool } from './tools/LineTool'
import { ArrowTool } from './tools/ArrowTool' import { ArrowTool } from './tools/ArrowTool'
import { StickyTool } from './tools/StickyTool' import { StickyTool } from './tools/StickyTool'
import { StateManager } from './StateManager' import { StateManager } from './StateManager'
import { deepCopy } from './StateManager/copy'
const uuid = Utils.uniqueId() const uuid = Utils.uniqueId()
@ -368,25 +370,15 @@ export class TldrawApp extends StateManager<TDSnapshot> {
} }
const toShape = page.shapes[binding.toId] const toShape = page.shapes[binding.toId]
const fromShape = page.shapes[binding.fromId] const fromShape = page.shapes[binding.fromId] as ArrowShape
if (!(toShape && fromShape)) { if (!(toShape && fromShape)) {
delete next.document.pages[pageId].bindings[binding.id] delete next.document.pages[pageId].bindings[binding.id]
return return
} }
const toUtils = TLDR.getShapeUtil(toShape) // We only need to update the binding's "from" shape (an arrow)
const fromUtils = TLDR.getShapeUtil(fromShape) const fromDelta = TLDR.updateArrowBindings(page, fromShape)
// We only need to update the binding's "from" shape
const fromDelta = fromUtils.onBindingChange?.(
fromShape,
binding,
toShape,
toUtils.getBounds(toShape),
toUtils.getExpandedBounds(toShape),
toUtils.getCenter(toShape)
)
if (fromDelta) { if (fromDelta) {
const nextShape = { const nextShape = {
@ -795,22 +787,10 @@ export class TldrawApp extends StateManager<TDSnapshot> {
return return
} }
const toShape = page.shapes[binding.toId] const fromShape = page.shapes[binding.fromId] as ArrowShape
const fromShape = page.shapes[binding.fromId]
const toUtils = TLDR.getShapeUtil(toShape) // We only need to update the binding's "from" shape (an arrow)
const fromDelta = TLDR.updateArrowBindings(page, fromShape)
const fromUtils = TLDR.getShapeUtil(fromShape)
// We only need to update the binding's "from" shape
const fromDelta = fromUtils.onBindingChange?.(
fromShape,
binding,
toShape,
toUtils.getBounds(toShape),
toUtils.getExpandedBounds(toShape),
toUtils.getCenter(toShape)
)
if (fromDelta) { if (fromDelta) {
const nextShape = { const nextShape = {

View file

@ -134,23 +134,39 @@ describe('Arrow session', () => {
describe('when dragging a bound shape', () => { describe('when dragging a bound shape', () => {
it('updates the arrow', () => { it('updates the arrow', () => {
const app = new TldrawTestApp() const app = new TldrawTestApp()
.reset()
app .createShapes(
.loadDocument(restoreDoc) { type: TDShapeType.Rectangle, id: 'target1', point: [0, 0], size: [100, 100] },
// Select the arrow and begin a session on the handle's start handle { type: TDShapeType.Arrow, id: 'arrow1', point: [200, 200] }
.movePointer([200, 200]) )
.select('arrow1') expect(app.bindings.length).toBe(0)
.startSession(SessionType.Arrow, 'arrow1', 'start') expect(app.getShape<ArrowShape>('arrow1').point).toStrictEqual([200, 200])
// Move to [50,50]
app.movePointer([50, 50])
// Both handles will keep the same screen positions, but their points will have changed.
expect(app.getShape<ArrowShape>('arrow1').point).toStrictEqual([116, 116])
expect(app.getShape<ArrowShape>('arrow1').handles.start.point).toStrictEqual([0, 0]) expect(app.getShape<ArrowShape>('arrow1').handles.start.point).toStrictEqual([0, 0])
expect(app.getShape<ArrowShape>('arrow1').handles.end.point).toStrictEqual([85, 85]) expect(app.getShape<ArrowShape>('arrow1').handles.end.point).toStrictEqual([1, 1])
// Select the arrow at [200,200] and begin a session on the handle's start handle
app.movePointer([200, 200])
app.startSession(SessionType.Arrow, 'arrow1', 'start')
// expect(app.getShape<ArrowShape>('arrow1').point).toStrictEqual([200, 200])
// expect(app.bindings.length).toBe(0)
// Move the pointer to update the session...
// app.movePointer([190, 190])
// expect(app.getShape<ArrowShape>('arrow1').point).toStrictEqual([190, 190])
// expect(app.bindings.length).toBe(0)
// Move the pointer over another shape to create a binding...
app.movePointer([50, 50])
expect(app.getShape<ArrowShape>('arrow1').point).toStrictEqual([100, 100])
expect(app.bindings.length).toBe(1)
const shape = app.getShape<ArrowShape>('arrow1')
expect(shape.handles.start.bindingId).toBe(app.bindings[0].id)
// Both handles will keep the same screen positions, but their points will have changed.
expect(app.getShape<ArrowShape>('arrow1').handles.start.point).toStrictEqual([0, 0])
expect(app.getShape<ArrowShape>('arrow1').handles.end.point).toStrictEqual([101, 101])
expect(app.getShape<ArrowShape>('arrow1').point).toStrictEqual([100, 100])
}) })
it.todo('updates the arrow when bound on both sides') it.todo('updates the arrow when bound on both sides')
it.todo('snaps the bend to zero when dragging the bend handle toward the center') it.todo('snaps the bend to zero when dragging the bend handle toward the center')
}) })
}) })
@ -170,15 +186,26 @@ describe('When creating with an arrow session', () => {
}) })
it("Doesn't corrupt a shape after undoing", () => { it("Doesn't corrupt a shape after undoing", () => {
const app = new TldrawTestApp() const app = new TldrawTestApp().reset()
expect(app.bindings.length).toBe(0)
app
.createShapes( .createShapes(
{ type: TDShapeType.Rectangle, id: 'rect1', point: [200, 200], size: [200, 200] }, { type: TDShapeType.Rectangle, id: 'rect1', point: [200, 200], size: [100, 100] },
{ type: TDShapeType.Rectangle, id: 'rect2', point: [400, 200], size: [200, 200] } { type: TDShapeType.Rectangle, id: 'rect2', point: [400, 400], size: [100, 100] }
) )
.selectTool(TDShapeType.Arrow) .selectTool(TDShapeType.Arrow)
.pointShape('rect1', { x: 250, y: 250 }) .pointShape('rect1', { x: 210, y: 210 })
.movePointer([450, 250]) app.movePointer([350, 200])
.stopPointing()
expect(app.bindings.length).toBe(1) // Start
app.movePointer([450, 450])
expect(app.bindings.length).toBe(2) // Start and end
app.stopPointing()
expect(app.bindings.length).toBe(2) expect(app.bindings.length).toBe(2)
@ -186,58 +213,93 @@ describe('When creating with an arrow session', () => {
expect(app.bindings.length).toBe(0) expect(app.bindings.length).toBe(0)
app.select('rect1').pointShape('rect1', [250, 250]).movePointer([275, 275]).completeSession() app.select('rect1').pointShape('rect1', [210, 210]).movePointer([275, 275]).completeSession()
expect(app.bindings.length).toBe(0) expect(app.bindings.length).toBe(0)
}) })
it('Creates a start binding if possible', () => { it('Creates a start binding if possible', () => {
const app = new TldrawTestApp() const app = new TldrawTestApp()
.selectAll()
.delete()
.createShapes( .createShapes(
{ type: TDShapeType.Rectangle, id: 'rect1', point: [200, 200], size: [200, 200] }, { type: TDShapeType.Rectangle, id: 'rect1', point: [200, 200], size: [100, 100] },
{ type: TDShapeType.Rectangle, id: 'rect2', point: [400, 200], size: [200, 200] } { type: TDShapeType.Rectangle, id: 'rect2', point: [400, 400], size: [100, 100] }
) )
.selectTool(TDShapeType.Arrow) .selectTool(TDShapeType.Arrow)
.pointShape('rect1', { x: 250, y: 250 }) .pointShape('rect1', { x: 251, y: 251 })
.movePointer([450, 250]) .movePointer([350, 350])
.movePointer([450, 450])
.completeSession() .completeSession()
const arrow = app.shapes.find((shape) => shape.type === TDShapeType.Arrow) as ArrowShape const arrow = app.shapes.find((shape) => shape.type === TDShapeType.Arrow) as ArrowShape
expect(arrow).toBeTruthy() expect(arrow).toBeTruthy()
expect(app.bindings.length).toBe(2)
expect(arrow.handles.start.bindingId).not.toBe(undefined) expect(arrow.handles.start.bindingId).not.toBe(undefined)
expect(arrow.handles.end.bindingId).not.toBe(undefined) expect(arrow.handles.end.bindingId).not.toBe(undefined)
expect(app.bindings.length).toBe(2)
}) })
it('Removes a binding when dragged away', () => { it('Creates a start binding if started in dead center', () => {
const app = new TldrawTestApp() const app = new TldrawTestApp()
.selectAll() .selectAll()
.delete() .delete()
.createShapes( .createShapes(
{ type: TDShapeType.Rectangle, id: 'rect1', point: [200, 200], size: [200, 200] }, { type: TDShapeType.Rectangle, id: 'rect1', point: [200, 200], size: [100, 100] },
{ type: TDShapeType.Rectangle, id: 'rect2', point: [400, 200], size: [200, 200] }, { type: TDShapeType.Rectangle, id: 'rect2', point: [400, 400], size: [100, 100] }
{ type: TDShapeType.Arrow, id: 'arrow1', point: [250, 250] }
) )
.select('arrow1') .selectTool(TDShapeType.Arrow)
.movePointer([250, 250]) .pointShape('rect1', { x: 250, y: 250 })
.startSession(SessionType.Arrow, 'arrow1', 'end', true) .movePointer([350, 350])
.movePointer([450, 250]) .movePointer([450, 450])
.completeSession()
.select('arrow1')
.startSession(SessionType.Arrow, 'arrow1', 'start', false)
.movePointer([0, 0])
.completeSession() .completeSession()
const arrow = app.shapes.find((shape) => shape.type === TDShapeType.Arrow) as ArrowShape const arrow = app.shapes.find((shape) => shape.type === TDShapeType.Arrow) as ArrowShape
expect(arrow).toBeTruthy() expect(arrow).toBeTruthy()
expect(arrow.handles.start.bindingId).not.toBe(undefined)
expect(app.bindings.length).toBe(1)
expect(arrow.handles.start.point).toStrictEqual([0, 0])
expect(arrow.handles.start.bindingId).toBe(undefined)
expect(arrow.handles.end.bindingId).not.toBe(undefined) expect(arrow.handles.end.bindingId).not.toBe(undefined)
expect(app.bindings.length).toBe(2)
})
it('Removes a binding when dragged away', () => {
const app = new TldrawTestApp()
.reset()
.createShapes(
{ type: TDShapeType.Rectangle, id: 'rect1', point: [0, 0], size: [100, 100] },
{ type: TDShapeType.Arrow, id: 'arrow1', point: [200, 200] }
)
expect(app.bindings.length).toBe(0)
expect(app.getShape('arrow1').handles?.end.bindingId).toBeUndefined()
// Select the arrow and create a binding from its end handle to rect1
app
.movePointer([201, 201])
.startSession(SessionType.Arrow, 'arrow1', 'end', false)
.movePointer([50, 50])
.completeSession()
// Expect a binding to exist on the shape's end handle
expect(app.bindings.length).toBe(1)
let arrow = app.getShape<ArrowShape>('arrow1')
expect(arrow.handles?.end.bindingId).toBeDefined()
expect(arrow.point).toStrictEqual([116, 116])
expect(arrow.handles.start.point).toStrictEqual([84, 84])
expect(arrow.handles.end.point).toStrictEqual([0, 0])
// Drag the shape away by [10,10]
app.movePointer([50, 50]).pointShape('arrow1', [50, 50]).movePointer([60, 60]).stopPointing()
arrow = app.getShape<ArrowShape>('arrow1')
// The shape should have moved
expect(arrow.point).toStrictEqual([126, 126])
// The handles should be in the same place
expect(arrow.handles.start.point).toStrictEqual([84, 84])
expect(arrow.handles.end.point).toStrictEqual([0, 0])
// The bindings should have been removed
expect(app.bindings.length).toBe(0)
expect(arrow.handles.start.bindingId).toBe(undefined)
expect(arrow.handles.end.bindingId).toBe(undefined)
}) })
}) })
@ -296,6 +358,7 @@ describe('When drawing an arrow', () => {
it('create a short arrow if start handle is bound', () => { it('create a short arrow if start handle is bound', () => {
const app = new TldrawTestApp() const app = new TldrawTestApp()
.reset()
.createShapes({ .createShapes({
type: TDShapeType.Rectangle, type: TDShapeType.Rectangle,
id: 'rect1', id: 'rect1',
@ -304,9 +367,31 @@ describe('When drawing an arrow', () => {
}) })
.selectTool(TDShapeType.Arrow) .selectTool(TDShapeType.Arrow)
.pointCanvas([101, 100]) // Inside of shape .pointCanvas([101, 100]) // Inside of shape
.movePointer([100, 100]) .movePointer([50, 100])
.stopPointing() .stopPointing()
expect(app.shapes.length).toBe(2) expect(app.shapes.length).toBe(2)
}) })
}) })
describe('When creating arrows inside of other shapes...', () => {
it('does not bind an arrow to shapes that contain the whole arrow', () => {
const app = new TldrawTestApp()
.reset()
.selectTool(TDShapeType.Arrow)
.createShapes({
id: 'rect1',
type: TDShapeType.Rectangle,
point: [0, 0],
size: [200, 200],
})
.pointCanvas([50, 50])
.movePointer([150, 150])
.stopPointing()
const arrow = app.shapes[1] as ArrowShape
expect(arrow.type).toBe(TDShapeType.Arrow)
expect(app.bindings.length).toBe(0)
expect(app.shapes.length).toBe(2)
})
})

View file

@ -1,3 +1,4 @@
/* eslint-disable @typescript-eslint/no-non-null-assertion */
import { import {
ArrowBinding, ArrowBinding,
ArrowShape, ArrowShape,
@ -15,6 +16,7 @@ import { shapeUtils } from '~state/shapes'
import { BaseSession } from '../BaseSession' import { BaseSession } from '../BaseSession'
import type { TldrawApp } from '../../internal' import type { TldrawApp } from '../../internal'
import { Utils } from '@tldraw/core' import { Utils } from '@tldraw/core'
import { deepCopy } from '~state/StateManager/copy'
export class ArrowSession extends BaseSession { export class ArrowSession extends BaseSession {
type = SessionType.Arrow type = SessionType.Arrow
@ -36,14 +38,15 @@ export class ArrowSession extends BaseSession {
const { currentPageId } = app.state.appState const { currentPageId } = app.state.appState
const page = app.state.document.pages[currentPageId] const page = app.state.document.pages[currentPageId]
this.handleId = handleId this.handleId = handleId
this.initialShape = page.shapes[shapeId] as ArrowShape this.initialShape = deepCopy(page.shapes[shapeId] as ArrowShape)
this.bindableShapeIds = TLDR.getBindableShapeIds(app.state).filter( this.bindableShapeIds = TLDR.getBindableShapeIds(app.state).filter(
(id) => !(id === this.initialShape.id || id === this.initialShape.parentId) (id) => !(id === this.initialShape.id || id === this.initialShape.parentId)
) )
// TODO: find out why this the oppositeHandleBindingId is sometimes missing
const oppositeHandleBindingId = const oppositeHandleBindingId =
this.initialShape.handles[handleId === 'start' ? 'end' : 'start']?.bindingId this.initialShape.handles[handleId === 'start' ? 'end' : 'start']?.bindingId
if (oppositeHandleBindingId) { if (oppositeHandleBindingId) {
// TODO: find out why this the binding here is sometimes missing
const oppositeToId = page.bindings[oppositeHandleBindingId]?.toId const oppositeToId = page.bindings[oppositeHandleBindingId]?.toId
if (oppositeToId) { if (oppositeToId) {
this.bindableShapeIds = this.bindableShapeIds.filter((id) => id !== oppositeToId) this.bindableShapeIds = this.bindableShapeIds.filter((id) => id !== oppositeToId)
@ -56,9 +59,13 @@ export class ArrowSession extends BaseSession {
// bindable shape under the pointer. // bindable shape under the pointer.
this.startBindingShapeId = this.bindableShapeIds this.startBindingShapeId = this.bindableShapeIds
.map((id) => page.shapes[id]) .map((id) => page.shapes[id])
.find((shape) => .filter((shape) =>
Utils.pointInBounds(originPoint, TLDR.getShapeUtil(shape).getBounds(shape)) Utils.pointInBounds(originPoint, TLDR.getShapeUtil(shape).getBounds(shape))
)?.id )
.sort((a, b) => {
// TODO - We should be smarter here, what's the right logic?
return b.childIndex - a.childIndex
})[0]?.id
if (this.startBindingShapeId) { if (this.startBindingShapeId) {
this.bindableShapeIds.splice(this.bindableShapeIds.indexOf(this.startBindingShapeId), 1) this.bindableShapeIds.splice(this.bindableShapeIds.indexOf(this.startBindingShapeId), 1)
} }
@ -105,45 +112,54 @@ export class ArrowSession extends BaseSession {
delta = Vec.add(delta, Vec.sub(adjusted, C)) delta = Vec.add(delta, Vec.sub(adjusted, C))
} }
const nextPoint = Vec.sub(Vec.add(handles[handleId].point, delta), shape.point) const nextPoint = Vec.sub(Vec.add(handles[handleId].point, delta), shape.point)
const handle = { const draggedHandle = {
...handles[handleId], ...handles[handleId],
point: showGrid ? Vec.snap(nextPoint, currentGrid) : Vec.toFixed(nextPoint), point: showGrid ? Vec.snap(nextPoint, currentGrid) : Vec.toFixed(nextPoint),
bindingId: undefined, bindingId: undefined,
} }
const utils = shapeUtils[TDShapeType.Arrow] const utils = shapeUtils[TDShapeType.Arrow]
const change = utils.onHandleChange?.(shape, { const handleChange = utils.onHandleChange?.(shape, {
[handleId]: handle, [handleId]: draggedHandle,
}) })
// If the handle changed produced no change, bail here // If the handle changed produced no change, bail here
if (!change) return if (!handleChange) return
// If nothing changes, we want these to be the same object reference as // If nothing changes, we want these to be the same object reference as
// before. If it does change, we'll redefine this later on. And if we've // before. If it does change, we'll redefine this later on. And if we've
// made it this far, the shape should be a new object reference that // made it this far, the shape should be a new object reference that
// incorporates the changes we've made due to the handle movement. // incorporates the changes we've made due to the handle movement.
const next: { shape: ArrowShape; bindings: Record<string, TDBinding | undefined> } = { const next: { shape: ArrowShape; bindings: Record<string, TDBinding | undefined> } = {
shape: Utils.deepMerge(shape, change), shape: Utils.deepMerge(shape, handleChange),
bindings: {}, bindings: {},
} }
if (this.initialBinding) { let draggedBinding: ArrowBinding | undefined
next.bindings[this.initialBinding.id] = undefined const draggingHandle = next.shape.handles[this.handleId]
} const oppositeHandle = next.shape.handles[this.handleId === 'start' ? 'end' : 'start']
// START BINDING // START BINDING
// If we have a start binding shape id, the recompute the binding // If we have a start binding shape id, the recompute the binding
// point based on the current end handle position // point based on the current end handle position
if (this.startBindingShapeId) { if (this.startBindingShapeId) {
let startBinding: ArrowBinding | undefined let nextStartBinding: ArrowBinding | undefined
const target = this.app.page.shapes[this.startBindingShapeId] const startTarget = this.app.page.shapes[this.startBindingShapeId]
const targetUtils = TLDR.getShapeUtil(target) const startTargetUtils = TLDR.getShapeUtil(startTarget)
if (!metaKey) { const center = startTargetUtils.getCenter(startTarget)
const center = targetUtils.getCenter(target) const startHandle = next.shape.handles.start
const handle = next.shape.handles.start const endHandle = next.shape.handles.end
const rayPoint = Vec.add(handle.point, next.shape.point) const rayPoint = Vec.add(startHandle.point, next.shape.point)
const rayOrigin = center if (Vec.isEqual(rayPoint, center)) rayPoint[1]++ // Fix bug where ray and center are identical
const rayDirection = Vec.uni(Vec.sub(rayPoint, rayOrigin)) const rayOrigin = center
const isInsideShape = targetUtils.hitTestPoint(target, currentPoint) const isInsideShape = startTargetUtils.hitTestPoint(startTarget, currentPoint)
startBinding = this.findBindingPoint( const rayDirection = Vec.uni(Vec.sub(rayPoint, rayOrigin))
const hasStartBinding = this.app.getBinding(this.newStartBindingId) !== undefined
// Don't bind the start handle if both handles are inside of the target shape.
if (
!metaKey &&
!startTargetUtils.hitTestPoint(startTarget, Vec.add(next.shape.point, endHandle.point))
) {
nextStartBinding = this.findBindingPoint(
shape, shape,
target, startTarget,
'start', 'start',
this.newStartBindingId, this.newStartBindingId,
center, center,
@ -152,52 +168,45 @@ export class ArrowSession extends BaseSession {
isInsideShape isInsideShape
) )
} }
if (startBinding) { if (nextStartBinding && !hasStartBinding) {
// Bind the arrow's start handle to the start target
this.didBind = true this.didBind = true
next.bindings[this.newStartBindingId] = startBinding next.bindings[this.newStartBindingId] = nextStartBinding
next.shape.handles = { next.shape = Utils.deepMerge(next.shape, {
...next.shape.handles, handles: {
start: { start: {
...next.shape.handles.start, bindingId: nextStartBinding.id,
bindingId: startBinding.id, },
}, },
} })
const target = this.app.page.shapes[this.startBindingShapeId] } else if (!nextStartBinding && hasStartBinding) {
const targetUtils = TLDR.getShapeUtil(target) // Remove the start binding
const arrowChange = TLDR.getShapeUtil<ArrowShape>(next.shape.type).onBindingChange?.( this.didBind = false
next.shape, next.bindings[this.newStartBindingId] = undefined
startBinding, next.shape = Utils.deepMerge(initialShape, {
target, handles: {
targetUtils.getBounds(target),
targetUtils.getExpandedBounds(target),
targetUtils.getCenter(target)
)
if (arrowChange) Object.assign(next.shape, arrowChange)
} else {
this.didBind = this.didBind || false
if (this.app.page.bindings[this.newStartBindingId]) {
next.bindings[this.newStartBindingId] = undefined
}
if (shape.handles.start.bindingId === this.newStartBindingId) {
next.shape.handles = {
...next.shape.handles,
start: { start: {
...next.shape.handles.start,
bindingId: undefined, bindingId: undefined,
}, },
} },
} })
} }
} }
// DRAGGED POINT BINDING // DRAGGED POINT BINDING
let draggedBinding: ArrowBinding | undefined
if (!metaKey) { if (!metaKey) {
const handle = next.shape.handles[this.handleId]
const oppositeHandle = next.shape.handles[this.handleId === 'start' ? 'end' : 'start']
const rayOrigin = Vec.add(oppositeHandle.point, next.shape.point) const rayOrigin = Vec.add(oppositeHandle.point, next.shape.point)
const rayPoint = Vec.add(handle.point, next.shape.point) const rayPoint = Vec.add(draggingHandle.point, next.shape.point)
const rayDirection = Vec.uni(Vec.sub(rayPoint, rayOrigin)) const rayDirection = Vec.uni(Vec.sub(rayPoint, rayOrigin))
const targets = this.bindableShapeIds.map((id) => this.app.page.shapes[id]) const startPoint = Vec.add(next.shape.point!, next.shape.handles!.start.point!)
const endPoint = Vec.add(next.shape.point!, next.shape.handles!.end.point!)
const targets = this.bindableShapeIds
.map((id) => this.app.page.shapes[id])
.sort((a, b) => b.childIndex - a.childIndex)
.filter((shape) => {
const utils = TLDR.getShapeUtil(shape)
return ![startPoint, endPoint].every((point) => utils.hitTestPoint(shape, point))
})
for (const target of targets) { for (const target of targets) {
draggedBinding = this.findBindingPoint( draggedBinding = this.findBindingPoint(
shape, shape,
@ -213,52 +222,43 @@ export class ArrowSession extends BaseSession {
} }
} }
if (draggedBinding) { if (draggedBinding) {
// Create the dragged point binding
this.didBind = true this.didBind = true
next.bindings[this.draggedBindingId] = draggedBinding next.bindings[this.draggedBindingId] = draggedBinding
next.shape.handles = { next.shape = Utils.deepMerge(next.shape, {
...next.shape.handles, handles: {
[this.handleId]: { [this.handleId]: {
...next.shape.handles[this.handleId], bindingId: this.draggedBindingId,
bindingId: this.draggedBindingId, },
}, },
} })
const target = this.app.page.shapes[draggedBinding.toId]
const targetUtils = TLDR.getShapeUtil(target)
const utils = shapeUtils[TDShapeType.Arrow]
const arrowChange = utils.onBindingChange(
next.shape,
draggedBinding,
target,
targetUtils.getBounds(target),
targetUtils.getExpandedBounds(target),
targetUtils.getCenter(target)
)
if (arrowChange) {
Object.assign(next.shape, arrowChange)
}
} else { } else {
// Remove the dragging point binding
this.didBind = this.didBind || false this.didBind = this.didBind || false
const currentBindingId = shape.handles[this.handleId].bindingId const currentBindingId = shape.handles[this.handleId].bindingId
if (currentBindingId) { if (currentBindingId !== undefined) {
next.bindings = { next.bindings[currentBindingId] = undefined
...next.bindings, next.shape = Utils.deepMerge(next.shape, {
[currentBindingId]: undefined, handles: {
} [this.handleId]: {
next.shape.handles = { bindingId: undefined,
...next.shape.handles, },
[this.handleId]: {
...next.shape.handles[this.handleId],
bindingId: undefined,
}, },
} })
} }
} }
const change = TLDR.getShapeUtil<ArrowShape>(next.shape).onHandleChange?.(
next.shape,
next.shape.handles
)
return { return {
document: { document: {
pages: { pages: {
[this.app.currentPageId]: { [this.app.currentPageId]: {
shapes: { shapes: {
[shape.id]: next.shape, [shape.id]: Utils.deepMerge(next.shape, change ?? {}),
}, },
bindings: next.bindings, bindings: next.bindings,
}, },
@ -335,6 +335,7 @@ export class ArrowSession extends BaseSession {
beforeBindings[newStartBindingId] = undefined beforeBindings[newStartBindingId] = undefined
afterBindings[newStartBindingId] = this.app.page.bindings[newStartBindingId] afterBindings[newStartBindingId] = this.app.page.bindings[newStartBindingId]
} }
return { return {
id: 'arrow', id: 'arrow',
before: { before: {
@ -392,7 +393,14 @@ export class ArrowSession extends BaseSession {
) => { ) => {
const util = TLDR.getShapeUtil<TDShape>(target.type) const util = TLDR.getShapeUtil<TDShape>(target.type)
const bindingPoint = util.getBindingPoint(target, shape, point, origin, direction, bindAnywhere) const bindingPoint = util.getBindingPoint(
target,
shape,
point, // fix dead center bug
origin,
direction,
bindAnywhere
)
// Not all shapes will produce a binding point // Not all shapes will produce a binding point
if (!bindingPoint) return if (!bindingPoint) return

File diff suppressed because it is too large Load diff

View file

@ -48,13 +48,9 @@ export class HandleSession extends BaseSession {
} }
// First update the handle's next point // First update the handle's next point
const change = TLDR.getShapeUtil(shape).onHandleChange?.( const change = TLDR.getShapeUtil(shape).onHandleChange?.(shape, {
shape, [handleId]: handle,
{ })
[handleId]: handle,
},
{ delta, shiftKey, altKey, metaKey }
)
if (!change) return if (!change) return

View file

@ -38,6 +38,7 @@ import { getTextLabelSize } from '../shared/getTextSize'
import { StraightArrow } from './components/StraightArrow' import { StraightArrow } from './components/StraightArrow'
import { CurvedArrow } from './components/CurvedArrow.tsx' import { CurvedArrow } from './components/CurvedArrow.tsx'
import { LabelMask } from '../shared/LabelMask' import { LabelMask } from '../shared/LabelMask'
import { TLDR } from '~state/TLDR'
type T = ArrowShape type T = ArrowShape
type E = HTMLDivElement type E = HTMLDivElement
@ -413,77 +414,6 @@ export class ArrowUtil extends TDShapeUtil<T, E> {
return this return this
} }
onBindingChange = (
shape: T,
binding: TDBinding,
target: TDShape,
targetBounds: TLBounds,
expandedBounds: TLBounds,
center: number[]
): Partial<T> | void => {
const handle = shape.handles[binding.handleId as keyof ArrowShape['handles']]
let handlePoint = Vec.sub(
Vec.add(
[expandedBounds.minX, expandedBounds.minY],
Vec.mulV(
[expandedBounds.width, expandedBounds.height],
Vec.rotWith(binding.point, [0.5, 0.5], target.rotation || 0)
)
),
shape.point
)
if (binding.distance) {
const intersectBounds = Utils.expandBounds(targetBounds, binding.distance)
// The direction vector starts from the arrow's opposite handle
const origin = Vec.add(
shape.point,
shape.handles[handle.id === 'start' ? 'end' : 'start'].point
)
// And passes through the dragging handle
const direction = Vec.uni(Vec.sub(Vec.add(handlePoint, shape.point), origin))
if (target.type === TDShapeType.Ellipse) {
const hits = intersectRayEllipse(
origin,
direction,
center,
(target as EllipseShape).radius[0] + binding.distance,
(target as EllipseShape).radius[1] + binding.distance,
target.rotation || 0
).points.sort((a, b) => Vec.dist(a, origin) - Vec.dist(b, origin))
if (hits[0]) handlePoint = Vec.sub(hits[0], shape.point)
} else if (target.type === TDShapeType.Triangle) {
const points = getTrianglePoints(target.size, BINDING_DISTANCE, target.rotation).map((pt) =>
Vec.add(pt, target.point)
)
const segments = Utils.pointsToLineSegments(points, true)
const hits = segments
.map((segment) => intersectRayLineSegment(origin, direction, segment[0], segment[1]))
.filter((intersection) => intersection.didIntersect)
.flatMap((intersection) => intersection.points)
.sort((a, b) => Vec.dist(a, origin) - Vec.dist(b, origin))
if (hits[0]) handlePoint = Vec.sub(hits[0], shape.point)
} else {
let hits = intersectRayBounds(origin, direction, intersectBounds, target.rotation)
.filter((int) => int.didIntersect)
.map((int) => int.points[0])
.sort((a, b) => Vec.dist(a, origin) - Vec.dist(b, origin))
if (hits.length < 2) {
hits = intersectRayBounds(origin, Vec.neg(direction), intersectBounds)
.filter((int) => int.didIntersect)
.map((int) => int.points[0])
.sort((a, b) => Vec.dist(a, origin) - Vec.dist(b, origin))
}
if (hits[0]) handlePoint = Vec.sub(hits[0], shape.point)
}
}
return this.onHandleChange(shape, {
[handle.id]: {
...handle,
point: Vec.toFixed(handlePoint),
},
})
}
onHandleChange = (shape: T, handles: Partial<T['handles']>): Partial<T> | void => { onHandleChange = (shape: T, handles: Partial<T['handles']>): Partial<T> | void => {
let nextHandles = Utils.deepMerge<ArrowShape['handles']>(shape.handles, handles) let nextHandles = Utils.deepMerge<ArrowShape['handles']>(shape.handles, handles)
let nextBend = shape.bend let nextBend = shape.bend
@ -549,6 +479,9 @@ export class ArrowUtil extends TDShapeUtil<T, E> {
handle.point = Vec.toFixed(Vec.sub(handle.point, offset)) handle.point = Vec.toFixed(Vec.sub(handle.point, offset))
}) })
nextShape.point = Vec.toFixed(Vec.add(nextShape.point, offset)) nextShape.point = Vec.toFixed(Vec.add(nextShape.point, offset))
if (Vec.isEqual(nextShape.point, [0, 0])) {
throw Error('here!')
}
} }
return nextShape return nextShape
} }

View file

@ -159,20 +159,20 @@ export abstract class TDShapeUtil<T extends TDShape, E extends Element = any> ex
onChildrenChange?: (shape: T, children: TDShape[]) => Partial<T> | void onChildrenChange?: (shape: T, children: TDShape[]) => Partial<T> | void
onBindingChange?: ( // onBindingChange?: (
shape: T, // shape: T,
binding: TDBinding, // binding: TDBinding,
target: TDShape, // target: TDShape,
targetBounds: TLBounds, // targetBounds: TLBounds,
expandedBounds: TLBounds, // targetExpandedBounds: TLBounds,
center: number[] // targetCenter: number[],
) => Partial<T> | void // oppositeShape?: TDShape,
// oppositeShapeTargetBounds?: TLBounds,
// oppositeShapeTargetExpandedBounds?: TLBounds,
// oppositeShapeTargetCenter?: number[]
// ) => Partial<T> | void
onHandleChange?: ( onHandleChange?: (shape: T, handles: Partial<T['handles']>) => Partial<T> | void
shape: T,
handles: Partial<T['handles']>,
info: Partial<TLPointerInfo>
) => Partial<T> | void
onRightPointHandle?: ( onRightPointHandle?: (
shape: T, shape: T,

View file

@ -295,7 +295,7 @@ export interface TDBaseShape extends TLShape {
style: ShapeStyles style: ShapeStyles
type: TDShapeType type: TDShapeType
label?: string label?: string
handles?: Record<string, TldrawHandle> handles?: Record<string, TDHandle>
} }
export interface DrawShape extends TDBaseShape { export interface DrawShape extends TDBaseShape {
@ -305,7 +305,7 @@ export interface DrawShape extends TDBaseShape {
} }
// The extended handle (used for arrows) // The extended handle (used for arrows)
export interface TldrawHandle extends TLHandle { export interface TDHandle extends TLHandle {
canBind?: boolean canBind?: boolean
bindingId?: string bindingId?: string
} }
@ -336,9 +336,9 @@ export interface ArrowShape extends TDBaseShape {
type: TDShapeType.Arrow type: TDShapeType.Arrow
bend: number bend: number
handles: { handles: {
start: TldrawHandle start: TDHandle
bend: TldrawHandle bend: TDHandle
end: TldrawHandle end: TDHandle
} }
decorations?: { decorations?: {
start?: Decoration start?: Decoration

23414
yarn.lock

File diff suppressed because it is too large Load diff